]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Improve `Rewriter` implementation
authorl-hedgehog <l-hedgehog@outlook.com>
Mon, 11 Dec 2023 20:24:56 +0000 (15:24 -0500)
committersqla-tester <sqla-tester@sqlalchemy.org>
Mon, 11 Dec 2023 20:24:56 +0000 (15:24 -0500)
Fixes #1337

### Description
* Fix the chaining of more than two rewriters
* Wrap a callable so that it could be chained

This works in my local test, and I hope it makes sense to use the callable wrapper as the base class.

### Checklist
This pull request is:

- [ ] A documentation / typographical error fix
- Good to go, no issue or tests are needed
- [x] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

Closes: #1368
Pull-request: https://github.com/sqlalchemy/alembic/pull/1368
Pull-request-sha: e62633dbbb9b8dd91a145a6f27efcdbe4e4c0b3b

Change-Id: I7642a3ec8c6b8923f70ae8e7b6dbd482cb15eff9

alembic/autogenerate/rewriter.py
docs/build/unreleased/1337.rst [new file with mode: 0644]
tests/test_script_production.py

index 68a93dd0ab8b35ab24e4bfee23056ad36bad6a70..3efb499b366cf111f2d0cf4ded5a7c8726214744 100644 (file)
@@ -4,7 +4,7 @@ from typing import Any
 from typing import Callable
 from typing import Iterator
 from typing import List
-from typing import Optional
+from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
 from typing import Union
@@ -23,6 +23,10 @@ if TYPE_CHECKING:
     from ..runtime.environment import _GetRevArg
     from ..runtime.migration import MigrationContext
 
+ProcessRevisionDirectiveFn = Callable[
+    ["MigrationContext", "_GetRevArg", List["MigrationScript"]], None
+]
+
 
 class Rewriter:
     """A helper object that allows easy 'rewriting' of ops streams.
@@ -52,15 +56,21 @@ class Rewriter:
 
     _traverse = util.Dispatcher()
 
-    _chained: Optional[Rewriter] = None
+    _chained: Tuple[Union[ProcessRevisionDirectiveFn, Rewriter], ...] = ()
 
     def __init__(self) -> None:
         self.dispatch = util.Dispatcher()
 
-    def chain(self, other: Rewriter) -> Rewriter:
+    def chain(
+        self,
+        other: Union[
+            ProcessRevisionDirectiveFn,
+            Rewriter,
+        ],
+    ) -> Rewriter:
         """Produce a "chain" of this :class:`.Rewriter` to another.
 
-        This allows two rewriters to operate serially on a stream,
+        This allows two or more rewriters to operate serially on a stream,
         e.g.::
 
             writer1 = autogenerate.Rewriter()
@@ -89,7 +99,7 @@ class Rewriter:
         """
         wr = self.__class__.__new__(self.__class__)
         wr.__dict__.update(self.__dict__)
-        wr._chained = other
+        wr._chained += (other,)
         return wr
 
     def rewrites(
@@ -146,8 +156,8 @@ class Rewriter:
         directives: List[MigrationScript],
     ) -> None:
         self.process_revision_directives(context, revision, directives)
-        if self._chained:
-            self._chained(context, revision, directives)
+        for process_revision_directives in self._chained:
+            process_revision_directives(context, revision, directives)
 
     @_traverse.dispatch_for(ops.MigrationScript)
     def _traverse_script(
diff --git a/docs/build/unreleased/1337.rst b/docs/build/unreleased/1337.rst
new file mode 100644 (file)
index 0000000..2660e83
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, autogenerate
+    :tickets: 1337
+
+    Fixes `autogenerate.Rewriter` so that more than two instances could be
+    chained together correctly, and `process_revision_directives` callable
+    could also be chained.
index 3b5a6f60461eb7905c28b52c58bff7f31b36fe2a..7b7db814bd5d4c6fb56046085ec7857a6ce3a20c 100644 (file)
@@ -933,6 +933,11 @@ class RewriterTest(TestBase):
             idx_op = ops.CreateIndexOp("ixt", op.table_name, [op.column.name])
             return [op, idx_op]
 
+        def process_revision_directives(context, revision, generate_revisions):
+            generate_revisions[0].downgrade_ops = ops.DowngradeOps(
+                ops=[ops.DropColumnOp("t1", "x")]
+            )
+
         directives = [
             ops.MigrationScript(
                 util.rev_id(),
@@ -956,7 +961,8 @@ class RewriterTest(TestBase):
         ]
 
         ctx, rev = mock.Mock(), mock.Mock()
-        writer1.chain(writer2)(ctx, rev, directives)
+        writer = writer1.chain(process_revision_directives).chain(writer2)
+        writer(ctx, rev, directives)
 
         eq_(
             autogenerate.render_python_code(directives[0].upgrade_ops),
@@ -970,6 +976,13 @@ class RewriterTest(TestBase):
             "    # ### end Alembic commands ###",
         )
 
+        eq_(
+            autogenerate.render_python_code(directives[0].downgrade_ops),
+            "# ### commands auto generated by Alembic - please adjust! ###\n"
+            "    op.drop_column('t1', 'x')\n"
+            "    # ### end Alembic commands ###",
+        )
+
     def test_no_needless_pass(self):
         writer1 = autogenerate.Rewriter()