From 48b225067e1a7536fa63a6811d507b2079ef0a3f Mon Sep 17 00:00:00 2001 From: l-hedgehog Date: Mon, 11 Dec 2023 15:24:56 -0500 Subject: [PATCH] Improve `Rewriter` implementation Fixes #1337 ### Description * Fix the chaining of more than two rewriters * Wrap a callable so that it could be chained This works in my local test, and I hope it makes sense to use the callable wrapper as the base class. ### Checklist This pull request is: - [ ] A documentation / typographical error fix - Good to go, no issue or tests are needed - [x] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [ ] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. Closes: #1368 Pull-request: https://github.com/sqlalchemy/alembic/pull/1368 Pull-request-sha: e62633dbbb9b8dd91a145a6f27efcdbe4e4c0b3b Change-Id: I7642a3ec8c6b8923f70ae8e7b6dbd482cb15eff9 --- alembic/autogenerate/rewriter.py | 24 +++++++++++++++++------- docs/build/unreleased/1337.rst | 7 +++++++ tests/test_script_production.py | 15 ++++++++++++++- 3 files changed, 38 insertions(+), 8 deletions(-) create mode 100644 docs/build/unreleased/1337.rst diff --git a/alembic/autogenerate/rewriter.py b/alembic/autogenerate/rewriter.py index 68a93dd0..3efb499b 100644 --- a/alembic/autogenerate/rewriter.py +++ b/alembic/autogenerate/rewriter.py @@ -4,7 +4,7 @@ from typing import Any from typing import Callable from typing import Iterator from typing import List -from typing import Optional +from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import Union @@ -23,6 +23,10 @@ if TYPE_CHECKING: from ..runtime.environment import _GetRevArg from ..runtime.migration import MigrationContext +ProcessRevisionDirectiveFn = Callable[ + ["MigrationContext", "_GetRevArg", List["MigrationScript"]], None +] + class Rewriter: """A helper object that allows easy 'rewriting' of ops streams. @@ -52,15 +56,21 @@ class Rewriter: _traverse = util.Dispatcher() - _chained: Optional[Rewriter] = None + _chained: Tuple[Union[ProcessRevisionDirectiveFn, Rewriter], ...] = () def __init__(self) -> None: self.dispatch = util.Dispatcher() - def chain(self, other: Rewriter) -> Rewriter: + def chain( + self, + other: Union[ + ProcessRevisionDirectiveFn, + Rewriter, + ], + ) -> Rewriter: """Produce a "chain" of this :class:`.Rewriter` to another. - This allows two rewriters to operate serially on a stream, + This allows two or more rewriters to operate serially on a stream, e.g.:: writer1 = autogenerate.Rewriter() @@ -89,7 +99,7 @@ class Rewriter: """ wr = self.__class__.__new__(self.__class__) wr.__dict__.update(self.__dict__) - wr._chained = other + wr._chained += (other,) return wr def rewrites( @@ -146,8 +156,8 @@ class Rewriter: directives: List[MigrationScript], ) -> None: self.process_revision_directives(context, revision, directives) - if self._chained: - self._chained(context, revision, directives) + for process_revision_directives in self._chained: + process_revision_directives(context, revision, directives) @_traverse.dispatch_for(ops.MigrationScript) def _traverse_script( diff --git a/docs/build/unreleased/1337.rst b/docs/build/unreleased/1337.rst new file mode 100644 index 00000000..2660e831 --- /dev/null +++ b/docs/build/unreleased/1337.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, autogenerate + :tickets: 1337 + + Fixes `autogenerate.Rewriter` so that more than two instances could be + chained together correctly, and `process_revision_directives` callable + could also be chained. diff --git a/tests/test_script_production.py b/tests/test_script_production.py index 3b5a6f60..7b7db814 100644 --- a/tests/test_script_production.py +++ b/tests/test_script_production.py @@ -933,6 +933,11 @@ class RewriterTest(TestBase): idx_op = ops.CreateIndexOp("ixt", op.table_name, [op.column.name]) return [op, idx_op] + def process_revision_directives(context, revision, generate_revisions): + generate_revisions[0].downgrade_ops = ops.DowngradeOps( + ops=[ops.DropColumnOp("t1", "x")] + ) + directives = [ ops.MigrationScript( util.rev_id(), @@ -956,7 +961,8 @@ class RewriterTest(TestBase): ] ctx, rev = mock.Mock(), mock.Mock() - writer1.chain(writer2)(ctx, rev, directives) + writer = writer1.chain(process_revision_directives).chain(writer2) + writer(ctx, rev, directives) eq_( autogenerate.render_python_code(directives[0].upgrade_ops), @@ -970,6 +976,13 @@ class RewriterTest(TestBase): " # ### end Alembic commands ###", ) + eq_( + autogenerate.render_python_code(directives[0].downgrade_ops), + "# ### commands auto generated by Alembic - please adjust! ###\n" + " op.drop_column('t1', 'x')\n" + " # ### end Alembic commands ###", + ) + def test_no_needless_pass(self): writer1 = autogenerate.Rewriter() -- 2.47.2