]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
render .info in create_table
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Oct 2023 14:59:39 +0000 (10:59 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Oct 2023 15:37:37 +0000 (11:37 -0400)
Fixed regression caused by :ticket:`879` released in 1.7.0 where the
".info" dictionary of ``Table`` would not render in autogenerate create
table statements.  This can be useful for custom create table DDL rendering
schemes so it is restored.

Additionally upon seeing that Rewriter is failing typing that was
just imporved in the previous commit for #1325 /
Ibfb7a57a081818c290cf0964d12a72b85c2c1983, further correct the typing
of the "revision" argument for process_revision_directives which was
still inconsistent.

Change-Id: Ifa4c7bd1b730d51629f42bc159b994f42d157c04
Fixes: #1329
alembic/autogenerate/api.py
alembic/autogenerate/render.py
alembic/autogenerate/rewriter.py
alembic/context.pyi
alembic/runtime/environment.py
alembic/script/revision.py
docs/build/unreleased/1329.rst [new file with mode: 0644]
tests/test_autogen_render.py

index 13d025b28ee300f41a2d39e971d929fc1ac9cff2..7282487be240d54c1e40ff04ede5e8b4e70ae91b 100644 (file)
@@ -2,7 +2,6 @@ from __future__ import annotations
 
 import contextlib
 from typing import Any
-from typing import Callable
 from typing import Dict
 from typing import Iterator
 from typing import List
@@ -35,6 +34,7 @@ if TYPE_CHECKING:
     from ..operations.ops import UpgradeOps
     from ..runtime.environment import NameFilterParentNames
     from ..runtime.environment import NameFilterType
+    from ..runtime.environment import ProcessRevisionDirectiveFn
     from ..runtime.environment import RenderItemFn
     from ..runtime.migration import MigrationContext
     from ..script.base import Script
@@ -510,13 +510,16 @@ class RevisionContext:
     file generation operation."""
 
     generated_revisions: List[MigrationScript]
+    process_revision_directives: Optional[ProcessRevisionDirectiveFn]
 
     def __init__(
         self,
         config: Config,
         script_directory: ScriptDirectory,
         command_args: Dict[str, Any],
-        process_revision_directives: Optional[Callable] = None,
+        process_revision_directives: Optional[
+            ProcessRevisionDirectiveFn
+        ] = None,
     ) -> None:
         self.config = config
         self.script_directory = script_directory
index 1f4bcf898b6dd8b6ed9a52299a6766f12ca1e705..9c84cd6c51745ed9ab68412e4a2d9f8bb3f08612 100644 (file)
@@ -245,6 +245,11 @@ def _add_table(autogen_context: AutogenContext, op: ops.CreateTableOp) -> str:
     comment = table.comment
     if comment:
         text += ",\ncomment=%r" % _ident(comment)
+
+    info = table.info
+    if info:
+        text += f",\ninfo={info!r}"
+
     for k in sorted(op.kw):
         text += ",\n%s=%r" % (k.replace(" ", "_"), op.kw[k])
 
index 4209c32149e4e3cbc790eeef353679f8f0c0acaa..68a93dd0ab8b35ab24e4bfee23056ad36bad6a70 100644 (file)
@@ -9,19 +9,19 @@ from typing import Type
 from typing import TYPE_CHECKING
 from typing import Union
 
-from alembic import util
-from alembic.operations import ops
+from .. import util
+from ..operations import ops
 
 if TYPE_CHECKING:
-    from alembic.operations.ops import AddColumnOp
-    from alembic.operations.ops import AlterColumnOp
-    from alembic.operations.ops import CreateTableOp
-    from alembic.operations.ops import MigrateOperation
-    from alembic.operations.ops import MigrationScript
-    from alembic.operations.ops import ModifyTableOps
-    from alembic.operations.ops import OpContainer
-    from alembic.runtime.migration import MigrationContext
-    from alembic.script.revision import Revision
+    from ..operations.ops import AddColumnOp
+    from ..operations.ops import AlterColumnOp
+    from ..operations.ops import CreateTableOp
+    from ..operations.ops import MigrateOperation
+    from ..operations.ops import MigrationScript
+    from ..operations.ops import ModifyTableOps
+    from ..operations.ops import OpContainer
+    from ..runtime.environment import _GetRevArg
+    from ..runtime.migration import MigrationContext
 
 
 class Rewriter:
@@ -119,7 +119,7 @@ class Rewriter:
     def _rewrite(
         self,
         context: MigrationContext,
-        revision: Revision,
+        revision: _GetRevArg,
         directive: MigrateOperation,
     ) -> Iterator[MigrateOperation]:
         try:
@@ -142,7 +142,7 @@ class Rewriter:
     def __call__(
         self,
         context: MigrationContext,
-        revision: Revision,
+        revision: _GetRevArg,
         directives: List[MigrationScript],
     ) -> None:
         self.process_revision_directives(context, revision, directives)
@@ -153,7 +153,7 @@ class Rewriter:
     def _traverse_script(
         self,
         context: MigrationContext,
-        revision: Revision,
+        revision: _GetRevArg,
         directive: MigrationScript,
     ) -> None:
         upgrade_ops_list = []
@@ -180,7 +180,7 @@ class Rewriter:
     def _traverse_op_container(
         self,
         context: MigrationContext,
-        revision: Revision,
+        revision: _GetRevArg,
         directive: OpContainer,
     ) -> None:
         self._traverse_list(context, revision, directive.ops)
@@ -189,7 +189,7 @@ class Rewriter:
     def _traverse_any_directive(
         self,
         context: MigrationContext,
-        revision: Revision,
+        revision: _GetRevArg,
         directive: MigrateOperation,
     ) -> None:
         pass
@@ -197,7 +197,7 @@ class Rewriter:
     def _traverse_for(
         self,
         context: MigrationContext,
-        revision: Revision,
+        revision: _GetRevArg,
         directive: MigrateOperation,
     ) -> Any:
         directives = list(self._rewrite(context, revision, directive))
@@ -209,7 +209,7 @@ class Rewriter:
     def _traverse_list(
         self,
         context: MigrationContext,
-        revision: Revision,
+        revision: _GetRevArg,
         directives: Any,
     ) -> None:
         dest = []
@@ -221,7 +221,7 @@ class Rewriter:
     def process_revision_directives(
         self,
         context: MigrationContext,
-        revision: Revision,
+        revision: _GetRevArg,
         directives: List[MigrationScript],
     ) -> None:
         self._traverse_list(context, revision, directives)
index 85e0cf75a472d4927ca7ed9f933c82282c70c3d8..f37f246183428423380c421537cfee811f596354 100644 (file)
@@ -7,6 +7,7 @@ from typing import Callable
 from typing import Collection
 from typing import ContextManager
 from typing import Dict
+from typing import Iterable
 from typing import List
 from typing import Literal
 from typing import Mapping
@@ -143,7 +144,12 @@ def configure(
     include_schemas: bool = False,
     process_revision_directives: Optional[
         Callable[
-            [MigrationContext, Tuple[str, str], List[MigrationScript]], None
+            [
+                MigrationContext,
+                Union[str, Iterable[Optional[str]], Iterable[str]],
+                List[MigrationScript],
+            ],
+            None,
         ]
     ] = None,
     compare_type: Union[
index a1c0e1b05265bdb3c2a660098fda1fdcc4ddd9b7..7640f563a99a399c952258ca71cfa6191c3a3e54 100644 (file)
@@ -23,6 +23,7 @@ from .migration import _ProxyTransaction
 from .migration import MigrationContext
 from .. import util
 from ..operations import Operations
+from ..script.revision import _GetRevArg
 
 if TYPE_CHECKING:
     from sqlalchemy.engine import URL
@@ -42,7 +43,7 @@ if TYPE_CHECKING:
 _RevNumber = Optional[Union[str, Tuple[str, ...]]]
 
 ProcessRevisionDirectiveFn = Callable[
-    [MigrationContext, Tuple[str, str], List["MigrationScript"]], None
+    [MigrationContext, _GetRevArg, List["MigrationScript"]], None
 ]
 
 RenderItemFn = Callable[
index aa0e9040f6e080781d1607a47a054814283de79d..035026441fad89518cf4dd91e3e9947d0624ef00 100644 (file)
@@ -32,14 +32,8 @@ if TYPE_CHECKING:
 _RevIdType = Union[str, List[str], Tuple[str, ...]]
 _GetRevArg = Union[
     str,
-    List[Optional[str]],
-    Tuple[Optional[str], ...],
-    FrozenSet[Optional[str]],
-    Set[Optional[str]],
-    List[str],
-    Tuple[str, ...],
-    FrozenSet[str],
-    Set[str],
+    Iterable[Optional[str]],
+    Iterable[str],
 ]
 _RevisionIdentifierType = Union[str, Tuple[str, ...], None]
 _RevisionOrStr = Union["Revision", str]
@@ -738,7 +732,7 @@ class RevisionMap:
         )
 
     def _resolve_revision_number(
-        self, id_: Optional[str]
+        self, id_: Optional[_GetRevArg]
     ) -> Tuple[Tuple[str, ...], Optional[str]]:
         branch_label: Optional[str]
         if isinstance(id_, str) and "@" in id_:
diff --git a/docs/build/unreleased/1329.rst b/docs/build/unreleased/1329.rst
new file mode 100644 (file)
index 0000000..b6065d9
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, autogenerate, regression
+    :tickets: 1329
+
+    Fixed regression caused by :ticket:`879` released in 1.7.0 where the
+    ".info" dictionary of ``Table`` would not render in autogenerate create
+    table statements.  This can be useful for custom create table DDL rendering
+    schemes so it is restored.
index 5c200b1fafc2cc05479c257cbf15e398139a93c4..88aa978cc93654f74437ad33431de54aad4a5f02 100644 (file)
@@ -1983,6 +1983,27 @@ class AutogenRenderTest(TestBase):
             ")",
         )
 
+    def test_render_table_with_info(self):
+        m = MetaData()
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("q", Integer, ForeignKey("address.id")),
+            info={"oracle_partition": "PARTITION BY ..."},
+        )
+        op_obj = ops.CreateTableOp.from_table(t)
+        eq_ignore_whitespace(
+            autogenerate.render_op_text(self.autogen_context, op_obj),
+            "op.create_table('test',"
+            "sa.Column('id', sa.Integer(), nullable=False),"
+            "sa.Column('q', sa.Integer(), nullable=True),"
+            "sa.ForeignKeyConstraint(['q'], ['address.id'], ),"
+            "sa.PrimaryKeyConstraint('id'),"
+            "info={'oracle_partition': 'PARTITION BY ...'}"
+            ")",
+        )
+
     def test_render_add_column_with_comment(self):
         op_obj = ops.AddColumnOp(
             "foo", Column("x", Integer, comment="This is a Column")