]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Improve typing.
authorFederico Caselli <cfederico87@gmail.com>
Thu, 13 Apr 2023 20:22:14 +0000 (22:22 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 13 Apr 2023 20:23:03 +0000 (22:23 +0200)
Correctly pass previously ignored arguments ``insert_before`` and
``insert_after`` in ``batch_alter_column``

Fixes: #1221
Change-Id: I79c9144f3e521fca00a0c32462ae2a69f9f7a032

14 files changed:
alembic/autogenerate/api.py
alembic/autogenerate/compare.py
alembic/context.pyi
alembic/ddl/impl.py
alembic/op.pyi
alembic/operations/base.py
alembic/operations/batch.py
alembic/operations/ops.py
alembic/runtime/environment.py
alembic/runtime/migration.py
alembic/util/sqla_compat.py
docs/build/unreleased/1221.rst [new file with mode: 0644]
tests/test_mysql.py
tools/write_pyi.py

index d7a0913d69f82a7511f371088a3976932e9d505c..9a3b003b7bb5d87f3bff2974a92b7337901838fe 100644 (file)
@@ -9,7 +9,6 @@ from typing import Optional
 from typing import Set
 from typing import Tuple
 from typing import TYPE_CHECKING
-from typing import Union
 
 from sqlalchemy import inspect
 
@@ -25,19 +24,18 @@ if TYPE_CHECKING:
     from sqlalchemy.engine import Connection
     from sqlalchemy.engine import Dialect
     from sqlalchemy.engine import Inspector
-    from sqlalchemy.sql.schema import Column
-    from sqlalchemy.sql.schema import ForeignKeyConstraint
-    from sqlalchemy.sql.schema import Index
     from sqlalchemy.sql.schema import MetaData
-    from sqlalchemy.sql.schema import Table
-    from sqlalchemy.sql.schema import UniqueConstraint
+    from sqlalchemy.sql.schema import SchemaItem
 
-    from alembic.config import Config
-    from alembic.operations.ops import MigrationScript
-    from alembic.operations.ops import UpgradeOps
-    from alembic.runtime.migration import MigrationContext
-    from alembic.script.base import Script
-    from alembic.script.base import ScriptDirectory
+    from ..config import Config
+    from ..operations.ops import MigrationScript
+    from ..operations.ops import UpgradeOps
+    from ..runtime.environment import NameFilterParentNames
+    from ..runtime.environment import NameFilterType
+    from ..runtime.environment import RenderItemFn
+    from ..runtime.migration import MigrationContext
+    from ..script.base import Script
+    from ..script.base import ScriptDirectory
 
 
 def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any:
@@ -172,7 +170,7 @@ def render_python_code(
     alembic_module_prefix: str = "op.",
     render_as_batch: bool = False,
     imports: Tuple[str, ...] = (),
-    render_item: None = None,
+    render_item: Optional[RenderItemFn] = None,
     migration_context: Optional[MigrationContext] = None,
 ) -> str:
     """Render Python code given an :class:`.UpgradeOps` or
@@ -359,8 +357,8 @@ class AutogenContext:
     def run_name_filters(
         self,
         name: Optional[str],
-        type_: str,
-        parent_names: Dict[str, Optional[str]],
+        type_: NameFilterType,
+        parent_names: NameFilterParentNames,
     ) -> bool:
         """Run the context's name filters and return True if the targets
         should be part of the autogenerate operation.
@@ -396,17 +394,11 @@ class AutogenContext:
 
     def run_object_filters(
         self,
-        object_: Union[
-            Table,
-            Index,
-            Column,
-            UniqueConstraint,
-            ForeignKeyConstraint,
-        ],
+        object_: SchemaItem,
         name: Optional[str],
-        type_: str,
+        type_: NameFilterType,
         reflected: bool,
-        compare_to: Optional[Union[Table, Index, Column, UniqueConstraint]],
+        compare_to: Optional[SchemaItem],
     ) -> bool:
         """Run the context's object filters and return True if the targets
         should be part of the autogenerate operation.
index 85cb426ed1cf742909ba5039e0f84ce6c8d9203f..595631ceb550e64d0857d9246efcb70e3ef6ff07 100644 (file)
@@ -212,7 +212,7 @@ def _compare_tables(
                 (inspector),
                 # fmt: on
             )
-            sqla_compat._reflect_table(inspector, t, None)
+            sqla_compat._reflect_table(inspector, t)
         if autogen_context.run_object_filters(t, tname, "table", True, None):
 
             modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
@@ -243,7 +243,7 @@ def _compare_tables(
                 _compat_autogen_column_reflect(inspector),
                 # fmt: on
             )
-            sqla_compat._reflect_table(inspector, t, None)
+            sqla_compat._reflect_table(inspector, t)
         conn_column_info[(s, tname)] = t
 
     for s, tname in sorted(existing_tables, key=lambda x: (x[0] or "", x[1])):
index a9f48b226f30ea5b25b23a3098752ffebe5e5dab..1007a5ef75aadd8bb4394572d8f07a5a2c8628a0 100644 (file)
@@ -3,7 +3,6 @@
 from __future__ import annotations
 
 from typing import Any
-from typing import Callable
 from typing import ContextManager
 from typing import Dict
 from typing import List
@@ -22,7 +21,11 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.schema import MetaData
 
     from .config import Config
-    from .operations import MigrateOperation
+    from .runtime.environment import IncludeNameFn
+    from .runtime.environment import IncludeObjectFn
+    from .runtime.environment import OnVersionApplyFn
+    from .runtime.environment import ProcessRevisionDirectiveFn
+    from .runtime.environment import RenderItemFn
     from .runtime.migration import _ProxyTransaction
     from .runtime.migration import MigrationContext
     from .script import ScriptDirectory
@@ -76,7 +79,7 @@ config: Config
 
 def configure(
     connection: Optional[Connection] = None,
-    url: Union[str, URL, None] = None,
+    url: Optional[Union[str, URL]] = None,
     dialect_name: Optional[str] = None,
     dialect_opts: Optional[Dict[str, Any]] = None,
     transactional_ddl: Optional[bool] = None,
@@ -87,24 +90,20 @@ def configure(
     template_args: Optional[Dict[str, Any]] = None,
     render_as_batch: bool = False,
     target_metadata: Optional[MetaData] = None,
-    include_name: Optional[Callable[..., bool]] = None,
-    include_object: Optional[Callable[..., bool]] = None,
+    include_name: Optional[IncludeNameFn] = None,
+    include_object: Optional[IncludeObjectFn] = None,
     include_schemas: bool = False,
-    process_revision_directives: Optional[
-        Callable[
-            [MigrationContext, Tuple[str, str], List[MigrateOperation]], None
-        ]
-    ] = None,
+    process_revision_directives: Optional[ProcessRevisionDirectiveFn] = None,
     compare_type: bool = False,
     compare_server_default: bool = False,
-    render_item: Optional[Callable[..., bool]] = None,
+    render_item: Optional[RenderItemFn] = None,
     literal_binds: bool = False,
     upgrade_token: str = "upgrades",
     downgrade_token: str = "downgrades",
     alembic_module_prefix: str = "op.",
     sqlalchemy_module_prefix: str = "sa.",
     user_module_prefix: Optional[str] = None,
-    on_version_apply: Optional[Callable[..., None]] = None,
+    on_version_apply: Optional[OnVersionApplyFn] = None,
     **kw: Any,
 ) -> None:
     """Configure a :class:`.MigrationContext` within this
@@ -308,7 +307,8 @@ def configure(
        ``"unique_constraint"``, or ``"foreign_key_constraint"``
      * ``parent_names``: a dictionary of "parent" object names, that are
        relative to the name being given.  Keys in this dictionary may
-       include:  ``"schema_name"``, ``"table_name"``.
+       include:  ``"schema_name"``, ``"table_name"`` or
+       ``"schema_qualified_table_name"``.
 
      E.g.::
 
index f11d1edc1a2fe505de59553642f49c53f46d16e6..84f5d86cc4674b1507c0c5713228a31dcb810f47 100644 (file)
@@ -155,9 +155,9 @@ class DefaultImpl(metaclass=ImplMeta):
     def _exec(
         self,
         construct: Union[ClauseElement, str],
-        execution_options: Optional[dict] = None,
+        execution_options: Optional[dict[str, Any]] = None,
         multiparams: Sequence[dict] = (),
-        params: Dict[str, int] = util.immutabledict(),
+        params: Dict[str, Any] = util.immutabledict(),
     ) -> Optional[CursorResult]:
         if isinstance(construct, str):
             construct = text(construct)
@@ -197,7 +197,7 @@ class DefaultImpl(metaclass=ImplMeta):
     def execute(
         self,
         sql: Union[ClauseElement, str],
-        execution_options: None = None,
+        execution_options: Optional[dict[str, Any]] = None,
     ) -> None:
         self._exec(sql, execution_options)
 
index dc94113996666acf42960bca18ef472f9fef05bf..dab58568d842f155754a19fa0ac1142f0a216a50 100644 (file)
@@ -951,7 +951,8 @@ def drop_table_comment(
     """
 
 def execute(
-    sqltext: Union[str, TextClause, Update], execution_options: None = None
+    sqltext: Union[str, TextClause, Update],
+    execution_options: Optional[dict[str, Any]] = None,
 ) -> Optional[Table]:
     r"""Execute the given SQL using the current migration context.
 
@@ -1101,7 +1102,7 @@ def implementation_for(op_cls: Any) -> Callable[..., Any]:
     """
 
 def inline_literal(
-    value: Union[str, int], type_: None = None
+    value: Union[str, int], type_: Optional[TypeEngine] = None
 ) -> _literal_bindparam:
     r"""Produce an 'inline literal' expression, suitable for
     using in an INSERT, UPDATE, or DELETE statement.
index 04b66b55dceffcb4ebcce0c59d78b5d80b3320e1..82d977922b0dd0b0d58752d18ecd0d673c44d0e2 100644 (file)
@@ -33,8 +33,9 @@ NoneType = type(None)
 if TYPE_CHECKING:
     from typing import Literal
 
-    from sqlalchemy import Table  # noqa
+    from sqlalchemy import Table
     from sqlalchemy.engine import Connection
+    from sqlalchemy.types import TypeEngine
 
     from .batch import BatchOperationsImpl
     from .ops import MigrateOperation
@@ -439,7 +440,7 @@ class Operations(util.ModuleClsProxy):
         return conv(name)
 
     def inline_literal(
-        self, value: Union[str, int], type_: None = None
+        self, value: Union[str, int], type_: Optional[TypeEngine[Any]] = None
     ) -> _literal_bindparam:
         r"""Produce an 'inline literal' expression, suitable for
         using in an INSERT, UPDATE, or DELETE statement.
index 00f13a1bea399828c097c7826a69296f52e4315a..f4a058bc9a201f5039d14d938a0f38de43de7644 100644 (file)
@@ -487,7 +487,7 @@ class ApplyBatchImpl:
         server_default: Optional[Union[Function[Any], str, bool]] = False,
         name: Optional[str] = None,
         type_: Optional[TypeEngine] = None,
-        autoincrement: None = None,
+        autoincrement: Optional[Union[bool, Literal["auto"]]] = None,
         comment: Union[str, Literal[False]] = False,
         **kw,
     ) -> None:
index b3ef5bb6a09bb29da231ddb18b1d6b9343c82110..7dd65a1f0fa7394e4503aada0acd9dec82872262 100644 (file)
@@ -673,11 +673,11 @@ class CreateForeignKeyOp(AddConstraintOp):
         local_cols: List[str],
         remote_cols: List[str],
         referent_schema: Optional[str] = None,
-        onupdate: None = None,
-        ondelete: None = None,
-        deferrable: None = None,
-        initially: None = None,
-        match: None = None,
+        onupdate: Optional[str] = None,
+        ondelete: Optional[str] = None,
+        deferrable: Optional[bool] = None,
+        initially: Optional[str] = None,
+        match: Optional[str] = None,
         **dialect_kw: Any,
     ) -> None:
         """Issue a "create foreign key" instruction using the
@@ -1890,10 +1890,10 @@ class AlterColumnOp(AlterTableOp):
         type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
         existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
         existing_server_default: bool = False,
-        existing_nullable: None = None,
-        existing_comment: None = None,
-        insert_before: None = None,
-        insert_after: None = None,
+        existing_nullable: Optional[bool] = None,
+        existing_comment: Optional[str] = None,
+        insert_before: Optional[str] = None,
+        insert_after: Optional[str] = None,
         **kw: Any,
     ) -> Optional[Table]:
         """Issue an "alter column" instruction using the current
@@ -1935,6 +1935,8 @@ class AlterColumnOp(AlterTableOp):
             modify_server_default=server_default,
             modify_nullable=nullable,
             modify_comment=comment,
+            insert_before=insert_before,
+            insert_after=insert_after,
             **kw,
         )
 
@@ -2314,7 +2316,7 @@ class ExecuteSQLOp(MigrateOperation):
     def __init__(
         self,
         sqltext: Union[Update, str, Insert, TextClause],
-        execution_options: None = None,
+        execution_options: Optional[dict[str, Any]] = None,
     ) -> None:
         self.sqltext = sqltext
         self.execution_options = execution_options
@@ -2324,7 +2326,7 @@ class ExecuteSQLOp(MigrateOperation):
         cls,
         operations: Operations,
         sqltext: Union[str, TextClause, Update],
-        execution_options: None = None,
+        execution_options: Optional[dict[str, Any]] = None,
     ) -> Optional[Table]:
         r"""Execute the given SQL using the current migration context.
 
index c2fa11adb2885928a1228a78b76c3584d15cee57..f5c177e8aae1276e9693c36585930b534d3858f9 100644 (file)
@@ -2,9 +2,12 @@ from __future__ import annotations
 
 from typing import Any
 from typing import Callable
+from typing import Collection
 from typing import ContextManager
 from typing import Dict
 from typing import List
+from typing import Mapping
+from typing import MutableMapping
 from typing import Optional
 from typing import overload
 from typing import TextIO
@@ -12,19 +15,23 @@ from typing import Tuple
 from typing import TYPE_CHECKING
 from typing import Union
 
+from typing_extensions import Literal
+
 from .migration import _ProxyTransaction
 from .migration import MigrationContext
 from .. import util
 from ..operations import Operations
 
 if TYPE_CHECKING:
-    from typing import Literal
 
     from sqlalchemy.engine import URL
     from sqlalchemy.engine.base import Connection
     from sqlalchemy.sql.elements import ClauseElement
     from sqlalchemy.sql.schema import MetaData
+    from sqlalchemy.sql.schema import SchemaItem
 
+    from .migration import MigrationInfo
+    from ..autogenerate.api import AutogenContext
     from ..config import Config
     from ..ddl import DefaultImpl
     from ..operations.ops import MigrateOperation
@@ -36,6 +43,42 @@ ProcessRevisionDirectiveFn = Callable[
     [MigrationContext, Tuple[str, str], List["MigrateOperation"]], None
 ]
 
+RenderItemFn = Callable[
+    [str, Any, "AutogenContext"], Union[str, Literal[False]]
+]
+
+NameFilterType = Literal[
+    "schema",
+    "table",
+    "column",
+    "index",
+    "unique_constraint",
+    "foreign_key_constraint",
+]
+NameFilterParentNames = MutableMapping[
+    Literal["schema_name", "table_name", "schema_qualified_table_name"],
+    Optional[str],
+]
+IncludeNameFn = Callable[
+    [Optional[str], NameFilterType, NameFilterParentNames], bool
+]
+
+IncludeObjectFn = Callable[
+    [
+        "SchemaItem",
+        Optional[str],
+        NameFilterType,
+        bool,
+        Optional["SchemaItem"],
+    ],
+    bool,
+]
+
+OnVersionApplyFn = Callable[
+    [MigrationContext, "MigrationInfo", Collection[Any], Mapping[str, Any]],
+    None,
+]
+
 
 class EnvironmentContext(util.ModuleClsProxy):
 
@@ -346,22 +389,22 @@ class EnvironmentContext(util.ModuleClsProxy):
         template_args: Optional[Dict[str, Any]] = None,
         render_as_batch: bool = False,
         target_metadata: Optional[MetaData] = None,
-        include_name: Optional[Callable[..., bool]] = None,
-        include_object: Optional[Callable[..., bool]] = None,
+        include_name: Optional[IncludeNameFn] = None,
+        include_object: Optional[IncludeObjectFn] = None,
         include_schemas: bool = False,
         process_revision_directives: Optional[
             ProcessRevisionDirectiveFn
         ] = None,
         compare_type: bool = False,
         compare_server_default: bool = False,
-        render_item: Optional[Callable[..., bool]] = None,
+        render_item: Optional[RenderItemFn] = None,
         literal_binds: bool = False,
         upgrade_token: str = "upgrades",
         downgrade_token: str = "downgrades",
         alembic_module_prefix: str = "op.",
         sqlalchemy_module_prefix: str = "sa.",
         user_module_prefix: Optional[str] = None,
-        on_version_apply: Optional[Callable[..., None]] = None,
+        on_version_apply: Optional[OnVersionApplyFn] = None,
         **kw: Any,
     ) -> None:
         """Configure a :class:`.MigrationContext` within this
@@ -565,7 +608,8 @@ class EnvironmentContext(util.ModuleClsProxy):
            ``"unique_constraint"``, or ``"foreign_key_constraint"``
          * ``parent_names``: a dictionary of "parent" object names, that are
            relative to the name being given.  Keys in this dictionary may
-           include:  ``"schema_name"``, ``"table_name"``.
+           include:  ``"schema_name"``, ``"table_name"`` or
+           ``"schema_qualified_table_name"``.
 
          E.g.::
 
index 4e2d06251083fc1e49a6020c99bc245ecebfb9e3..cfba0e3e494bc22815f3f1eb42c74965d5f0ab09 100644 (file)
@@ -5,10 +5,12 @@ from contextlib import nullcontext
 import logging
 import sys
 from typing import Any
+from typing import Callable
 from typing import cast
 from typing import Collection
 from typing import ContextManager
 from typing import Dict
+from typing import Iterable
 from typing import Iterator
 from typing import List
 from typing import Optional
@@ -74,7 +76,7 @@ class _ProxyTransaction:
     def __enter__(self) -> _ProxyTransaction:
         return self
 
-    def __exit__(self, type_: None, value: None, traceback: None) -> None:
+    def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
         if self._proxied_transaction is not None:
             self._proxied_transaction.__exit__(type_, value, traceback)
             self.migration_context._transaction = None
@@ -158,7 +160,9 @@ class MigrationContext:
                 sqla_compat._get_connection_in_transaction(connection)
             )
 
-        self._migrations_fn = opts.get("fn")
+        self._migrations_fn: Optional[
+            Callable[..., Iterable[RevisionStep]]
+        ] = opts.get("fn")
         self.as_sql = as_sql
 
         self.purge = opts.get("purge", False)
@@ -1275,7 +1279,7 @@ class StampStep(MigrationStep):
         self.migration_fn = self.stamp_revision
         self.revision_map = revision_map
 
-    doc: None = None
+    doc: Optional[str] = None
 
     def stamp_revision(self, **kw: Any) -> None:
         return None
index cab99494bddd328b21f42cfd1df5c07ac705b2b3..e2725d6c5c9ea718504315841440c659457ed64f 100644 (file)
@@ -299,9 +299,7 @@ def _columns_for_constraint(constraint):
         return list(constraint.columns)
 
 
-def _reflect_table(
-    inspector: Inspector, table: Table, include_cols: None
-) -> None:
+def _reflect_table(inspector: Inspector, table: Table) -> None:
     if sqla_14:
         return inspector.reflect_table(table, None)
     else:
diff --git a/docs/build/unreleased/1221.rst b/docs/build/unreleased/1221.rst
new file mode 100644 (file)
index 0000000..de14f15
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: bug, batch
+    :tickets: 1221
+
+    Correctly pass previously ignored arguments ``insert_before`` and
+    ``insert_after`` in ``batch_alter_column``
index 2145fd737db6503cdd899156b700a94b49632ce1..92c1819f16a5478f4689d061e5994673c8875cdb 100644 (file)
@@ -627,7 +627,7 @@ class MySQLDefaultCompareTest(TestBase):
         insp = inspect(self.bind)
         cols = insp.get_columns(t1.name)
         refl = Table(t1.name, MetaData())
-        sqla_compat._reflect_table(insp, refl, None)
+        sqla_compat._reflect_table(insp, refl)
         ctx = self.autogen_context["context"]
         return ctx.impl.compare_server_default(
             refl.c[cols[0]["name"]], col, rendered, cols[0]["default"]
index 4fbf36617ae2cfb54b401477d9ab5d736aac3409..fa79c495beec6d00e7b5d5ce059021aeb0a04fa7 100644 (file)
@@ -109,7 +109,8 @@ def generate_pyi_for_proxy(
                 # Do not generate the base implementation to avoid mypy errors
                 overloads = typing.get_overloads(meth)
                 if overloads:
-                    # use enumerate so we can generate docs on the last overload
+                    # use enumerate so we can generate docs on the
+                    # last overload
                     for i, ovl in enumerate(overloads, 1):
                         _generate_stub_for_meth(
                             ovl,