from typing import Callable
from typing import Iterator
from typing import List
-from typing import Optional
+from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from ..runtime.environment import _GetRevArg
from ..runtime.migration import MigrationContext
+ProcessRevisionDirectiveFn = Callable[
+ ["MigrationContext", "_GetRevArg", List["MigrationScript"]], None
+]
+
class Rewriter:
"""A helper object that allows easy 'rewriting' of ops streams.
_traverse = util.Dispatcher()
- _chained: Optional[Rewriter] = None
+ _chained: Tuple[Union[ProcessRevisionDirectiveFn, Rewriter], ...] = ()
def __init__(self) -> None:
self.dispatch = util.Dispatcher()
- def chain(self, other: Rewriter) -> Rewriter:
+ def chain(
+ self,
+ other: Union[
+ ProcessRevisionDirectiveFn,
+ Rewriter,
+ ],
+ ) -> Rewriter:
"""Produce a "chain" of this :class:`.Rewriter` to another.
- This allows two rewriters to operate serially on a stream,
+ This allows two or more rewriters to operate serially on a stream,
e.g.::
writer1 = autogenerate.Rewriter()
"""
wr = self.__class__.__new__(self.__class__)
wr.__dict__.update(self.__dict__)
- wr._chained = other
+ wr._chained += (other,)
return wr
def rewrites(
directives: List[MigrationScript],
) -> None:
self.process_revision_directives(context, revision, directives)
- if self._chained:
- self._chained(context, revision, directives)
+ for process_revision_directives in self._chained:
+ process_revision_directives(context, revision, directives)
@_traverse.dispatch_for(ops.MigrationScript)
def _traverse_script(
idx_op = ops.CreateIndexOp("ixt", op.table_name, [op.column.name])
return [op, idx_op]
+ def process_revision_directives(context, revision, generate_revisions):
+ generate_revisions[0].downgrade_ops = ops.DowngradeOps(
+ ops=[ops.DropColumnOp("t1", "x")]
+ )
+
directives = [
ops.MigrationScript(
util.rev_id(),
]
ctx, rev = mock.Mock(), mock.Mock()
- writer1.chain(writer2)(ctx, rev, directives)
+ writer = writer1.chain(process_revision_directives).chain(writer2)
+ writer(ctx, rev, directives)
eq_(
autogenerate.render_python_code(directives[0].upgrade_ops),
" # ### end Alembic commands ###",
)
+ eq_(
+ autogenerate.render_python_code(directives[0].downgrade_ops),
+ "# ### commands auto generated by Alembic - please adjust! ###\n"
+ " op.drop_column('t1', 'x')\n"
+ " # ### end Alembic commands ###",
+ )
+
def test_no_needless_pass(self):
writer1 = autogenerate.Rewriter()