]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Improve typing
authorCaselIT <cfederico87@gmail.com>
Sun, 11 Sep 2022 20:38:46 +0000 (22:38 +0200)
committerCaselIT <cfederico87@gmail.com>
Mon, 12 Sep 2022 19:00:52 +0000 (21:00 +0200)
Change-Id: I9fc86c4a92e1b76d19c9e891ff08ce8a46ad4e35

18 files changed:
alembic/autogenerate/api.py
alembic/autogenerate/compare.py
alembic/autogenerate/render.py
alembic/context.pyi
alembic/ddl/base.py
alembic/ddl/impl.py
alembic/ddl/mssql.py
alembic/ddl/oracle.py
alembic/op.pyi
alembic/operations/base.py
alembic/operations/batch.py
alembic/operations/ops.py
alembic/operations/toimpl.py
alembic/runtime/environment.py
alembic/runtime/migration.py
alembic/util/sqla_compat.py
pyproject.toml
tools/write_pyi.py

index 5f1c7f3554bec57bc4d4db6e438331617830f03b..cbd64e18c387b07ff4ef73a54bc685a88fcd645c 100644 (file)
@@ -523,12 +523,12 @@ class RevisionContext:
 
     def run_autogenerate(
         self, rev: tuple, migration_context: "MigrationContext"
-    ):
+    ) -> None:
         self._run_environment(rev, migration_context, True)
 
     def run_no_autogenerate(
         self, rev: tuple, migration_context: "MigrationContext"
-    ):
+    ) -> None:
         self._run_environment(rev, migration_context, False)
 
     def _run_environment(
@@ -536,7 +536,7 @@ class RevisionContext:
         rev: tuple,
         migration_context: "MigrationContext",
         autogenerate: bool,
-    ):
+    ) -> None:
         if autogenerate:
             if self.command_args["sql"]:
                 raise util.CommandError(
index 5b698151982af8f787b815080b7cd0d14d1699df..c32ab4d9bb9f946af31ab32035888bb1c03922fe 100644 (file)
@@ -616,8 +616,8 @@ def _compare_indexes_and_uniques(
     # we know are either added implicitly by the DB or that the DB
     # can't accurately report on
     autogen_context.migration_context.impl.correct_for_autogen_constraints(
-        conn_uniques,
-        conn_indexes,
+        conn_uniques,  # type: ignore[arg-type]
+        conn_indexes,  # type: ignore[arg-type]
         metadata_unique_constraints,
         metadata_indexes,
     )
@@ -1274,7 +1274,8 @@ def _compare_foreign_keys(
     )
 
     conn_fks = set(
-        _make_foreign_key(const, conn_table) for const in conn_fks_list
+        _make_foreign_key(const, conn_table)  # type: ignore[arg-type]
+        for const in conn_fks_list
     )
 
     # give the dialect a chance to correct the FKs to match more
index 9c992b47cdfdcf495cdc50b65a90e36e0d67c9e6..1ac6753d9f1efa842f626ba71691313d5d2f835e 100644 (file)
@@ -989,7 +989,7 @@ def _fk_colspec(
         if table_fullname in namespace_metadata.tables:
             col = namespace_metadata.tables[table_fullname].c.get(colname)
             if col is not None:
-                colname = _ident(col.name)
+                colname = _ident(col.name)  # type: ignore[assignment]
 
     colspec = "%s.%s" % (table_fullname, colname)
 
index 14e1b5fbc345d02468302b328e2eee372b7cdd96..a2e53994a20f6c33724ce510d8d76434d032a2e5 100644 (file)
@@ -5,6 +5,8 @@ from __future__ import annotations
 from typing import Any
 from typing import Callable
 from typing import ContextManager
+from typing import Dict
+from typing import List
 from typing import Optional
 from typing import TextIO
 from typing import Tuple
@@ -13,6 +15,7 @@ from typing import Union
 
 if TYPE_CHECKING:
     from sqlalchemy.engine.base import Connection
+    from sqlalchemy.sql.elements import ClauseElement
     from sqlalchemy.sql.schema import MetaData
 
     from .config import Config
@@ -530,7 +533,9 @@ def configure(
 
     """
 
-def execute(sql, execution_options=None):
+def execute(
+    sql: Union[ClauseElement, str], execution_options: Optional[dict] = None
+) -> None:
     """Execute the given SQL using the current change context.
 
     The behavior of :meth:`.execute` is the same
@@ -543,7 +548,7 @@ def execute(sql, execution_options=None):
 
     """
 
-def get_bind():
+def get_bind() -> Connection:
     """Return the current 'bind'.
 
     In "online" mode, this is the
@@ -635,7 +640,9 @@ def get_tag_argument() -> Optional[str]:
 
     """
 
-def get_x_argument(as_dictionary: bool = False):
+def get_x_argument(
+    as_dictionary: bool = False,
+) -> Union[List[str], Dict[str, str]]:
     """Return the value(s) passed for the ``-x`` argument, if any.
 
     The ``-x`` argument is an open ended flag that allows any user-defined
@@ -723,7 +730,7 @@ def run_migrations(**kw: Any) -> None:
 
 script: ScriptDirectory
 
-def static_output(text):
+def static_output(text: str) -> None:
     """Emit text directly to the "offline" SQL stream.
 
     Typically this is for emitting comments that
index 7b0f63e811efb6b380e8278bd8b69674abe46ad5..c9107867d3c6a92901a4461cdf3e11f06e1a22f6 100644 (file)
@@ -294,7 +294,7 @@ def format_table_name(
 def format_column_name(
     compiler: "DDLCompiler", name: Optional[Union["quoted_name", str]]
 ) -> Union["quoted_name", str]:
-    return compiler.preparer.quote(name)
+    return compiler.preparer.quote(name)  # type: ignore[arg-type]
 
 
 def format_server_default(
index 070c124bdedf435efbc9b33b520ff90e9bf39de8..79d5245e5d30a91d494a4fba8b2f758931c9df5b 100644 (file)
@@ -23,19 +23,16 @@ from .. import util
 from ..util import sqla_compat
 
 if TYPE_CHECKING:
-    from io import StringIO
     from typing import Literal
+    from typing import TextIO
 
     from sqlalchemy.engine import Connection
     from sqlalchemy.engine import Dialect
     from sqlalchemy.engine.cursor import CursorResult
-    from sqlalchemy.engine.cursor import LegacyCursorResult
     from sqlalchemy.engine.reflection import Inspector
-    from sqlalchemy.sql.dml import Update
     from sqlalchemy.sql.elements import ClauseElement
     from sqlalchemy.sql.elements import ColumnElement
     from sqlalchemy.sql.elements import quoted_name
-    from sqlalchemy.sql.elements import TextClause
     from sqlalchemy.sql.schema import Column
     from sqlalchemy.sql.schema import Constraint
     from sqlalchemy.sql.schema import ForeignKeyConstraint
@@ -60,11 +57,11 @@ class ImplMeta(type):
     ):
         newtype = type.__init__(cls, classname, bases, dict_)
         if "__dialect__" in dict_:
-            _impls[dict_["__dialect__"]] = cls
+            _impls[dict_["__dialect__"]] = cls  # type: ignore[assignment]
         return newtype
 
 
-_impls: dict = {}
+_impls: Dict[str, Type["DefaultImpl"]] = {}
 
 Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
 
@@ -98,7 +95,7 @@ class DefaultImpl(metaclass=ImplMeta):
         connection: Optional["Connection"],
         as_sql: bool,
         transactional_ddl: Optional[bool],
-        output_buffer: Optional["StringIO"],
+        output_buffer: Optional["TextIO"],
         context_opts: Dict[str, Any],
     ) -> None:
         self.dialect = dialect
@@ -119,7 +116,7 @@ class DefaultImpl(metaclass=ImplMeta):
                 )
 
     @classmethod
-    def get_by_dialect(cls, dialect: "Dialect") -> Any:
+    def get_by_dialect(cls, dialect: "Dialect") -> Type["DefaultImpl"]:
         return _impls[dialect.name]
 
     def static_output(self, text: str) -> None:
@@ -158,10 +155,10 @@ class DefaultImpl(metaclass=ImplMeta):
     def _exec(
         self,
         construct: Union["ClauseElement", str],
-        execution_options: None = None,
+        execution_options: Optional[dict] = None,
         multiparams: Sequence[dict] = (),
         params: Dict[str, int] = util.immutabledict(),
-    ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
+    ) -> Optional["CursorResult"]:
         if isinstance(construct, str):
             construct = text(construct)
         if self.as_sql:
@@ -176,10 +173,11 @@ class DefaultImpl(metaclass=ImplMeta):
             else:
                 compile_kw = {}
 
+            compiled = construct.compile(
+                dialect=self.dialect, **compile_kw  # type: ignore[arg-type]
+            )
             self.static_output(
-                str(construct.compile(dialect=self.dialect, **compile_kw))
-                .replace("\t", "    ")
-                .strip()
+                str(compiled).replace("\t", "    ").strip()
                 + self.command_terminator
             )
             return None
@@ -192,11 +190,13 @@ class DefaultImpl(metaclass=ImplMeta):
                 assert isinstance(multiparams, tuple)
                 multiparams += (params,)
 
-            return conn.execute(construct, multiparams)
+            return conn.execute(  # type: ignore[call-overload]
+                construct, multiparams
+            )
 
     def execute(
         self,
-        sql: Union["Update", "TextClause", str],
+        sql: Union["ClauseElement", str],
         execution_options: None = None,
     ) -> None:
         self._exec(sql, execution_options)
@@ -424,9 +424,6 @@ class DefaultImpl(metaclass=ImplMeta):
                     )
                 )
         else:
-            # work around http://www.sqlalchemy.org/trac/ticket/2461
-            if not hasattr(table, "_autoincrement_column"):
-                table._autoincrement_column = None
             if rows:
                 if multiinsert:
                     self._exec(
@@ -572,7 +569,7 @@ class DefaultImpl(metaclass=ImplMeta):
             )
 
     def render_ddl_sql_expr(
-        self, expr: "ClauseElement", is_server_default: bool = False, **kw
+        self, expr: "ClauseElement", is_server_default: bool = False, **kw: Any
     ) -> str:
         """Render a SQL expression that is typically a server default,
         index expression, etc.
@@ -581,10 +578,14 @@ class DefaultImpl(metaclass=ImplMeta):
 
         """
 
-        compile_kw = dict(
-            compile_kwargs={"literal_binds": True, "include_table": False}
+        compile_kw = {
+            "compile_kwargs": {"literal_binds": True, "include_table": False}
+        }
+        return str(
+            expr.compile(
+                dialect=self.dialect, **compile_kw  # type: ignore[arg-type]
+            )
         )
-        return str(expr.compile(dialect=self.dialect, **compile_kw))
 
     def _compat_autogen_column_reflect(
         self, inspector: "Inspector"
index b48f8ba988ebb0f6181b43786b3fff89454fcf0e..28f0678e4167fca780a360fdcd6d343f4db2585a 100644 (file)
@@ -35,7 +35,6 @@ if TYPE_CHECKING:
     from sqlalchemy.dialects.mssql.base import MSDDLCompiler
     from sqlalchemy.dialects.mssql.base import MSSQLCompiler
     from sqlalchemy.engine.cursor import CursorResult
-    from sqlalchemy.engine.cursor import LegacyCursorResult
     from sqlalchemy.sql.schema import Index
     from sqlalchemy.sql.schema import Table
     from sqlalchemy.sql.selectable import TableClause
@@ -68,9 +67,7 @@ class MSSQLImpl(DefaultImpl):
             "mssql_batch_separator", self.batch_separator
         )
 
-    def _exec(
-        self, construct: Any, *args, **kw
-    ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
+    def _exec(self, construct: Any, *args, **kw) -> Optional["CursorResult"]:
         result = super(MSSQLImpl, self)._exec(construct, *args, **kw)
         if self.as_sql and self.batch_separator:
             self.static_output(self.batch_separator)
@@ -359,7 +356,7 @@ def visit_column_nullable(
     return "%s %s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
-        format_type(compiler, element.existing_type),
+        format_type(compiler, element.existing_type),  # type: ignore[arg-type]
         "NULL" if element.nullable else "NOT NULL",
     )
 
index 0e787fb1cd1483cde00b0a38a53cac432e3ba247..accd1fcfb2481a44e1a51ba40641a16528b1f136 100644 (file)
@@ -3,7 +3,6 @@ from __future__ import annotations
 from typing import Any
 from typing import Optional
 from typing import TYPE_CHECKING
-from typing import Union
 
 from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.sql import sqltypes
@@ -26,7 +25,6 @@ from .impl import DefaultImpl
 if TYPE_CHECKING:
     from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
     from sqlalchemy.engine.cursor import CursorResult
-    from sqlalchemy.engine.cursor import LegacyCursorResult
     from sqlalchemy.sql.schema import Column
 
 
@@ -48,9 +46,7 @@ class OracleImpl(DefaultImpl):
             "oracle_batch_separator", self.batch_separator
         )
 
-    def _exec(
-        self, construct: Any, *args, **kw
-    ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
+    def _exec(self, construct: Any, *args, **kw) -> Optional["CursorResult"]:
         result = super(OracleImpl, self)._exec(construct, *args, **kw)
         if self.as_sql and self.batch_separator:
             self.static_output(self.batch_separator)
index 59dfc58927ae2bf2f8a5f0d1cd12b06640d7b484..490d714614fcbd53f9f457f2987f22864779b8a1 100644 (file)
@@ -28,6 +28,7 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.schema import Column
     from sqlalchemy.sql.schema import Computed
     from sqlalchemy.sql.schema import Identity
+    from sqlalchemy.sql.schema import SchemaItem
     from sqlalchemy.sql.schema import Table
     from sqlalchemy.sql.type_api import TypeEngine
     from sqlalchemy.util import immutabledict
@@ -94,7 +95,7 @@ def alter_column(
     table_name: str,
     column_name: str,
     nullable: Optional[bool] = None,
-    comment: Union[str, bool, None] = False,
+    comment: Union[str, Literal[False], None] = False,
     server_default: Any = False,
     new_column_name: Optional[str] = None,
     type_: Union[TypeEngine, Type[TypeEngine], None] = None,
@@ -202,13 +203,13 @@ def batch_alter_table(
     schema: Optional[str] = None,
     recreate: Literal["auto", "always", "never"] = "auto",
     partial_reordering: Optional[tuple] = None,
-    copy_from: Optional["Table"] = None,
+    copy_from: Optional[Table] = None,
     table_args: Tuple[Any, ...] = (),
     table_kwargs: Mapping[str, Any] = immutabledict({}),
     reflect_args: Tuple[Any, ...] = (),
     reflect_kwargs: Mapping[str, Any] = immutabledict({}),
     naming_convention: Optional[Dict[str, str]] = None,
-) -> Iterator["BatchOperations"]:
+) -> Iterator[BatchOperations]:
     """Invoke a series of per-table migrations in batch.
 
     Batch mode allows a series of operations specific to a table
@@ -667,7 +668,9 @@ def create_primary_key(
 
     """
 
-def create_table(table_name: str, *columns, **kw: Any) -> Optional[Table]:
+def create_table(
+    table_name: str, *columns: SchemaItem, **kw: Any
+) -> Optional[Table]:
     """Issue a "create table" instruction using the current migration
     context.
 
index 9ecf3d4a9ad2c8a38c71f34dd7230e4ddc26c011..535dff0f9cdb22b7b67f718f52477e0471a2f2b5 100644 (file)
@@ -37,6 +37,7 @@ if TYPE_CHECKING:
 
     from .batch import BatchOperationsImpl
     from .ops import MigrateOperation
+    from ..ddl import DefaultImpl
     from ..runtime.migration import MigrationContext
     from ..util.sqla_compat import _literal_bindparam
 
@@ -74,6 +75,7 @@ class Operations(util.ModuleClsProxy):
 
     """
 
+    impl: Union["DefaultImpl", "BatchOperationsImpl"]
     _to_impl = util.Dispatcher()
 
     def __init__(
@@ -492,7 +494,7 @@ class Operations(util.ModuleClsProxy):
         In a SQL script context, this value is ``None``. [TODO: verify this]
 
         """
-        return self.migration_context.impl.bind
+        return self.migration_context.impl.bind  # type: ignore[return-value]
 
 
 class BatchOperations(Operations):
@@ -512,6 +514,8 @@ class BatchOperations(Operations):
 
     """
 
+    impl: "BatchOperationsImpl"
+
     def _noop(self, operation):
         raise NotImplementedError(
             "The %s method does not apply to a batch table alter operation."
index 71d26816f4dab70d556d602fda7591b0716c83e9..f1459e2bd82c5fbc370873e75b690d518d135b48 100644 (file)
@@ -236,7 +236,7 @@ class ApplyBatchImpl:
         self._grab_table_elements()
 
     @classmethod
-    def _calc_temp_name(cls, tablename: "quoted_name") -> str:
+    def _calc_temp_name(cls, tablename: Union["quoted_name", str]) -> str:
         return ("_alembic_tmp_%s" % tablename)[0:50]
 
     def _grab_table_elements(self) -> None:
@@ -280,7 +280,7 @@ class ApplyBatchImpl:
                         self.col_named_constraints[const.name] = (col, const)
 
         for idx in self.table.indexes:
-            self.indexes[idx.name] = idx
+            self.indexes[idx.name] = idx  # type: ignore[index]
 
         for k in self.table.kwargs:
             self.table_kwargs.setdefault(k, self.table.kwargs[k])
@@ -546,7 +546,7 @@ class ApplyBatchImpl:
                 existing.server_default = None
             else:
                 sql_schema.DefaultClause(
-                    server_default
+                    server_default  # type: ignore[arg-type]
                 )._set_parent(  # type:ignore[attr-defined]
                     existing
                 )
@@ -699,11 +699,11 @@ class ApplyBatchImpl:
                     self.columns[col.name].primary_key = False
 
     def create_index(self, idx: "Index") -> None:
-        self.new_indexes[idx.name] = idx
+        self.new_indexes[idx.name] = idx  # type: ignore[index]
 
     def drop_index(self, idx: "Index") -> None:
         try:
-            del self.indexes[idx.name]
+            del self.indexes[idx.name]  # type: ignore[arg-type]
         except KeyError:
             raise ValueError("No such index: '%s'" % idx.name)
 
index 997274d7d38566b7f23679339defe29d97345bd0..85ffe149bb319e968270a94a16b48579efbd6071 100644 (file)
@@ -26,6 +26,8 @@ from .. import util
 from ..util import sqla_compat
 
 if TYPE_CHECKING:
+    from typing import Literal
+
     from sqlalchemy.sql.dml import Insert
     from sqlalchemy.sql.dml import Update
     from sqlalchemy.sql.elements import BinaryExpression
@@ -885,7 +887,7 @@ class CreateIndexOp(MigrateOperation):
     def from_index(cls, index: "Index") -> "CreateIndexOp":
         assert index.table is not None
         return cls(
-            index.name,
+            index.name,  # type: ignore[arg-type]
             index.table.name,
             sqla_compat._get_index_expressions(index),
             schema=index.table.schema,
@@ -1021,7 +1023,7 @@ class DropIndexOp(MigrateOperation):
     def from_index(cls, index: "Index") -> "DropIndexOp":
         assert index.table is not None
         return cls(
-            index.name,
+            index.name,  # type: ignore[arg-type]
             index.table.name,
             schema=index.table.schema,
             _reverse=CreateIndexOp.from_index(index),
@@ -1105,7 +1107,7 @@ class CreateTableOp(MigrateOperation):
     def __init__(
         self,
         table_name: str,
-        columns: Sequence[Union["Column", "Constraint"]],
+        columns: Sequence["SchemaItem"],
         schema: Optional[str] = None,
         _namespace_metadata: Optional["MetaData"] = None,
         _constraints_included: bool = False,
@@ -1172,8 +1174,12 @@ class CreateTableOp(MigrateOperation):
 
     @classmethod
     def create_table(
-        cls, operations: "Operations", table_name: str, *columns, **kw: Any
-    ) -> Optional["Table"]:
+        cls,
+        operations: "Operations",
+        table_name: str,
+        *columns: "SchemaItem",
+        **kw: Any,
+    ) -> "Optional[Table]":
         r"""Issue a "create table" instruction using the current migration
         context.
 
@@ -1603,7 +1609,7 @@ class AlterColumnOp(AlterTableOp):
         existing_nullable: Optional[bool] = None,
         existing_comment: Optional[str] = None,
         modify_nullable: Optional[bool] = None,
-        modify_comment: Optional[Union[str, bool]] = False,
+        modify_comment: Optional[Union[str, "Literal[False]"]] = False,
         modify_server_default: Any = False,
         modify_name: Optional[str] = None,
         modify_type: Optional[Any] = None,
@@ -1757,7 +1763,7 @@ class AlterColumnOp(AlterTableOp):
         table_name: str,
         column_name: str,
         nullable: Optional[bool] = None,
-        comment: Optional[Union[str, bool]] = False,
+        comment: Optional[Union[str, "Literal[False]"]] = False,
         server_default: Any = False,
         new_column_name: Optional[str] = None,
         type_: Optional[Union["TypeEngine", Type["TypeEngine"]]] = None,
@@ -1885,7 +1891,7 @@ class AlterColumnOp(AlterTableOp):
         operations: BatchOperations,
         column_name: str,
         nullable: Optional[bool] = None,
-        comment: bool = False,
+        comment: Union[str, "Literal[False]"] = False,
         server_default: Union["Function", bool] = False,
         new_column_name: Optional[str] = None,
         type_: Optional[Union["TypeEngine", Type["TypeEngine"]]] = None,
index f97983e66a523e862fb512253d6b490f1a2b85b6..add142de37368870e2b2be382185e21187d98842 100644 (file)
@@ -195,7 +195,7 @@ def drop_constraint(
 def bulk_insert(
     operations: "Operations", operation: "ops.BulkInsertOp"
 ) -> None:
-    operations.impl.bulk_insert(
+    operations.impl.bulk_insert(  # type: ignore[union-attr]
         operation.table, operation.rows, multiinsert=operation.multiinsert
     )
 
index b95e0b5e562d327d7af01042e7eea07382572d06..3cec5b1c32de7fbd497ee20162de383ace92ffd7 100644 (file)
@@ -20,10 +20,12 @@ if TYPE_CHECKING:
     from typing import Literal
 
     from sqlalchemy.engine.base import Connection
+    from sqlalchemy.sql.elements import ClauseElement
     from sqlalchemy.sql.schema import MetaData
 
     from .migration import _ProxyTransaction
     from ..config import Config
+    from ..ddl import DefaultImpl
     from ..script.base import ScriptDirectory
 
 _RevNumber = Optional[Union[str, Tuple[str, ...]]]
@@ -273,7 +275,9 @@ class EnvironmentContext(util.ModuleClsProxy):
     ) -> Dict[str, str]:
         ...
 
-    def get_x_argument(self, as_dictionary: bool = False):
+    def get_x_argument(
+        self, as_dictionary: bool = False
+    ) -> Union[List[str], Dict[str, str]]:
         """Return the value(s) passed for the ``-x`` argument, if any.
 
         The ``-x`` argument is an open ended flag that allows any user-defined
@@ -853,7 +857,11 @@ class EnvironmentContext(util.ModuleClsProxy):
         with Operations.context(self._migration_context):
             self.get_context().run_migrations(**kw)
 
-    def execute(self, sql, execution_options=None):
+    def execute(
+        self,
+        sql: Union["ClauseElement", str],
+        execution_options: Optional[dict] = None,
+    ) -> None:
         """Execute the given SQL using the current change context.
 
         The behavior of :meth:`.execute` is the same
@@ -867,7 +875,7 @@ class EnvironmentContext(util.ModuleClsProxy):
         """
         self.get_context().execute(sql, execution_options=execution_options)
 
-    def static_output(self, text):
+    def static_output(self, text: str) -> None:
         """Emit text directly to the "offline" SQL stream.
 
         Typically this is for emitting comments that
@@ -938,7 +946,7 @@ class EnvironmentContext(util.ModuleClsProxy):
             raise Exception("No context has been configured yet.")
         return self._migration_context
 
-    def get_bind(self):
+    def get_bind(self) -> "Connection":
         """Return the current 'bind'.
 
         In "online" mode, this is the
@@ -949,7 +957,7 @@ class EnvironmentContext(util.ModuleClsProxy):
         has first been made available via :meth:`.configure`.
 
         """
-        return self.get_context().bind
+        return self.get_context().bind  # type: ignore[return-value]
 
-    def get_impl(self):
+    def get_impl(self) -> "DefaultImpl":
         return self.get_context().impl
index c09c8e416a3599574a00068c2714c913c7c66c3b..677d0c74d0a2b68f5756074392f2aec83892e54e 100644 (file)
@@ -36,6 +36,7 @@ if TYPE_CHECKING:
     from sqlalchemy.engine.base import Connection
     from sqlalchemy.engine.base import Transaction
     from sqlalchemy.engine.mock import MockConnection
+    from sqlalchemy.sql.elements import ClauseElement
 
     from .environment import EnvironmentContext
     from ..config import Config
@@ -539,6 +540,7 @@ class MigrationContext:
 
     def _ensure_version_table(self, purge: bool = False) -> None:
         with sqla_compat._ensure_scope_for_ddl(self.connection):
+            assert self.connection is not None
             self._version.create(self.connection, checkfirst=True)
             if purge:
                 assert self.connection is not None
@@ -568,7 +570,7 @@ class MigrationContext:
         for step in script_directory._stamp_revs(revision, heads):
             head_maintainer.update_to_step(step)
 
-    def run_migrations(self, **kw) -> None:
+    def run_migrations(self, **kw: Any) -> None:
         r"""Run the migration scripts established for this
         :class:`.MigrationContext`, if any.
 
@@ -614,6 +616,7 @@ class MigrationContext:
                 if self.as_sql and not head_maintainer.heads:
                     # for offline mode, include a CREATE TABLE from
                     # the base
+                    assert self.connection is not None
                     self._version.create(self.connection)
                 log.info("Running %s", step)
                 if self.as_sql:
@@ -637,6 +640,7 @@ class MigrationContext:
                     )
 
         if self.as_sql and not head_maintainer.heads:
+            assert self.connection is not None
             self._version.drop(self.connection)
 
     def _in_connection_transaction(self) -> bool:
@@ -647,7 +651,11 @@ class MigrationContext:
         else:
             return meth()
 
-    def execute(self, sql: str, execution_options: None = None) -> None:
+    def execute(
+        self,
+        sql: Union["ClauseElement", str],
+        execution_options: Optional[dict] = None,
+    ) -> None:
         """Execute a SQL construct or string statement.
 
         The underlying execution mechanics are used, that is
@@ -771,9 +779,11 @@ class HeadMaintainer:
                 == literal_column("'%s'" % version)
             )
         )
+
         if (
             not self.context.as_sql
             and self.context.dialect.supports_sane_rowcount
+            and ret is not None
             and ret.rowcount != 1
         ):
             raise util.CommandError(
@@ -796,9 +806,11 @@ class HeadMaintainer:
                 == literal_column("'%s'" % from_)
             )
         )
+
         if (
             not self.context.as_sql
             and self.context.dialect.supports_sane_rowcount
+            and ret is not None
             and ret.rowcount != 1
         ):
             raise util.CommandError(
@@ -1269,7 +1281,7 @@ class StampStep(MigrationStep):
 
     doc: None = None
 
-    def stamp_revision(self, **kw) -> None:
+    def stamp_revision(self, **kw: Any) -> None:
         return None
 
     def __eq__(self, other):
index 179e5482394859a355cdd7413385a4d5a908d7a2..7e98bb5a645827156564b1dc105539f56b6d8a7c 100644 (file)
@@ -241,7 +241,7 @@ def _table_for_constraint(constraint: "Constraint") -> "Table":
     if isinstance(constraint, ForeignKeyConstraint):
         table = constraint.parent
         assert table is not None
-        return table
+        return table  # type: ignore[return-value]
     else:
         return constraint.table
 
@@ -261,7 +261,9 @@ def _reflect_table(
     if sqla_14:
         return inspector.reflect_table(table, None)
     else:
-        return inspector.reflecttable(table, None)
+        return inspector.reflecttable(  # type: ignore[attr-defined]
+            table, None
+        )
 
 
 def _resolve_for_variant(type_, dialect):
@@ -391,7 +393,9 @@ def _copy_expression(expression: _CE, target_table: "Table") -> _CE:
         else:
             return None
 
-    return visitors.replacement_traverse(expression, {}, replace)
+    return visitors.replacement_traverse(  # type: ignore[call-overload]
+        expression, {}, replace
+    )
 
 
 class _textual_index_element(sql.ColumnElement):
@@ -487,7 +491,7 @@ def _get_constraint_final_name(
 
         if isinstance(constraint, schema.Index):
             # name should not be quoted.
-            d = dialect.ddl_compiler(dialect, None)
+            d = dialect.ddl_compiler(dialect, None)  # type: ignore[arg-type]
             return d._prepared_index_name(  # type: ignore[attr-defined]
                 constraint
             )
@@ -529,7 +533,7 @@ def _insert_inline(table: Union["TableClause", "Table"]) -> "Insert":
     if sqla_14:
         return table.insert().inline()
     else:
-        return table.insert(inline=True)
+        return table.insert(inline=True)  # type: ignore[call-arg]
 
 
 if sqla_14:
@@ -543,5 +547,5 @@ else:
             "postgresql://", strategy="mock", executor=executor
         )
 
-    def _select(*columns, **kw) -> "Select":
-        return sql.select(list(columns), **kw)
+    def _select(*columns, **kw) -> "Select":  # type: ignore[no-redef]
+        return sql.select(list(columns), **kw)  # type: ignore[call-overload]
index 2a8de061dae89e94732aff580278d9b26c3d382e..f66269af6beb2b3b98965f24e8977103dc4fd935 100644 (file)
@@ -15,6 +15,17 @@ exclude = [
 ]
 show_error_codes = true
 
+[[tool.mypy.overrides]]
+module = [
+    'alembic.operations.ops',
+    'alembic.op',
+    'alembic.context',
+    'alembic.autogenerate.api',
+    'alembic.runtime.*',
+]
+
+disallow_incomplete_defs = true
+
 [[tool.mypy.overrides]]
 module = [
     'mako.*',
index cf42d1b1f9ceba92e4bddd5a60f792a16bdf34c2..52fac3c1282efdce72e11f4c1d1534a162081836 100644 (file)
@@ -29,6 +29,7 @@ IGNORE_ITEMS = {
 }
 TRIM_MODULE = [
     "alembic.runtime.migration.",
+    "alembic.operations.base.",
     "alembic.operations.ops.",
     "sqlalchemy.engine.base.",
     "sqlalchemy.sql.schema.",
@@ -85,6 +86,8 @@ def generate_pyi_for_proxy(
 
         module = sys.modules[cls.__module__]
         env = {
+            **typing.__dict__,
+            **sa.sql.schema.__dict__,
             **sa.__dict__,
             **sa.types.__dict__,
             **ops.__dict__,
@@ -141,7 +144,7 @@ def _generate_stub_for_meth(cls, name, printer, env, is_context_manager):
         annotations = typing.get_type_hints(fn, env)
         spec.annotations.update(annotations)
     except NameError as e:
-        pass
+        print(f"{cls.__name__}.{name} NameError: {e}", file=sys.stderr)
 
     name_args = spec[0]
     assert name_args[0:1] == ["self"] or name_args[0:1] == ["cls"]