]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
finish strict typing for most modules
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 13 Dec 2023 16:05:03 +0000 (11:05 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 Dec 2023 17:13:02 +0000 (12:13 -0500)
Updated pep-484 typing to pass mypy "strict" mode, however including
per-module qualifications for specific typing elements not yet complete.
This allows us to catch specific typing issues that have been ongoing
such as import symbols not properly exported.

Fixes: #1377
Change-Id: I69db4d23460f02161ac771d5d591b2bc802b8ab1

38 files changed:
alembic/autogenerate/__init__.py
alembic/autogenerate/api.py
alembic/autogenerate/compare.py
alembic/autogenerate/render.py
alembic/autogenerate/rewriter.py
alembic/command.py
alembic/config.py
alembic/context.pyi
alembic/ddl/__init__.py
alembic/ddl/_autogen.py
alembic/ddl/base.py
alembic/ddl/impl.py
alembic/ddl/mssql.py
alembic/ddl/mysql.py
alembic/ddl/oracle.py
alembic/ddl/postgresql.py
alembic/ddl/sqlite.py
alembic/op.pyi
alembic/operations/base.py
alembic/operations/batch.py
alembic/operations/ops.py
alembic/operations/schemaobj.py
alembic/operations/toimpl.py
alembic/runtime/environment.py
alembic/runtime/migration.py
alembic/script/base.py
alembic/script/revision.py
alembic/script/write_hooks.py
alembic/util/__init__.py
alembic/util/compat.py
alembic/util/langhelpers.py
alembic/util/messaging.py
alembic/util/pyfiles.py
alembic/util/sqla_compat.py
docs/build/unreleased/1377.rst [new file with mode: 0644]
pyproject.toml
setup.cfg
tools/write_pyi.py

index cd2ed1c15e1afc37b335d8e1f262463d600053e6..445ddb25125aa63994052dd4ecea1362dc91656d 100644 (file)
@@ -1,10 +1,10 @@
-from .api import _render_migration_diffs
-from .api import compare_metadata
-from .api import produce_migrations
-from .api import render_python_code
-from .api import RevisionContext
-from .compare import _produce_net_changes
-from .compare import comparators
-from .render import render_op_text
-from .render import renderers
-from .rewriter import Rewriter
+from .api import _render_migration_diffs as _render_migration_diffs
+from .api import compare_metadata as compare_metadata
+from .api import produce_migrations as produce_migrations
+from .api import render_python_code as render_python_code
+from .api import RevisionContext as RevisionContext
+from .compare import _produce_net_changes as _produce_net_changes
+from .compare import comparators as comparators
+from .render import render_op_text as render_op_text
+from .render import renderers as renderers
+from .rewriter import Rewriter as Rewriter
index b7f43b1936886edc09d9a88f55b3de86cb6ffdf7..aa8f32f65359c9c04f41ea24e21131beee2d8d2a 100644 (file)
@@ -28,6 +28,7 @@ if TYPE_CHECKING:
     from sqlalchemy.engine import Inspector
     from sqlalchemy.sql.schema import MetaData
     from sqlalchemy.sql.schema import SchemaItem
+    from sqlalchemy.sql.schema import Table
 
     from ..config import Config
     from ..operations.ops import DowngradeOps
@@ -165,6 +166,7 @@ def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any:
     """
 
     migration_script = produce_migrations(context, metadata)
+    assert migration_script.upgrade_ops is not None
     return migration_script.upgrade_ops.as_diffs()
 
 
@@ -331,7 +333,7 @@ class AutogenContext:
         self,
         migration_context: MigrationContext,
         metadata: Optional[MetaData] = None,
-        opts: Optional[dict] = None,
+        opts: Optional[Dict[str, Any]] = None,
         autogenerate: bool = True,
     ) -> None:
         if (
@@ -465,7 +467,7 @@ class AutogenContext:
     run_filters = run_object_filters
 
     @util.memoized_property
-    def sorted_tables(self):
+    def sorted_tables(self) -> List[Table]:
         """Return an aggregate of the :attr:`.MetaData.sorted_tables`
         collection(s).
 
@@ -481,7 +483,7 @@ class AutogenContext:
         return result
 
     @util.memoized_property
-    def table_key_to_table(self):
+    def table_key_to_table(self) -> Dict[str, Table]:
         """Return an aggregate  of the :attr:`.MetaData.tables` dictionaries.
 
         The :attr:`.MetaData.tables` collection is a dictionary of table key
@@ -492,7 +494,7 @@ class AutogenContext:
         objects contain the same table key, an exception is raised.
 
         """
-        result = {}
+        result: Dict[str, Table] = {}
         for m in util.to_list(self.metadata):
             intersect = set(result).intersection(set(m.tables))
             if intersect:
index a50d8b8186b7c0c96cd1b81596ae944f7ddde45b..fcef531a544eefdc3baacff66aa5fa70229f9b58 100644 (file)
@@ -1,3 +1,6 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
 from __future__ import annotations
 
 import contextlib
@@ -577,9 +580,7 @@ def _compare_indexes_and_uniques(
     # 5. index things by name, for those objects that have names
     metadata_names = {
         cast(str, c.md_name_to_sql_name(autogen_context)): c
-        for c in metadata_unique_constraints_sig.union(
-            metadata_indexes_sig  # type:ignore[arg-type]
-        )
+        for c in metadata_unique_constraints_sig.union(metadata_indexes_sig)
         if c.is_named
     }
 
@@ -1240,7 +1241,7 @@ def _compare_foreign_keys(
             obj.const, obj.name, "foreign_key_constraint", False, compare_to
         ):
             modify_table_ops.ops.append(
-                ops.CreateForeignKeyOp.from_constraint(const.const)
+                ops.CreateForeignKeyOp.from_constraint(const.const)  # type: ignore[has-type]  # noqa: E501
             )
 
             log.info(
index 67cc8c33dda0fda1df3576e5810c3d767ee1dd2e..317a6dbed9cf6eb6514d67a82ee3ee853c22254b 100644 (file)
@@ -1,3 +1,6 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
 from __future__ import annotations
 
 from io import StringIO
@@ -849,7 +852,7 @@ def _render_Variant_type(
 ) -> str:
     base_type, variant_mapping = sqla_compat._get_variant_mapping(type_)
     base = _repr_type(base_type, autogen_context, _skip_variants=True)
-    assert base is not None and base is not False
+    assert base is not None and base is not False  # type: ignore[comparison-overlap]  # noqa:E501
     for dialect in sorted(variant_mapping):
         typ = variant_mapping[dialect]
         base += ".with_variant(%s, %r)" % (
@@ -946,7 +949,7 @@ def _fk_colspec(
     won't fail if the remote table can't be resolved.
 
     """
-    colspec = fk._get_colspec()  # type:ignore[attr-defined]
+    colspec = fk._get_colspec()
     tokens = colspec.split(".")
     tname, colname = tokens[-2:]
 
@@ -1016,8 +1019,7 @@ def _render_foreign_key(
         % {
             "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
             "cols": ", ".join(
-                "%r" % _ident(cast("Column", f.parent).name)
-                for f in constraint.elements
+                repr(_ident(f.parent.name)) for f in constraint.elements
             ),
             "refcols": ", ".join(
                 repr(_fk_colspec(f, apply_metadata_schema, namespace_metadata))
@@ -1058,12 +1060,10 @@ def _render_check_constraint(
     # ideally SQLAlchemy would give us more of a first class
     # way to detect this.
     if (
-        constraint._create_rule  # type:ignore[attr-defined]
-        and hasattr(
-            constraint._create_rule, "target"  # type:ignore[attr-defined]
-        )
+        constraint._create_rule
+        and hasattr(constraint._create_rule, "target")
         and isinstance(
-            constraint._create_rule.target,  # type:ignore[attr-defined]
+            constraint._create_rule.target,
             sqltypes.TypeEngine,
         )
     ):
index 68a93dd0ab8b35ab24e4bfee23056ad36bad6a70..02ff431c2a64041d8127be4517122529f0714b08 100644 (file)
@@ -16,12 +16,14 @@ if TYPE_CHECKING:
     from ..operations.ops import AddColumnOp
     from ..operations.ops import AlterColumnOp
     from ..operations.ops import CreateTableOp
+    from ..operations.ops import DowngradeOps
     from ..operations.ops import MigrateOperation
     from ..operations.ops import MigrationScript
     from ..operations.ops import ModifyTableOps
     from ..operations.ops import OpContainer
-    from ..runtime.environment import _GetRevArg
+    from ..operations.ops import UpgradeOps
     from ..runtime.migration import MigrationContext
+    from ..script.revision import _GetRevArg
 
 
 class Rewriter:
@@ -101,7 +103,7 @@ class Rewriter:
             Type[CreateTableOp],
             Type[ModifyTableOps],
         ],
-    ) -> Callable:
+    ) -> Callable[..., Any]:
         """Register a function as rewriter for a given type.
 
         The function should receive three arguments, which are
@@ -156,7 +158,7 @@ class Rewriter:
         revision: _GetRevArg,
         directive: MigrationScript,
     ) -> None:
-        upgrade_ops_list = []
+        upgrade_ops_list: List[UpgradeOps] = []
         for upgrade_ops in directive.upgrade_ops_list:
             ret = self._traverse_for(context, revision, upgrade_ops)
             if len(ret) != 1:
@@ -164,9 +166,10 @@ class Rewriter:
                     "Can only return single object for UpgradeOps traverse"
                 )
             upgrade_ops_list.append(ret[0])
-        directive.upgrade_ops = upgrade_ops_list
 
-        downgrade_ops_list = []
+        directive.upgrade_ops = upgrade_ops_list  # type: ignore
+
+        downgrade_ops_list: List[DowngradeOps] = []
         for downgrade_ops in directive.downgrade_ops_list:
             ret = self._traverse_for(context, revision, downgrade_ops)
             if len(ret) != 1:
@@ -174,7 +177,7 @@ class Rewriter:
                     "Can only return single object for DowngradeOps traverse"
                 )
             downgrade_ops_list.append(ret[0])
-        directive.downgrade_ops = downgrade_ops_list
+        directive.downgrade_ops = downgrade_ops_list  # type: ignore
 
     @_traverse.dispatch_for(ops.OpContainer)
     def _traverse_op_container(
index c5233e72efd27b12f20934d82f649f24923c1fe3..37aa6e67ebf1c82416c96fd40f1dee0ae8c82dcb 100644 (file)
@@ -1,3 +1,5 @@
+# mypy: allow-untyped-defs, allow-untyped-calls
+
 from __future__ import annotations
 
 import os
@@ -18,7 +20,7 @@ if TYPE_CHECKING:
     from .runtime.environment import ProcessRevisionDirectiveFn
 
 
-def list_templates(config: Config):
+def list_templates(config: Config) -> None:
     """List available templates.
 
     :param config: a :class:`.Config` object.
index 55b5811a2c179933047774608340bcd65465720f..4b2263fddacf0cd85e9969221ef9b3d8105cd8f4 100644 (file)
@@ -12,6 +12,7 @@ from typing import Dict
 from typing import Mapping
 from typing import Optional
 from typing import overload
+from typing import Sequence
 from typing import TextIO
 from typing import Union
 
@@ -104,7 +105,7 @@ class Config:
         stdout: TextIO = sys.stdout,
         cmd_opts: Optional[Namespace] = None,
         config_args: Mapping[str, Any] = util.immutabledict(),
-        attributes: Optional[dict] = None,
+        attributes: Optional[Dict[str, Any]] = None,
     ) -> None:
         """Construct a new :class:`.Config`"""
         self.config_file_name = file_
@@ -140,7 +141,7 @@ class Config:
     """
 
     @util.memoized_property
-    def attributes(self):
+    def attributes(self) -> Dict[str, Any]:
         """A Python dictionary for storage of additional state.
 
 
@@ -159,7 +160,7 @@ class Config:
         """
         return {}
 
-    def print_stdout(self, text: str, *arg) -> None:
+    def print_stdout(self, text: str, *arg: Any) -> None:
         """Render a message to standard out.
 
         When :meth:`.Config.print_stdout` is called with additional args
@@ -183,7 +184,7 @@ class Config:
         util.write_outstream(self.stdout, output, "\n", **self.messaging_opts)
 
     @util.memoized_property
-    def file_config(self):
+    def file_config(self) -> ConfigParser:
         """Return the underlying ``ConfigParser`` object.
 
         Direct access to the .ini file is available here,
@@ -321,7 +322,9 @@ class Config:
     ) -> Optional[str]:
         ...
 
-    def get_main_option(self, name, default=None):
+    def get_main_option(
+        self, name: str, default: Optional[str] = None
+    ) -> Optional[str]:
         """Return an option from the 'main' section of the .ini file.
 
         This defaults to being a key from the ``[alembic]``
@@ -351,7 +354,9 @@ class CommandLine:
         self._generate_args(prog)
 
     def _generate_args(self, prog: Optional[str]) -> None:
-        def add_options(fn, parser, positional, kwargs):
+        def add_options(
+            fn: Any, parser: Any, positional: Any, kwargs: Any
+        ) -> None:
             kwargs_opts = {
                 "template": (
                     "-t",
@@ -554,7 +559,9 @@ class CommandLine:
         )
         subparsers = parser.add_subparsers()
 
-        positional_translations = {command.stamp: {"revision": "revisions"}}
+        positional_translations: Dict[Any, Any] = {
+            command.stamp: {"revision": "revisions"}
+        }
 
         for fn in [getattr(command, n) for n in dir(command)]:
             if (
@@ -609,7 +616,7 @@ class CommandLine:
             else:
                 util.err(str(e), **config.messaging_opts)
 
-    def main(self, argv=None):
+    def main(self, argv: Optional[Sequence[str]] = None) -> None:
         options = self.parser.parse_args(argv)
         if not hasattr(options, "cmd"):
             # see http://bugs.python.org/issue9253, argparse
@@ -624,7 +631,11 @@ class CommandLine:
             self.run_cmd(cfg, options)
 
 
-def main(argv=None, prog=None, **kwargs):
+def main(
+    argv: Optional[Sequence[str]] = None,
+    prog: Optional[str] = None,
+    **kwargs: Any,
+) -> None:
     """The console runner function for Alembic."""
 
     CommandLine(prog=prog).main(argv=argv)
index e8d98210d08e17cacbef66bc7ce3352130bfa229..80619fb24f13fadcbaa7fcd4a907b19f231b12e3 100644 (file)
@@ -160,8 +160,8 @@ def configure(
                 MigrationContext,
                 Column[Any],
                 Column[Any],
-                TypeEngine,
-                TypeEngine,
+                TypeEngine[Any],
+                TypeEngine[Any],
             ],
             Optional[bool],
         ],
@@ -636,7 +636,8 @@ def configure(
     """
 
 def execute(
-    sql: Union[Executable, str], execution_options: Optional[dict] = None
+    sql: Union[Executable, str],
+    execution_options: Optional[Dict[str, Any]] = None,
 ) -> None:
     """Execute the given SQL using the current change context.
 
@@ -805,7 +806,7 @@ def is_offline_mode() -> bool:
 
     """
 
-def is_transactional_ddl():
+def is_transactional_ddl() -> bool:
     """Return True if the context is configured to expect a
     transactional DDL capable backend.
 
index cfcc47e029515838acc6b972950f8bfd8454c770..f2f72b3dd8d3748b36cb7acfcda7abf8468b6926 100644 (file)
@@ -3,4 +3,4 @@ from . import mysql
 from . import oracle
 from . import postgresql
 from . import sqlite
-from .impl import DefaultImpl
+from .impl import DefaultImpl as DefaultImpl
index cc1a1fc4293b21bf1faa3317c00b53410126be03..e22153c49c761451c074c11de6c7ea53d20c1149 100644 (file)
@@ -1,3 +1,6 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
 from __future__ import annotations
 
 from typing import Any
@@ -19,7 +22,6 @@ from sqlalchemy.sql.schema import Index
 from sqlalchemy.sql.schema import UniqueConstraint
 from typing_extensions import TypeGuard
 
-from alembic.ddl.base import _fk_spec
 from .. import util
 from ..util import sqla_compat
 
@@ -275,7 +277,7 @@ class _fk_constraint_sig(_constraint_sig[ForeignKeyConstraint]):
             ondelete,
             deferrable,
             initially,
-        ) = _fk_spec(const)
+        ) = sqla_compat._fk_spec(const)
 
         self._sig: Tuple[Any, ...] = (
             self.source_schema,
index 339db0c4a5d9e78e7e7b608895fb30de8653b22a..7a85a5c198affa8f50fcfe4da126836627ae472c 100644 (file)
@@ -1,3 +1,6 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
 from __future__ import annotations
 
 import functools
@@ -173,7 +176,7 @@ class ColumnComment(AlterColumn):
         self.comment = comment
 
 
-@compiles(RenameTable)
+@compiles(RenameTable)  # type: ignore[misc]
 def visit_rename_table(
     element: RenameTable, compiler: DDLCompiler, **kw
 ) -> str:
@@ -183,7 +186,7 @@ def visit_rename_table(
     )
 
 
-@compiles(AddColumn)
+@compiles(AddColumn)  # type: ignore[misc]
 def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str:
     return "%s %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -191,7 +194,7 @@ def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str:
     )
 
 
-@compiles(DropColumn)
+@compiles(DropColumn)  # type: ignore[misc]
 def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str:
     return "%s %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -199,7 +202,7 @@ def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str:
     )
 
 
-@compiles(ColumnNullable)
+@compiles(ColumnNullable)  # type: ignore[misc]
 def visit_column_nullable(
     element: ColumnNullable, compiler: DDLCompiler, **kw
 ) -> str:
@@ -210,7 +213,7 @@ def visit_column_nullable(
     )
 
 
-@compiles(ColumnType)
+@compiles(ColumnType)  # type: ignore[misc]
 def visit_column_type(element: ColumnType, compiler: DDLCompiler, **kw) -> str:
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -219,7 +222,7 @@ def visit_column_type(element: ColumnType, compiler: DDLCompiler, **kw) -> str:
     )
 
 
-@compiles(ColumnName)
+@compiles(ColumnName)  # type: ignore[misc]
 def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str:
     return "%s RENAME %s TO %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -228,7 +231,7 @@ def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str:
     )
 
 
-@compiles(ColumnDefault)
+@compiles(ColumnDefault)  # type: ignore[misc]
 def visit_column_default(
     element: ColumnDefault, compiler: DDLCompiler, **kw
 ) -> str:
@@ -241,7 +244,7 @@ def visit_column_default(
     )
 
 
-@compiles(ComputedColumnDefault)
+@compiles(ComputedColumnDefault)  # type: ignore[misc]
 def visit_computed_column(
     element: ComputedColumnDefault, compiler: DDLCompiler, **kw
 ):
@@ -251,7 +254,7 @@ def visit_computed_column(
     )
 
 
-@compiles(IdentityColumnDefault)
+@compiles(IdentityColumnDefault)  # type: ignore[misc]
 def visit_identity_column(
     element: IdentityColumnDefault, compiler: DDLCompiler, **kw
 ):
index 571a3041cc66ba52f4fe20d87402fb2789547a37..2e4f1ae9405eac6c755f3f4f3957efa717ecd8da 100644 (file)
@@ -1,3 +1,6 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
 from __future__ import annotations
 
 import logging
@@ -23,8 +26,8 @@ from sqlalchemy import text
 
 from . import _autogen
 from . import base
-from ._autogen import _constraint_sig
-from ._autogen import ComparisonResult
+from ._autogen import _constraint_sig as _constraint_sig
+from ._autogen import ComparisonResult as ComparisonResult
 from .. import util
 from ..util import sqla_compat
 
index 9b0fff885fa19ab65d57d5c27d01deec83b0a6d0..baa43d5e73abb3e40294c0000d0f2694182744eb 100644 (file)
@@ -1,3 +1,6 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
 from __future__ import annotations
 
 import re
@@ -9,7 +12,6 @@ from typing import TYPE_CHECKING
 from typing import Union
 
 from sqlalchemy import types as sqltypes
-from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.schema import Column
 from sqlalchemy.schema import CreateIndex
 from sqlalchemy.sql.base import Executable
@@ -30,6 +32,7 @@ from .base import RenameTable
 from .impl import DefaultImpl
 from .. import util
 from ..util import sqla_compat
+from ..util.sqla_compat import compiles
 
 if TYPE_CHECKING:
     from typing import Literal
index 5a2af5ce7b773b12295a02afad6d8614137d3a77..f312173e946d117b276e06ed5aa290f18f7db61b 100644 (file)
@@ -1,3 +1,6 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
 from __future__ import annotations
 
 import re
@@ -8,7 +11,6 @@ from typing import Union
 
 from sqlalchemy import schema
 from sqlalchemy import types as sqltypes
-from sqlalchemy.ext.compiler import compiles
 
 from .base import alter_table
 from .base import AlterColumn
@@ -23,6 +25,7 @@ from .. import util
 from ..util import sqla_compat
 from ..util.sqla_compat import _is_mariadb
 from ..util.sqla_compat import _is_type_bound
+from ..util.sqla_compat import compiles
 
 if TYPE_CHECKING:
     from typing import Literal
@@ -160,8 +163,7 @@ class MySQLImpl(DefaultImpl):
     ) -> bool:
         return (
             type_ is not None
-            and type_._type_affinity  # type:ignore[attr-defined]
-            is sqltypes.DateTime
+            and type_._type_affinity is sqltypes.DateTime
             and server_default is not None
         )
 
index e56bb2102f45d0e807ec7ad908fd86f359efc2ea..54011740723749b50f53beaac6c75ca020e365a3 100644 (file)
@@ -1,3 +1,6 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
 from __future__ import annotations
 
 import re
@@ -5,7 +8,6 @@ from typing import Any
 from typing import Optional
 from typing import TYPE_CHECKING
 
-from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.sql import sqltypes
 
 from .base import AddColumn
@@ -22,6 +24,7 @@ from .base import format_type
 from .base import IdentityColumnDefault
 from .base import RenameTable
 from .impl import DefaultImpl
+from ..util.sqla_compat import compiles
 
 if TYPE_CHECKING:
     from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
index 68628c8ecfda7f41f967d83deba054e60fbe1cef..6507fcbdd75c82873e3fc4ff5e3030d8638c0474 100644 (file)
@@ -1,3 +1,6 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
 from __future__ import annotations
 
 import logging
@@ -30,7 +33,6 @@ from .base import alter_column
 from .base import alter_table
 from .base import AlterColumn
 from .base import ColumnComment
-from .base import compiles
 from .base import format_column_name
 from .base import format_table_name
 from .base import format_type
@@ -45,6 +47,7 @@ from ..operations import schemaobj
 from ..operations.base import BatchOperations
 from ..operations.base import Operations
 from ..util import sqla_compat
+from ..util.sqla_compat import compiles
 
 if TYPE_CHECKING:
     from typing import Literal
@@ -136,7 +139,9 @@ class PostgresqlImpl(DefaultImpl):
             metadata_default = literal_column(metadata_default)
 
         # run a real compare against the server
-        return not self.connection.scalar(
+        conn = self.connection
+        assert conn is not None
+        return not conn.scalar(
             sqla_compat._select(
                 literal_column(conn_col_default) == metadata_default
             )
@@ -623,9 +628,8 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
         return cls(
             constraint.name,
             constraint_table.name,
-            [
-                (expr, op)
-                for expr, name, op in constraint._render_exprs  # type:ignore[attr-defined] # noqa
+            [  # type: ignore
+                (expr, op) for expr, name, op in constraint._render_exprs
             ],
             where=cast("ColumnElement[bool] | None", constraint.where),
             schema=constraint_table.schema,
@@ -652,7 +656,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
             expr,
             name,
             oper,
-        ) in excl._render_exprs:  # type:ignore[attr-defined]
+        ) in excl._render_exprs:
             t.append_column(Column(name, NULLTYPE))
         t.append_constraint(excl)
         return excl
@@ -710,7 +714,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
         constraint_name: str,
         *elements: Any,
         **kw: Any,
-    ):
+    ) -> Optional[Table]:
         """Issue a "create exclude constraint" instruction using the
         current batch migration context.
 
@@ -782,10 +786,13 @@ def _exclude_constraint(
         args = [
             "(%s, %r)"
             % (
-                _render_potential_column(sqltext, autogen_context),
+                _render_potential_column(
+                    sqltext,  # type:ignore[arg-type]
+                    autogen_context,
+                ),
                 opstring,
             )
-            for sqltext, name, opstring in constraint._render_exprs  # type:ignore[attr-defined] # noqa
+            for sqltext, name, opstring in constraint._render_exprs
         ]
         if constraint.where is not None:
             args.append(
index c6186c60a91892dedeb1eeecf3c3d0337ba9c9a8..762e8ca198a6af4d001afd362ff15ac9c43a2821 100644 (file)
@@ -1,3 +1,6 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
 from __future__ import annotations
 
 import re
@@ -11,13 +14,13 @@ from sqlalchemy import cast
 from sqlalchemy import JSON
 from sqlalchemy import schema
 from sqlalchemy import sql
-from sqlalchemy.ext.compiler import compiles
 
 from .base import alter_table
 from .base import format_table_name
 from .base import RenameTable
 from .impl import DefaultImpl
 from .. import util
+from ..util.sqla_compat import compiles
 
 if TYPE_CHECKING:
     from sqlalchemy.engine.reflection import Inspector
@@ -71,13 +74,13 @@ class SQLiteImpl(DefaultImpl):
     def add_constraint(self, const: Constraint):
         # attempt to distinguish between an
         # auto-gen constraint and an explicit one
-        if const._create_rule is None:  # type:ignore[attr-defined]
+        if const._create_rule is None:
             raise NotImplementedError(
                 "No support for ALTER of constraints in SQLite dialect. "
                 "Please refer to the batch mode feature which allows for "
                 "SQLite migrations using a copy-and-move strategy."
             )
-        elif const._create_rule(self):  # type:ignore[attr-defined]
+        elif const._create_rule(self):
             util.warn(
                 "Skipping unsupported ALTER for "
                 "creation of implicit constraint. "
@@ -86,7 +89,7 @@ class SQLiteImpl(DefaultImpl):
             )
 
     def drop_constraint(self, const: Constraint):
-        if const._create_rule is None:  # type:ignore[attr-defined]
+        if const._create_rule is None:
             raise NotImplementedError(
                 "No support for ALTER of constraints in SQLite dialect. "
                 "Please refer to the batch mode feature which allows for "
@@ -177,8 +180,7 @@ class SQLiteImpl(DefaultImpl):
         new_type: TypeEngine,
     ) -> None:
         if (
-            existing.type._type_affinity  # type:ignore[attr-defined]
-            is not new_type._type_affinity  # type:ignore[attr-defined]
+            existing.type._type_affinity is not new_type._type_affinity
             and not isinstance(new_type, JSON)
         ):
             existing_transfer["expr"] = cast(
index 944b5ae16a64be6670f32b811e897027de2210e8..83deac1eb0154050362c9411291bcfe8d64e97c3 100644 (file)
@@ -12,6 +12,7 @@ from typing import List
 from typing import Literal
 from typing import Mapping
 from typing import Optional
+from typing import overload
 from typing import Sequence
 from typing import Tuple
 from typing import Type
@@ -35,12 +36,28 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.type_api import TypeEngine
     from sqlalchemy.util import immutabledict
 
-    from .operations.ops import BatchOperations
+    from .operations.base import BatchOperations
+    from .operations.ops import AddColumnOp
+    from .operations.ops import AddConstraintOp
+    from .operations.ops import AlterColumnOp
+    from .operations.ops import AlterTableOp
+    from .operations.ops import BulkInsertOp
+    from .operations.ops import CreateIndexOp
+    from .operations.ops import CreateTableCommentOp
+    from .operations.ops import CreateTableOp
+    from .operations.ops import DropColumnOp
+    from .operations.ops import DropConstraintOp
+    from .operations.ops import DropIndexOp
+    from .operations.ops import DropTableCommentOp
+    from .operations.ops import DropTableOp
+    from .operations.ops import ExecuteSQLOp
     from .operations.ops import MigrateOperation
     from .runtime.migration import MigrationContext
     from .util.sqla_compat import _literal_bindparam
 
 _T = TypeVar("_T")
+_C = TypeVar("_C", bound=Callable[..., Any])
+
 ### end imports ###
 
 def add_column(
@@ -132,8 +149,8 @@ def alter_column(
     comment: Union[str, Literal[False], None] = False,
     server_default: Any = False,
     new_column_name: Optional[str] = None,
-    type_: Union[TypeEngine, Type[TypeEngine], None] = None,
-    existing_type: Union[TypeEngine, Type[TypeEngine], None] = None,
+    type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None,
+    existing_type: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None,
     existing_server_default: Union[
         str, bool, Identity, Computed, None
     ] = False,
@@ -230,7 +247,7 @@ def batch_alter_table(
     table_name: str,
     schema: Optional[str] = None,
     recreate: Literal["auto", "always", "never"] = "auto",
-    partial_reordering: Optional[tuple] = None,
+    partial_reordering: Optional[Tuple[Any, ...]] = None,
     copy_from: Optional[Table] = None,
     table_args: Tuple[Any, ...] = (),
     table_kwargs: Mapping[str, Any] = immutabledict({}),
@@ -377,7 +394,7 @@ def batch_alter_table(
 
 def bulk_insert(
     table: Union[Table, TableClause],
-    rows: List[dict],
+    rows: List[Dict[str, Any]],
     *,
     multiinsert: bool = True,
 ) -> None:
@@ -1162,7 +1179,7 @@ def get_context() -> MigrationContext:
 
     """
 
-def implementation_for(op_cls: Any) -> Callable[..., Any]:
+def implementation_for(op_cls: Any) -> Callable[[_C], _C]:
     """Register an implementation for a given :class:`.MigrateOperation`.
 
     This is part of the operation extensibility API.
@@ -1174,7 +1191,7 @@ def implementation_for(op_cls: Any) -> Callable[..., Any]:
     """
 
 def inline_literal(
-    value: Union[str, int], type_: Optional[TypeEngine] = None
+    value: Union[str, int], type_: Optional[TypeEngine[Any]] = None
 ) -> _literal_bindparam:
     r"""Produce an 'inline literal' expression, suitable for
     using in an INSERT, UPDATE, or DELETE statement.
@@ -1218,6 +1235,27 @@ def inline_literal(
 
     """
 
+@overload
+def invoke(operation: CreateTableOp) -> Table: ...
+@overload
+def invoke(
+    operation: Union[
+        AddConstraintOp,
+        DropConstraintOp,
+        CreateIndexOp,
+        DropIndexOp,
+        AddColumnOp,
+        AlterColumnOp,
+        AlterTableOp,
+        CreateTableCommentOp,
+        DropTableCommentOp,
+        DropColumnOp,
+        BulkInsertOp,
+        DropTableOp,
+        ExecuteSQLOp,
+    ]
+) -> None: ...
+@overload
 def invoke(operation: MigrateOperation) -> Any:
     """Given a :class:`.MigrateOperation`, invoke it in terms of
     this :class:`.Operations` instance.
@@ -1226,7 +1264,7 @@ def invoke(operation: MigrateOperation) -> Any:
 
 def register_operation(
     name: str, sourcename: Optional[str] = None
-) -> Callable[[_T], _T]:
+) -> Callable[[Type[_T]], Type[_T]]:
     """Register a new operation for this class.
 
     This method is normally used to add new operations
index e3207be765f0fc6c9fdc5949e79bd7c10cb8d6f0..bafe441a69ceb2bcd13f5f3f3fad1382b589e99f 100644 (file)
@@ -1,3 +1,5 @@
+# mypy: allow-untyped-calls
+
 from __future__ import annotations
 
 from contextlib import contextmanager
@@ -10,7 +12,9 @@ from typing import Dict
 from typing import Iterator
 from typing import List  # noqa
 from typing import Mapping
+from typing import NoReturn
 from typing import Optional
+from typing import overload
 from typing import Sequence  # noqa
 from typing import Tuple
 from typing import Type  # noqa
@@ -47,12 +51,28 @@ if TYPE_CHECKING:
     from sqlalchemy.types import TypeEngine
 
     from .batch import BatchOperationsImpl
+    from .ops import AddColumnOp
+    from .ops import AddConstraintOp
+    from .ops import AlterColumnOp
+    from .ops import AlterTableOp
+    from .ops import BulkInsertOp
+    from .ops import CreateIndexOp
+    from .ops import CreateTableCommentOp
+    from .ops import CreateTableOp
+    from .ops import DropColumnOp
+    from .ops import DropConstraintOp
+    from .ops import DropIndexOp
+    from .ops import DropTableCommentOp
+    from .ops import DropTableOp
+    from .ops import ExecuteSQLOp
     from .ops import MigrateOperation
     from ..ddl import DefaultImpl
     from ..runtime.migration import MigrationContext
 __all__ = ("Operations", "BatchOperations")
 _T = TypeVar("_T")
 
+_C = TypeVar("_C", bound=Callable[..., Any])
+
 
 class AbstractOperations(util.ModuleClsProxy):
     """Base class for Operations and BatchOperations.
@@ -86,7 +106,7 @@ class AbstractOperations(util.ModuleClsProxy):
     @classmethod
     def register_operation(
         cls, name: str, sourcename: Optional[str] = None
-    ) -> Callable[[_T], _T]:
+    ) -> Callable[[Type[_T]], Type[_T]]:
         """Register a new operation for this class.
 
         This method is normally used to add new operations
@@ -103,7 +123,7 @@ class AbstractOperations(util.ModuleClsProxy):
 
         """
 
-        def register(op_cls):
+        def register(op_cls: Type[_T]) -> Type[_T]:
             if sourcename is None:
                 fn = getattr(op_cls, name)
                 source_name = fn.__name__
@@ -122,8 +142,11 @@ class AbstractOperations(util.ModuleClsProxy):
                 *spec, formatannotation=formatannotation_fwdref
             )
             num_defaults = len(spec[3]) if spec[3] else 0
+
+            defaulted_vals: Tuple[Any, ...]
+
             if num_defaults:
-                defaulted_vals = name_args[0 - num_defaults :]
+                defaulted_vals = tuple(name_args[0 - num_defaults :])
             else:
                 defaulted_vals = ()
 
@@ -164,7 +187,7 @@ class AbstractOperations(util.ModuleClsProxy):
 
             globals_ = dict(globals())
             globals_.update({"op_cls": op_cls})
-            lcl = {}
+            lcl: Dict[str, Any] = {}
 
             exec(func_text, globals_, lcl)
             setattr(cls, name, lcl[name])
@@ -180,7 +203,7 @@ class AbstractOperations(util.ModuleClsProxy):
         return register
 
     @classmethod
-    def implementation_for(cls, op_cls: Any) -> Callable[..., Any]:
+    def implementation_for(cls, op_cls: Any) -> Callable[[_C], _C]:
         """Register an implementation for a given :class:`.MigrateOperation`.
 
         This is part of the operation extensibility API.
@@ -191,7 +214,7 @@ class AbstractOperations(util.ModuleClsProxy):
 
         """
 
-        def decorate(fn):
+        def decorate(fn: _C) -> _C:
             cls._to_impl.dispatch_for(op_cls)(fn)
             return fn
 
@@ -213,7 +236,7 @@ class AbstractOperations(util.ModuleClsProxy):
         table_name: str,
         schema: Optional[str] = None,
         recreate: Literal["auto", "always", "never"] = "auto",
-        partial_reordering: Optional[tuple] = None,
+        partial_reordering: Optional[Tuple[Any, ...]] = None,
         copy_from: Optional[Table] = None,
         table_args: Tuple[Any, ...] = (),
         table_kwargs: Mapping[str, Any] = util.immutabledict(),
@@ -382,6 +405,35 @@ class AbstractOperations(util.ModuleClsProxy):
 
         return self.migration_context
 
+    @overload
+    def invoke(self, operation: CreateTableOp) -> Table:
+        ...
+
+    @overload
+    def invoke(
+        self,
+        operation: Union[
+            AddConstraintOp,
+            DropConstraintOp,
+            CreateIndexOp,
+            DropIndexOp,
+            AddColumnOp,
+            AlterColumnOp,
+            AlterTableOp,
+            CreateTableCommentOp,
+            DropTableCommentOp,
+            DropColumnOp,
+            BulkInsertOp,
+            DropTableOp,
+            ExecuteSQLOp,
+        ],
+    ) -> None:
+        ...
+
+    @overload
+    def invoke(self, operation: MigrateOperation) -> Any:
+        ...
+
     def invoke(self, operation: MigrateOperation) -> Any:
         """Given a :class:`.MigrateOperation`, invoke it in terms of
         this :class:`.Operations` instance.
@@ -659,8 +711,10 @@ class Operations(AbstractOperations):
             comment: Union[str, Literal[False], None] = False,
             server_default: Any = False,
             new_column_name: Optional[str] = None,
-            type_: Union[TypeEngine, Type[TypeEngine], None] = None,
-            existing_type: Union[TypeEngine, Type[TypeEngine], None] = None,
+            type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None,
+            existing_type: Union[
+                TypeEngine[Any], Type[TypeEngine[Any]], None
+            ] = None,
             existing_server_default: Union[
                 str, bool, Identity, Computed, None
             ] = False,
@@ -756,7 +810,7 @@ class Operations(AbstractOperations):
         def bulk_insert(
             self,
             table: Union[Table, TableClause],
-            rows: List[dict],
+            rows: List[Dict[str, Any]],
             *,
             multiinsert: bool = True,
         ) -> None:
@@ -1560,7 +1614,7 @@ class BatchOperations(AbstractOperations):
 
     impl: BatchOperationsImpl
 
-    def _noop(self, operation):
+    def _noop(self, operation: Any) -> NoReturn:
         raise NotImplementedError(
             "The %s method does not apply to a batch table alter operation."
             % operation
@@ -1596,8 +1650,10 @@ class BatchOperations(AbstractOperations):
             comment: Union[str, Literal[False], None] = False,
             server_default: Any = False,
             new_column_name: Optional[str] = None,
-            type_: Union[TypeEngine, Type[TypeEngine], None] = None,
-            existing_type: Union[TypeEngine, Type[TypeEngine], None] = None,
+            type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None,
+            existing_type: Union[
+                TypeEngine[Any], Type[TypeEngine[Any]], None
+            ] = None,
             existing_server_default: Union[
                 str, bool, Identity, Computed, None
             ] = False,
@@ -1652,7 +1708,7 @@ class BatchOperations(AbstractOperations):
 
         def create_exclude_constraint(
             self, constraint_name: str, *elements: Any, **kw: Any
-        ):
+        ) -> Optional[Table]:
             """Issue a "create exclude constraint" instruction using the
             current batch migration context.
 
index 8c88e885acfdb54bfb506848c5a91b2327ae3703..fd7ab990306aff0501b39ea96fde31b81d20ff94 100644 (file)
@@ -1,3 +1,6 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
 from __future__ import annotations
 
 from typing import Any
@@ -17,7 +20,7 @@ from sqlalchemy import PrimaryKeyConstraint
 from sqlalchemy import schema as sql_schema
 from sqlalchemy import Table
 from sqlalchemy import types as sqltypes
-from sqlalchemy.events import SchemaEventTarget
+from sqlalchemy.sql.schema import SchemaEventTarget
 from sqlalchemy.util import OrderedDict
 from sqlalchemy.util import topological
 
@@ -374,7 +377,7 @@ class ApplyBatchImpl:
         for idx_existing in self.indexes.values():
             # this is a lift-and-move from Table.to_metadata
 
-            if idx_existing._column_flag:  # type: ignore
+            if idx_existing._column_flag:
                 continue
 
             idx_copy = Index(
@@ -403,9 +406,7 @@ class ApplyBatchImpl:
     def _setup_referent(
         self, metadata: MetaData, constraint: ForeignKeyConstraint
     ) -> None:
-        spec = constraint.elements[
-            0
-        ]._get_colspec()  # type:ignore[attr-defined]
+        spec = constraint.elements[0]._get_colspec()
         parts = spec.split(".")
         tname = parts[-2]
         if len(parts) == 3:
@@ -546,9 +547,7 @@ class ApplyBatchImpl:
             else:
                 sql_schema.DefaultClause(
                     server_default  # type: ignore[arg-type]
-                )._set_parent(  # type:ignore[attr-defined]
-                    existing
-                )
+                )._set_parent(existing)
         if autoincrement is not None:
             existing.autoincrement = bool(autoincrement)
 
index 07b3e5749bd2a06d14c40935d5aac152e67b6618..7b65191cf20fa5bc1be08c646247dee611f4f4fe 100644 (file)
@@ -5,6 +5,7 @@ import re
 from typing import Any
 from typing import Callable
 from typing import cast
+from typing import Dict
 from typing import FrozenSet
 from typing import Iterator
 from typing import List
@@ -15,6 +16,7 @@ from typing import Set
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 
 from sqlalchemy.types import NULLTYPE
@@ -53,6 +55,9 @@ if TYPE_CHECKING:
     from ..runtime.migration import MigrationContext
     from ..script.revision import _RevIdType
 
+_T = TypeVar("_T", bound=Any)
+_AC = TypeVar("_AC", bound="AddConstraintOp")
+
 
 class MigrateOperation:
     """base class for migration command and organization objects.
@@ -70,7 +75,7 @@ class MigrateOperation:
     """
 
     @util.memoized_property
-    def info(self):
+    def info(self) -> Dict[Any, Any]:
         """A dictionary that may be used to store arbitrary information
         along with this :class:`.MigrateOperation` object.
 
@@ -92,12 +97,14 @@ class AddConstraintOp(MigrateOperation):
     add_constraint_ops = util.Dispatcher()
 
     @property
-    def constraint_type(self):
+    def constraint_type(self) -> str:
         raise NotImplementedError()
 
     @classmethod
-    def register_add_constraint(cls, type_: str) -> Callable:
-        def go(klass):
+    def register_add_constraint(
+        cls, type_: str
+    ) -> Callable[[Type[_AC]], Type[_AC]]:
+        def go(klass: Type[_AC]) -> Type[_AC]:
             cls.add_constraint_ops.dispatch_for(type_)(klass.from_constraint)
             return klass
 
@@ -105,7 +112,7 @@ class AddConstraintOp(MigrateOperation):
 
     @classmethod
     def from_constraint(cls, constraint: Constraint) -> AddConstraintOp:
-        return cls.add_constraint_ops.dispatch(constraint.__visit_name__)(
+        return cls.add_constraint_ops.dispatch(constraint.__visit_name__)(  # type: ignore[no-any-return]  # noqa: E501
             constraint
         )
 
@@ -398,7 +405,7 @@ class CreateUniqueConstraintOp(AddConstraintOp):
 
         uq_constraint = cast("UniqueConstraint", constraint)
 
-        kw: dict = {}
+        kw: Dict[str, Any] = {}
         if uq_constraint.deferrable:
             kw["deferrable"] = uq_constraint.deferrable
         if uq_constraint.initially:
@@ -532,7 +539,7 @@ class CreateForeignKeyOp(AddConstraintOp):
     @classmethod
     def from_constraint(cls, constraint: Constraint) -> CreateForeignKeyOp:
         fk_constraint = cast("ForeignKeyConstraint", constraint)
-        kw: dict = {}
+        kw: Dict[str, Any] = {}
         if fk_constraint.onupdate:
             kw["onupdate"] = fk_constraint.onupdate
         if fk_constraint.ondelete:
@@ -897,7 +904,7 @@ class CreateIndexOp(MigrateOperation):
     def from_index(cls, index: Index) -> CreateIndexOp:
         assert index.table is not None
         return cls(
-            index.name,  # type: ignore[arg-type]
+            index.name,
             index.table.name,
             index.expressions,
             schema=index.table.schema,
@@ -1183,7 +1190,7 @@ class CreateTableOp(MigrateOperation):
 
         return cls(
             table.name,
-            list(table.c) + list(table.constraints),  # type:ignore[arg-type]
+            list(table.c) + list(table.constraints),
             schema=table.schema,
             _namespace_metadata=_namespace_metadata,
             # given a Table() object, this Table will contain full Index()
@@ -1535,7 +1542,7 @@ class CreateTableCommentOp(AlterTableOp):
         )
         return operations.invoke(op)
 
-    def reverse(self):
+    def reverse(self) -> Union[CreateTableCommentOp, DropTableCommentOp]:
         """Reverses the COMMENT ON operation against a table."""
         if self.existing_comment is None:
             return DropTableCommentOp(
@@ -1551,14 +1558,16 @@ class CreateTableCommentOp(AlterTableOp):
                 schema=self.schema,
             )
 
-    def to_table(self, migration_context=None):
+    def to_table(
+        self, migration_context: Optional[MigrationContext] = None
+    ) -> Table:
         schema_obj = schemaobj.SchemaObjects(migration_context)
 
         return schema_obj.table(
             self.table_name, schema=self.schema, comment=self.comment
         )
 
-    def to_diff_tuple(self):
+    def to_diff_tuple(self) -> Tuple[Any, ...]:
         return ("add_table_comment", self.to_table(), self.existing_comment)
 
 
@@ -1630,18 +1639,20 @@ class DropTableCommentOp(AlterTableOp):
         )
         return operations.invoke(op)
 
-    def reverse(self):
+    def reverse(self) -> CreateTableCommentOp:
         """Reverses the COMMENT ON operation against a table."""
         return CreateTableCommentOp(
             self.table_name, self.existing_comment, schema=self.schema
         )
 
-    def to_table(self, migration_context=None):
+    def to_table(
+        self, migration_context: Optional[MigrationContext] = None
+    ) -> Table:
         schema_obj = schemaobj.SchemaObjects(migration_context)
 
         return schema_obj.table(self.table_name, schema=self.schema)
 
-    def to_diff_tuple(self):
+    def to_diff_tuple(self) -> Tuple[Any, ...]:
         return ("remove_table_comment", self.to_table())
 
 
@@ -1818,8 +1829,10 @@ class AlterColumnOp(AlterTableOp):
         comment: Optional[Union[str, Literal[False]]] = False,
         server_default: Any = False,
         new_column_name: Optional[str] = None,
-        type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
-        existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
+        type_: Optional[Union[TypeEngine[Any], Type[TypeEngine[Any]]]] = None,
+        existing_type: Optional[
+            Union[TypeEngine[Any], Type[TypeEngine[Any]]]
+        ] = None,
         existing_server_default: Optional[
             Union[str, bool, Identity, Computed]
         ] = False,
@@ -1939,8 +1952,10 @@ class AlterColumnOp(AlterTableOp):
         comment: Optional[Union[str, Literal[False]]] = False,
         server_default: Any = False,
         new_column_name: Optional[str] = None,
-        type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
-        existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
+        type_: Optional[Union[TypeEngine[Any], Type[TypeEngine[Any]]]] = None,
+        existing_type: Optional[
+            Union[TypeEngine[Any], Type[TypeEngine[Any]]]
+        ] = None,
         existing_server_default: Optional[
             Union[str, bool, Identity, Computed]
         ] = False,
@@ -2020,11 +2035,11 @@ class AddColumnOp(AlterTableOp):
     ) -> Tuple[str, Optional[str], str, Column[Any]]:
         return ("add_column", self.schema, self.table_name, self.column)
 
-    def to_column(self) -> Column:
+    def to_column(self) -> Column[Any]:
         return self.column
 
     @classmethod
-    def from_column(cls, col: Column) -> AddColumnOp:
+    def from_column(cls, col: Column[Any]) -> AddColumnOp:
         return cls(col.table.name, col, schema=col.table.schema)
 
     @classmethod
@@ -2215,7 +2230,7 @@ class DropColumnOp(AlterTableOp):
 
     def to_column(
         self, migration_context: Optional[MigrationContext] = None
-    ) -> Column:
+    ) -> Column[Any]:
         if self._reverse is not None:
             return self._reverse.column
         schema_obj = schemaobj.SchemaObjects(migration_context)
@@ -2299,7 +2314,7 @@ class BulkInsertOp(MigrateOperation):
     def __init__(
         self,
         table: Union[Table, TableClause],
-        rows: List[dict],
+        rows: List[Dict[str, Any]],
         *,
         multiinsert: bool = True,
     ) -> None:
@@ -2312,7 +2327,7 @@ class BulkInsertOp(MigrateOperation):
         cls,
         operations: Operations,
         table: Union[Table, TableClause],
-        rows: List[dict],
+        rows: List[Dict[str, Any]],
         *,
         multiinsert: bool = True,
     ) -> None:
@@ -2608,7 +2623,7 @@ class UpgradeOps(OpContainer):
         self.upgrade_token = upgrade_token
 
     def reverse_into(self, downgrade_ops: DowngradeOps) -> DowngradeOps:
-        downgrade_ops.ops[:] = list(  # type:ignore[index]
+        downgrade_ops.ops[:] = list(
             reversed([op.reverse() for op in self.ops])
         )
         return downgrade_ops
@@ -2635,7 +2650,7 @@ class DowngradeOps(OpContainer):
         super().__init__(ops=ops)
         self.downgrade_token = downgrade_token
 
-    def reverse(self):
+    def reverse(self) -> UpgradeOps:
         return UpgradeOps(
             ops=list(reversed([op.reverse() for op in self.ops]))
         )
@@ -2666,6 +2681,8 @@ class MigrationScript(MigrateOperation):
     """
 
     _needs_render: Optional[bool]
+    _upgrade_ops: List[UpgradeOps]
+    _downgrade_ops: List[DowngradeOps]
 
     def __init__(
         self,
@@ -2693,7 +2710,7 @@ class MigrationScript(MigrateOperation):
         self.downgrade_ops = downgrade_ops
 
     @property
-    def upgrade_ops(self):
+    def upgrade_ops(self) -> Optional[UpgradeOps]:
         """An instance of :class:`.UpgradeOps`.
 
         .. seealso::
@@ -2712,13 +2729,15 @@ class MigrationScript(MigrateOperation):
             return self._upgrade_ops[0]
 
     @upgrade_ops.setter
-    def upgrade_ops(self, upgrade_ops):
+    def upgrade_ops(
+        self, upgrade_ops: Union[UpgradeOps, List[UpgradeOps]]
+    ) -> None:
         self._upgrade_ops = util.to_list(upgrade_ops)
         for elem in self._upgrade_ops:
             assert isinstance(elem, UpgradeOps)
 
     @property
-    def downgrade_ops(self):
+    def downgrade_ops(self) -> Optional[DowngradeOps]:
         """An instance of :class:`.DowngradeOps`.
 
         .. seealso::
@@ -2737,7 +2756,9 @@ class MigrationScript(MigrateOperation):
             return self._downgrade_ops[0]
 
     @downgrade_ops.setter
-    def downgrade_ops(self, downgrade_ops):
+    def downgrade_ops(
+        self, downgrade_ops: Union[DowngradeOps, List[DowngradeOps]]
+    ) -> None:
         self._downgrade_ops = util.to_list(downgrade_ops)
         for elem in self._downgrade_ops:
             assert isinstance(elem, DowngradeOps)
index 799f1139d93b30867d9f471ef987e4f19d6a9b4f..32b26e9b9d6471c7c663e732a2cfeb35e9eb4bd6 100644 (file)
@@ -1,3 +1,6 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
 from __future__ import annotations
 
 from typing import Any
@@ -274,10 +277,8 @@ class SchemaObjects:
         ForeignKey.
 
         """
-        if isinstance(fk._colspec, str):  # type:ignore[attr-defined]
-            table_key, cname = fk._colspec.rsplit(  # type:ignore[attr-defined]
-                ".", 1
-            )
+        if isinstance(fk._colspec, str):
+            table_key, cname = fk._colspec.rsplit(".", 1)
             sname, tname = self._parse_table_key(table_key)
             if table_key not in metadata.tables:
                 rel_t = sa_schema.Table(tname, metadata, schema=sname)
index ff77ab75e98a585918fa36661c71a7238737275b..4759f7fd2aa7d118ba0e811d5cb207e8b28d173a 100644 (file)
@@ -1,3 +1,6 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
 from typing import TYPE_CHECKING
 
 from sqlalchemy import schema as sa_schema
index 34ae1847711d27c944947fa8cb98e30551595b47..d64b2adc279761b40724b2c7c7c7f53da1e77019 100644 (file)
@@ -228,9 +228,9 @@ class EnvironmentContext(util.ModuleClsProxy):
         has been configured.
 
         """
-        return self.context_opts.get("as_sql", False)
+        return self.context_opts.get("as_sql", False)  # type: ignore[no-any-return]  # noqa: E501
 
-    def is_transactional_ddl(self):
+    def is_transactional_ddl(self) -> bool:
         """Return True if the context is configured to expect a
         transactional DDL capable backend.
 
@@ -339,7 +339,7 @@ class EnvironmentContext(util.ModuleClsProxy):
             line.
 
         """
-        return self.context_opts.get("tag", None)
+        return self.context_opts.get("tag", None)  # type: ignore[no-any-return]  # noqa: E501
 
     @overload
     def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]:
@@ -950,7 +950,7 @@ class EnvironmentContext(util.ModuleClsProxy):
     def execute(
         self,
         sql: Union[Executable, str],
-        execution_options: Optional[dict] = None,
+        execution_options: Optional[Dict[str, Any]] = None,
     ) -> None:
         """Execute the given SQL using the current change context.
 
index 24e3d6449f4e650c6ab78e8abcc27a118023eba1..10a632bb52c92641b787cd89fc0518c720236c5f 100644 (file)
@@ -1,3 +1,6 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
 from __future__ import annotations
 
 from contextlib import contextmanager
@@ -521,7 +524,7 @@ class MigrationContext:
                 start_from_rev = None
             elif start_from_rev is not None and self.script:
                 start_from_rev = [
-                    cast("Script", self.script.get_revision(sfr)).revision
+                    self.script.get_revision(sfr).revision
                     for sfr in util.to_list(start_from_rev)
                     if sfr not in (None, "base")
                 ]
@@ -652,7 +655,7 @@ class MigrationContext:
     def execute(
         self,
         sql: Union[Executable, str],
-        execution_options: Optional[dict] = None,
+        execution_options: Optional[Dict[str, Any]] = None,
     ) -> None:
         """Execute a SQL construct or string statement.
 
@@ -1000,6 +1003,12 @@ class MigrationStep:
     is_upgrade: bool
     migration_fn: Any
 
+    if TYPE_CHECKING:
+
+        @property
+        def doc(self) -> Optional[str]:
+            ...
+
     @property
     def name(self) -> str:
         return self.migration_fn.__name__
@@ -1048,13 +1057,9 @@ class RevisionStep(MigrationStep):
         self.revision = revision
         self.is_upgrade = is_upgrade
         if is_upgrade:
-            self.migration_fn = (
-                revision.module.upgrade  # type:ignore[attr-defined]
-            )
+            self.migration_fn = revision.module.upgrade
         else:
-            self.migration_fn = (
-                revision.module.downgrade  # type:ignore[attr-defined]
-            )
+            self.migration_fn = revision.module.downgrade
 
     def __repr__(self):
         return "RevisionStep(%r, is_upgrade=%r)" % (
@@ -1070,7 +1075,7 @@ class RevisionStep(MigrationStep):
         )
 
     @property
-    def doc(self) -> str:
+    def doc(self) -> Optional[str]:
         return self.revision.doc
 
     @property
@@ -1283,7 +1288,7 @@ class StampStep(MigrationStep):
     def __eq__(self, other):
         return (
             isinstance(other, StampStep)
-            and other.from_revisions == self.revisions
+            and other.from_revisions == self.from_revisions
             and other.to_revisions == self.to_revisions
             and other.branch_move == self.branch_move
             and self.is_upgrade == other.is_upgrade
index 5766d838721b8d783b9a6cd010a65d133fa3b0de..5945ca591c221279b05b07833591faa4ad4cd628 100644 (file)
@@ -41,7 +41,7 @@ try:
         from zoneinfo import ZoneInfoNotFoundError
     else:
         from backports.zoneinfo import ZoneInfo  # type: ignore[import-not-found,no-redef] # noqa: E501
-        from backports.zoneinfo import ZoneInfoNotFoundError  # type: ignore[import-not-found,no-redef] # noqa: E501
+        from backports.zoneinfo import ZoneInfoNotFoundError  # type: ignore[no-redef] # noqa: E501
 except ImportError:
     ZoneInfo = None  # type: ignore[assignment, misc]
 
@@ -119,7 +119,7 @@ class ScriptDirectory:
             return loc[0]
 
     @util.memoized_property
-    def _version_locations(self):
+    def _version_locations(self) -> Sequence[str]:
         if self.version_locations:
             return [
                 os.path.abspath(util.coerce_resource_to_filename(location))
@@ -303,24 +303,22 @@ class ScriptDirectory:
             ):
                 yield cast(Script, rev)
 
-    def get_revisions(self, id_: _GetRevArg) -> Tuple[Optional[Script], ...]:
+    def get_revisions(self, id_: _GetRevArg) -> Tuple[Script, ...]:
         """Return the :class:`.Script` instance with the given rev identifier,
         symbolic name, or sequence of identifiers.
 
         """
         with self._catch_revision_errors():
             return cast(
-                Tuple[Optional[Script], ...],
+                Tuple[Script, ...],
                 self.revision_map.get_revisions(id_),
             )
 
-    def get_all_current(self, id_: Tuple[str, ...]) -> Set[Optional[Script]]:
+    def get_all_current(self, id_: Tuple[str, ...]) -> Set[Script]:
         with self._catch_revision_errors():
-            return cast(
-                Set[Optional[Script]], self.revision_map._get_all_current(id_)
-            )
+            return cast(Set[Script], self.revision_map._get_all_current(id_))
 
-    def get_revision(self, id_: str) -> Optional[Script]:
+    def get_revision(self, id_: str) -> Script:
         """Return the :class:`.Script` instance with the given rev id.
 
         .. seealso::
@@ -330,7 +328,7 @@ class ScriptDirectory:
         """
 
         with self._catch_revision_errors():
-            return cast(Optional[Script], self.revision_map.get_revision(id_))
+            return cast(Script, self.revision_map.get_revision(id_))
 
     def as_revision_number(
         self, id_: Optional[str]
@@ -585,7 +583,7 @@ class ScriptDirectory:
         util.load_python_file(self.dir, "env.py")
 
     @property
-    def env_py_location(self):
+    def env_py_location(self) -> str:
         return os.path.abspath(os.path.join(self.dir, "env.py"))
 
     def _generate_template(self, src: str, dest: str, **kw: Any) -> None:
@@ -684,7 +682,7 @@ class ScriptDirectory:
                 self.revision_map.get_revisions(head),
             )
             for h in heads:
-                assert h != "base"
+                assert h != "base"  # type: ignore[comparison-overlap]
 
         if len(set(heads)) != len(heads):
             raise util.CommandError("Duplicate head revisions specified")
@@ -823,7 +821,7 @@ class Script(revision.Revision):
         self.path = path
         super().__init__(
             rev_id,
-            module.down_revision,  # type: ignore[attr-defined]
+            module.down_revision,
             branch_labels=util.to_tuple(
                 getattr(module, "branch_labels", None), default=()
             ),
@@ -856,7 +854,7 @@ class Script(revision.Revision):
         if doc:
             if hasattr(self.module, "_alembic_source_encoding"):
                 doc = doc.decode(  # type: ignore[attr-defined]
-                    self.module._alembic_source_encoding  # type: ignore[attr-defined] # noqa
+                    self.module._alembic_source_encoding
                 )
             return doc.strip()  # type: ignore[union-attr]
         else:
@@ -898,7 +896,7 @@ class Script(revision.Revision):
         )
         return entry
 
-    def __str__(self):
+    def __str__(self) -> str:
         return "%s -> %s%s%s%s, %s" % (
             self._format_down_revision(),
             self.revision,
index 035026441fad89518cf4dd91e3e9947d0624ef00..77a802cdcadf9c59049fdb5db1c2be95d305a1ae 100644 (file)
@@ -14,6 +14,7 @@ from typing import Iterator
 from typing import List
 from typing import Optional
 from typing import overload
+from typing import Protocol
 from typing import Sequence
 from typing import Set
 from typing import Tuple
@@ -47,6 +48,18 @@ _relative_destination = re.compile(r"(?:(.+?)@)?(\w+)?((?:\+|-)\d+)")
 _revision_illegal_chars = ["@", "-", "+"]
 
 
+class _CollectRevisionsProtocol(Protocol):
+    def __call__(
+        self,
+        upper: _RevisionIdentifierType,
+        lower: _RevisionIdentifierType,
+        inclusive: bool,
+        implicit_base: bool,
+        assert_relative_length: bool,
+    ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase], ...]]:
+        ...
+
+
 class RevisionError(Exception):
     pass
 
@@ -396,7 +409,7 @@ class RevisionMap:
                 for rev in self._get_ancestor_nodes(
                     [revision],
                     include_dependencies=False,
-                    map_=cast(_RevisionMapType, map_),
+                    map_=map_,
                 ):
                     if rev is revision:
                         continue
@@ -791,7 +804,7 @@ class RevisionMap:
         The iterator yields :class:`.Revision` objects.
 
         """
-        fn: Callable
+        fn: _CollectRevisionsProtocol
         if select_for_downgrade:
             fn = self._collect_downgrade_revisions
         else:
@@ -818,7 +831,7 @@ class RevisionMap:
     ) -> Iterator[Any]:
         if omit_immediate_dependencies:
 
-            def fn(rev):
+            def fn(rev: Revision) -> Iterable[str]:
                 if rev not in targets:
                     return rev._all_nextrev
                 else:
@@ -826,12 +839,12 @@ class RevisionMap:
 
         elif include_dependencies:
 
-            def fn(rev):
+            def fn(rev: Revision) -> Iterable[str]:
                 return rev._all_nextrev
 
         else:
 
-            def fn(rev):
+            def fn(rev: Revision) -> Iterable[str]:
                 return rev.nextrev
 
         return self._iterate_related_revisions(
@@ -847,12 +860,12 @@ class RevisionMap:
     ) -> Iterator[Revision]:
         if include_dependencies:
 
-            def fn(rev):
+            def fn(rev: Revision) -> Iterable[str]:
                 return rev._normalized_down_revisions
 
         else:
 
-            def fn(rev):
+            def fn(rev: Revision) -> Iterable[str]:
                 return rev._versioned_down_revisions
 
         return self._iterate_related_revisions(
@@ -861,7 +874,7 @@ class RevisionMap:
 
     def _iterate_related_revisions(
         self,
-        fn: Callable,
+        fn: Callable[[Revision], Iterable[str]],
         targets: Collection[Optional[_RevisionOrBase]],
         map_: Optional[_RevisionMapType],
         check: bool = False,
@@ -923,7 +936,7 @@ class RevisionMap:
 
         id_to_rev = self._revision_map
 
-        def get_ancestors(rev_id):
+        def get_ancestors(rev_id: str) -> Set[str]:
             return {
                 r.revision
                 for r in self._get_ancestor_nodes([id_to_rev[rev_id]])
@@ -1041,7 +1054,7 @@ class RevisionMap:
         children: Sequence[Optional[_RevisionOrBase]]
         for _ in range(abs(steps)):
             if steps > 0:
-                assert initial != "base"
+                assert initial != "base"  # type: ignore[comparison-overlap]
                 # Walk up
                 walk_up = [
                     is_revision(rev)
@@ -1055,7 +1068,7 @@ class RevisionMap:
                     children = walk_up
             else:
                 # Walk down
-                if initial == "base":
+                if initial == "base":  # type: ignore[comparison-overlap]
                     children = ()
                 else:
                     children = self.get_revisions(
@@ -1189,7 +1202,7 @@ class RevisionMap:
         # No relative destination given, revision specified is absolute.
         branch_label, _, symbol = target.rpartition("@")
         if not branch_label:
-            branch_label = None  # type:ignore[assignment]
+            branch_label = None
         return branch_label, self.get_revision(symbol)
 
     def _parse_upgrade_target(
@@ -1301,11 +1314,11 @@ class RevisionMap:
     def _collect_downgrade_revisions(
         self,
         upper: _RevisionIdentifierType,
-        target: _RevisionIdentifierType,
+        lower: _RevisionIdentifierType,
         inclusive: bool,
         implicit_base: bool,
         assert_relative_length: bool,
-    ) -> Any:
+    ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase], ...]]:
         """
         Compute the set of current revisions specified by :upper, and the
         downgrade target specified by :target. Return all dependents of target
@@ -1316,7 +1329,7 @@ class RevisionMap:
 
         branch_label, target_revision = self._parse_downgrade_target(
             current_revisions=upper,
-            target=target,
+            target=lower,
             assert_relative_length=assert_relative_length,
         )
         if target_revision == "base":
@@ -1408,7 +1421,7 @@ class RevisionMap:
         inclusive: bool,
         implicit_base: bool,
         assert_relative_length: bool,
-    ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase]]]:
+    ) -> Tuple[Set[Revision], Tuple[Revision, ...]]:
         """
         Compute the set of required revisions specified by :upper, and the
         current set of active revisions specified by :lower. Find the
@@ -1500,7 +1513,7 @@ class RevisionMap:
             )
             needs.intersection_update(lower_descendents)
 
-        return needs, tuple(targets)  # type:ignore[return-value]
+        return needs, tuple(targets)
 
     def _get_all_current(
         self, id_: Tuple[str, ...]
index b44ce644deff5817f888d5b8f25eed5e3c08f856..9977147921055b2f5540993188ca495455a2ca7d 100644 (file)
@@ -1,3 +1,6 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
 from __future__ import annotations
 
 import shlex
index 3c1e27ca4a3158aa388ee47caaf49233c1d7007e..4724e1f0847c4b8fe942e85285970c117acfa915 100644 (file)
@@ -1,34 +1,34 @@
-from .editor import open_in_editor
-from .exc import AutogenerateDiffsDetected
-from .exc import CommandError
-from .langhelpers import _with_legacy_names
-from .langhelpers import asbool
-from .langhelpers import dedupe_tuple
-from .langhelpers import Dispatcher
-from .langhelpers import EMPTY_DICT
-from .langhelpers import immutabledict
-from .langhelpers import memoized_property
-from .langhelpers import ModuleClsProxy
-from .langhelpers import not_none
-from .langhelpers import rev_id
-from .langhelpers import to_list
-from .langhelpers import to_tuple
-from .langhelpers import unique_list
-from .messaging import err
-from .messaging import format_as_comma
-from .messaging import msg
-from .messaging import obfuscate_url_pw
-from .messaging import status
-from .messaging import warn
-from .messaging import write_outstream
-from .pyfiles import coerce_resource_to_filename
-from .pyfiles import load_python_file
-from .pyfiles import pyc_file_from_path
-from .pyfiles import template_to_file
-from .sqla_compat import has_computed
-from .sqla_compat import sqla_13
-from .sqla_compat import sqla_14
-from .sqla_compat import sqla_2
+from .editor import open_in_editor as open_in_editor
+from .exc import AutogenerateDiffsDetected as AutogenerateDiffsDetected
+from .exc import CommandError as CommandError
+from .langhelpers import _with_legacy_names as _with_legacy_names
+from .langhelpers import asbool as asbool
+from .langhelpers import dedupe_tuple as dedupe_tuple
+from .langhelpers import Dispatcher as Dispatcher
+from .langhelpers import EMPTY_DICT as EMPTY_DICT
+from .langhelpers import immutabledict as immutabledict
+from .langhelpers import memoized_property as memoized_property
+from .langhelpers import ModuleClsProxy as ModuleClsProxy
+from .langhelpers import not_none as not_none
+from .langhelpers import rev_id as rev_id
+from .langhelpers import to_list as to_list
+from .langhelpers import to_tuple as to_tuple
+from .langhelpers import unique_list as unique_list
+from .messaging import err as err
+from .messaging import format_as_comma as format_as_comma
+from .messaging import msg as msg
+from .messaging import obfuscate_url_pw as obfuscate_url_pw
+from .messaging import status as status
+from .messaging import warn as warn
+from .messaging import write_outstream as write_outstream
+from .pyfiles import coerce_resource_to_filename as coerce_resource_to_filename
+from .pyfiles import load_python_file as load_python_file
+from .pyfiles import pyc_file_from_path as pyc_file_from_path
+from .pyfiles import template_to_file as template_to_file
+from .sqla_compat import has_computed as has_computed
+from .sqla_compat import sqla_13 as sqla_13
+from .sqla_compat import sqla_14 as sqla_14
+from .sqla_compat import sqla_2 as sqla_2
 
 
 if not sqla_13:
index 5b8f3d952fa05263f1b726ac915c469eb98a26aa..e185cc417204295070406b9a77231c48b3d6c38e 100644 (file)
@@ -1,3 +1,5 @@
+# mypy: no-warn-unused-ignores
+
 from __future__ import annotations
 
 from configparser import ConfigParser
@@ -5,11 +7,20 @@ import io
 import os
 import sys
 import typing
+from typing import Any
+from typing import List
+from typing import Optional
 from typing import Sequence
 from typing import Union
 
-from sqlalchemy.util import inspect_getfullargspec  # noqa
-from sqlalchemy.util.compat import inspect_formatargspec  # noqa
+if True:
+    # zimports hack for too-long names
+    from sqlalchemy.util import (  # noqa: F401
+        inspect_getfullargspec as inspect_getfullargspec,
+    )
+    from sqlalchemy.util.compat import (  # noqa: F401
+        inspect_formatargspec as inspect_formatargspec,
+    )
 
 is_posix = os.name == "posix"
 
@@ -27,9 +38,13 @@ class EncodedIO(io.TextIOWrapper):
 
 
 if py39:
-    from importlib import resources as importlib_resources
-    from importlib import metadata as importlib_metadata
-    from importlib.metadata import EntryPoint
+    from importlib import resources as _resources
+
+    importlib_resources = _resources
+    from importlib import metadata as _metadata
+
+    importlib_metadata = _metadata
+    from importlib.metadata import EntryPoint as EntryPoint
 else:
     import importlib_resources  # type:ignore # noqa
     import importlib_metadata  # type:ignore # noqa
@@ -39,12 +54,14 @@ else:
 def importlib_metadata_get(group: str) -> Sequence[EntryPoint]:
     ep = importlib_metadata.entry_points()
     if hasattr(ep, "select"):
-        return ep.select(group=group)  # type: ignore
+        return ep.select(group=group)
     else:
         return ep.get(group, ())  # type: ignore
 
 
-def formatannotation_fwdref(annotation, base_module=None):
+def formatannotation_fwdref(
+    annotation: Any, base_module: Optional[Any] = None
+) -> str:
     """vendored from python 3.7"""
     # copied over _formatannotation from sqlalchemy 2.0
 
@@ -65,7 +82,7 @@ def formatannotation_fwdref(annotation, base_module=None):
 def read_config_parser(
     file_config: ConfigParser,
     file_argument: Sequence[Union[str, os.PathLike[str]]],
-) -> list[str]:
+) -> List[str]:
     if py310:
         return file_config.read(file_argument, encoding="locale")
     else:
index 34d48bc6c77262f0119f8a5d3bf88211aee88f16..4a5bf09a98bba393e5d61a8abf67ba011ab52b71 100644 (file)
@@ -5,33 +5,46 @@ from collections.abc import Iterable
 import textwrap
 from typing import Any
 from typing import Callable
+from typing import cast
 from typing import Dict
 from typing import List
 from typing import Mapping
+from typing import MutableMapping
+from typing import NoReturn
 from typing import Optional
 from typing import overload
 from typing import Sequence
+from typing import Set
 from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 import uuid
 import warnings
 
-from sqlalchemy.util import asbool  # noqa
-from sqlalchemy.util import immutabledict  # noqa
-from sqlalchemy.util import memoized_property  # noqa
-from sqlalchemy.util import to_list  # noqa
-from sqlalchemy.util import unique_list  # noqa
+from sqlalchemy.util import asbool as asbool  # noqa: F401
+from sqlalchemy.util import immutabledict as immutabledict  # noqa: F401
+from sqlalchemy.util import to_list as to_list  # noqa: F401
+from sqlalchemy.util import unique_list as unique_list
 
 from .compat import inspect_getfullargspec
 
+if True:
+    # zimports workaround :(
+    from sqlalchemy.util import (  # noqa: F401
+        memoized_property as memoized_property,
+    )
+
 
 EMPTY_DICT: Mapping[Any, Any] = immutabledict()
-_T = TypeVar("_T")
+_T = TypeVar("_T", bound=Any)
+
+_C = TypeVar("_C", bound=Callable[..., Any])
 
 
 class _ModuleClsMeta(type):
-    def __setattr__(cls, key: str, value: Callable) -> None:
+    def __setattr__(cls, key: str, value: Callable[..., Any]) -> None:
         super().__setattr__(key, value)
         cls._update_module_proxies(key)  # type: ignore
 
@@ -45,9 +58,13 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
 
     """
 
-    _setups: Dict[type, Tuple[set, list]] = collections.defaultdict(
-        lambda: (set(), [])
-    )
+    _setups: Dict[
+        Type[Any],
+        Tuple[
+            Set[str],
+            List[Tuple[MutableMapping[str, Any], MutableMapping[str, Any]]],
+        ],
+    ] = collections.defaultdict(lambda: (set(), []))
 
     @classmethod
     def _update_module_proxies(cls, name: str) -> None:
@@ -70,18 +87,33 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
                 del globals_[attr_name]
 
     @classmethod
-    def create_module_class_proxy(cls, globals_, locals_):
+    def create_module_class_proxy(
+        cls,
+        globals_: MutableMapping[str, Any],
+        locals_: MutableMapping[str, Any],
+    ) -> None:
         attr_names, modules = cls._setups[cls]
         modules.append((globals_, locals_))
         cls._setup_proxy(globals_, locals_, attr_names)
 
     @classmethod
-    def _setup_proxy(cls, globals_, locals_, attr_names):
+    def _setup_proxy(
+        cls,
+        globals_: MutableMapping[str, Any],
+        locals_: MutableMapping[str, Any],
+        attr_names: Set[str],
+    ) -> None:
         for methname in dir(cls):
             cls._add_proxied_attribute(methname, globals_, locals_, attr_names)
 
     @classmethod
-    def _add_proxied_attribute(cls, methname, globals_, locals_, attr_names):
+    def _add_proxied_attribute(
+        cls,
+        methname: str,
+        globals_: MutableMapping[str, Any],
+        locals_: MutableMapping[str, Any],
+        attr_names: Set[str],
+    ) -> None:
         if not methname.startswith("_"):
             meth = getattr(cls, methname)
             if callable(meth):
@@ -92,10 +124,15 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
                 attr_names.add(methname)
 
     @classmethod
-    def _create_method_proxy(cls, name, globals_, locals_):
+    def _create_method_proxy(
+        cls,
+        name: str,
+        globals_: MutableMapping[str, Any],
+        locals_: MutableMapping[str, Any],
+    ) -> Callable[..., Any]:
         fn = getattr(cls, name)
 
-        def _name_error(name, from_):
+        def _name_error(name: str, from_: Exception) -> NoReturn:
             raise NameError(
                 "Can't invoke function '%s', as the proxy object has "
                 "not yet been "
@@ -119,7 +156,9 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
                 translations,
             )
 
-            def translate(fn_name, spec, translations, args, kw):
+            def translate(
+                fn_name: str, spec: Any, translations: Any, args: Any, kw: Any
+            ) -> Any:
                 return_kw = {}
                 return_args = []
 
@@ -176,15 +215,15 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
                 "doc": fn.__doc__,
             }
         )
-        lcl = {}
+        lcl: MutableMapping[str, Any] = {}
 
-        exec(func_text, globals_, lcl)
-        return lcl[name]
+        exec(func_text, cast("Dict[str, Any]", globals_), lcl)
+        return cast("Callable[..., Any]", lcl[name])
 
 
-def _with_legacy_names(translations):
-    def decorate(fn):
-        fn._legacy_translations = translations
+def _with_legacy_names(translations: Any) -> Any:
+    def decorate(fn: _C) -> _C:
+        fn._legacy_translations = translations  # type: ignore[attr-defined]
         return fn
 
     return decorate
@@ -195,21 +234,25 @@ def rev_id() -> str:
 
 
 @overload
-def to_tuple(x: Any, default: tuple) -> tuple:
+def to_tuple(x: Any, default: Tuple[Any, ...]) -> Tuple[Any, ...]:
     ...
 
 
 @overload
-def to_tuple(x: None, default: Optional[_T] = None) -> _T:
+def to_tuple(x: None, default: Optional[_T] = ...) -> _T:
     ...
 
 
 @overload
-def to_tuple(x: Any, default: Optional[tuple] = None) -> tuple:
+def to_tuple(
+    x: Any, default: Optional[Tuple[Any, ...]] = None
+) -> Tuple[Any, ...]:
     ...
 
 
-def to_tuple(x, default=None):
+def to_tuple(
+    x: Any, default: Optional[Tuple[Any, ...]] = None
+) -> Optional[Tuple[Any, ...]]:
     if x is None:
         return default
     elif isinstance(x, str):
@@ -226,13 +269,13 @@ def dedupe_tuple(tup: Tuple[str, ...]) -> Tuple[str, ...]:
 
 class Dispatcher:
     def __init__(self, uselist: bool = False) -> None:
-        self._registry: Dict[tuple, Any] = {}
+        self._registry: Dict[Tuple[Any, ...], Any] = {}
         self.uselist = uselist
 
     def dispatch_for(
         self, target: Any, qualifier: str = "default"
-    ) -> Callable:
-        def decorate(fn):
+    ) -> Callable[[_C], _C]:
+        def decorate(fn: _C) -> _C:
             if self.uselist:
                 self._registry.setdefault((target, qualifier), []).append(fn)
             else:
@@ -244,7 +287,7 @@ class Dispatcher:
 
     def dispatch(self, obj: Any, qualifier: str = "default") -> Any:
         if isinstance(obj, str):
-            targets: Sequence = [obj]
+            targets: Sequence[Any] = [obj]
         elif isinstance(obj, type):
             targets = obj.__mro__
         else:
@@ -259,11 +302,13 @@ class Dispatcher:
             raise ValueError("no dispatch function for object: %s" % obj)
 
     def _fn_or_list(
-        self, fn_or_list: Union[List[Callable], Callable]
-    ) -> Callable:
+        self, fn_or_list: Union[List[Callable[..., Any]], Callable[..., Any]]
+    ) -> Callable[..., Any]:
         if self.uselist:
 
-            def go(*arg, **kw):
+            def go(*arg: Any, **kw: Any) -> None:
+                if TYPE_CHECKING:
+                    assert isinstance(fn_or_list, Sequence)
                 for fn in fn_or_list:
                     fn(*arg, **kw)
 
index 35592c0ec9a83f327661ac7435ba908be108b969..5f14d597554e281ee098afbbb0f2864cec0fd2e8 100644 (file)
@@ -5,6 +5,7 @@ from contextlib import contextmanager
 import logging
 import sys
 import textwrap
+from typing import Iterator
 from typing import Optional
 from typing import TextIO
 from typing import Union
@@ -53,7 +54,9 @@ def write_outstream(
 
 
 @contextmanager
-def status(status_msg: str, newline: bool = False, quiet: bool = False):
+def status(
+    status_msg: str, newline: bool = False, quiet: bool = False
+) -> Iterator[None]:
     msg(status_msg + " ...", newline, flush=True, quiet=quiet)
     try:
         yield
@@ -66,7 +69,7 @@ def status(status_msg: str, newline: bool = False, quiet: bool = False):
             write_outstream(sys.stdout, "  done\n")
 
 
-def err(message: str, quiet: bool = False):
+def err(message: str, quiet: bool = False) -> None:
     log.error(message)
     msg(f"FAILED: {message}", quiet=quiet)
     sys.exit(-1)
@@ -74,7 +77,7 @@ def err(message: str, quiet: bool = False):
 
 def obfuscate_url_pw(input_url: str) -> str:
     u = url.make_url(input_url)
-    return sqla_compat.url_render_as_string(u, hide_password=True)
+    return sqla_compat.url_render_as_string(u, hide_password=True)  # type: ignore  # noqa: E501
 
 
 def warn(msg: str, stacklevel: int = 2) -> None:
index e7576731e124a972157b3bdc377a662672784111..973bd458e5ce615a2b31aa3eeaeec61d6c2f709e 100644 (file)
@@ -8,6 +8,8 @@ import importlib.util
 import os
 import re
 import tempfile
+from types import ModuleType
+from typing import Any
 from typing import Optional
 
 from mako import exceptions
@@ -18,7 +20,7 @@ from .exc import CommandError
 
 
 def template_to_file(
-    template_file: str, dest: str, output_encoding: str, **kw
+    template_file: str, dest: str, output_encoding: str, **kw: Any
 ) -> None:
     template = Template(filename=template_file)
     try:
@@ -82,7 +84,7 @@ def pyc_file_from_path(path: str) -> Optional[str]:
         return None
 
 
-def load_python_file(dir_: str, filename: str):
+def load_python_file(dir_: str, filename: str) -> ModuleType:
     """Load a file from the given path as a Python module."""
 
     module_id = re.sub(r"\W", "_", filename)
@@ -99,10 +101,12 @@ def load_python_file(dir_: str, filename: str):
                 module = load_module_py(module_id, pyc_path)
     elif ext in (".pyc", ".pyo"):
         module = load_module_py(module_id, path)
+    else:
+        assert False
     return module
 
 
-def load_module_py(module_id: str, path: str):
+def load_module_py(module_id: str, path: str) -> ModuleType:
     spec = importlib.util.spec_from_file_location(module_id, path)
     assert spec
     module = importlib.util.module_from_spec(spec)
index 9332a062563c89a634257966f3427d372f5f22e4..8489c19fac7c163dc9053d2f52606855117d60a6 100644 (file)
@@ -1,13 +1,20 @@
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
 from __future__ import annotations
 
 import contextlib
 import re
 from typing import Any
+from typing import Callable
 from typing import Dict
 from typing import Iterable
 from typing import Iterator
 from typing import Mapping
 from typing import Optional
+from typing import Protocol
+from typing import Set
+from typing import Type
 from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
@@ -18,7 +25,6 @@ from sqlalchemy import schema
 from sqlalchemy import sql
 from sqlalchemy import types as sqltypes
 from sqlalchemy.engine import url
-from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.schema import CheckConstraint
 from sqlalchemy.schema import Column
 from sqlalchemy.schema import ForeignKeyConstraint
@@ -33,6 +39,7 @@ from sqlalchemy.sql.visitors import traverse
 from typing_extensions import TypeGuard
 
 if TYPE_CHECKING:
+    from sqlalchemy import ClauseElement
     from sqlalchemy import Index
     from sqlalchemy import Table
     from sqlalchemy.engine import Connection
@@ -51,6 +58,11 @@ if TYPE_CHECKING:
 _CE = TypeVar("_CE", bound=Union["ColumnElement[Any]", "SchemaItem"])
 
 
+class _CompilerProtocol(Protocol):
+    def __call__(self, element: Any, compiler: Any, **kw: Any) -> str:
+        ...
+
+
 def _safe_int(value: str) -> Union[int, str]:
     try:
         return int(value)
@@ -70,7 +82,7 @@ sqla_2 = _vers >= (2,)
 sqlalchemy_version = __version__
 
 try:
-    from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME
+    from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME  # type: ignore[attr-defined]  # noqa: E501
 except ImportError:
     from sqlalchemy.sql.elements import _NONE_NAME as _NONE_NAME  # type: ignore  # noqa: E501
 
@@ -79,8 +91,18 @@ class _Unsupported:
     "Placeholder for unsupported SQLAlchemy classes"
 
 
+if TYPE_CHECKING:
+
+    def compiles(
+        element: Type[ClauseElement], *dialects: str
+    ) -> Callable[[_CompilerProtocol], _CompilerProtocol]:
+        ...
+
+else:
+    from sqlalchemy.ext.compiler import compiles
+
 try:
-    from sqlalchemy import Computed
+    from sqlalchemy import Computed as Computed
 except ImportError:
     if not TYPE_CHECKING:
 
@@ -94,7 +116,7 @@ else:
     has_computed_reflection = _vers >= (1, 3, 16)
 
 try:
-    from sqlalchemy import Identity
+    from sqlalchemy import Identity as Identity
 except ImportError:
     if not TYPE_CHECKING:
 
@@ -250,7 +272,7 @@ def _idx_table_bound_expressions(idx: Index) -> Iterable[ColumnElement[Any]]:
 
 def _copy(schema_item: _CE, **kw) -> _CE:
     if hasattr(schema_item, "_copy"):
-        return schema_item._copy(**kw)  # type: ignore[union-attr]
+        return schema_item._copy(**kw)
     else:
         return schema_item.copy(**kw)  # type: ignore[union-attr]
 
@@ -368,7 +390,12 @@ else:
         return type_.impl, type_.mapping
 
 
-def _fk_spec(constraint):
+def _fk_spec(constraint: ForeignKeyConstraint) -> Any:
+    if TYPE_CHECKING:
+        assert constraint.columns is not None
+        assert constraint.elements is not None
+        assert isinstance(constraint.parent, Table)
+
     source_columns = [
         constraint.columns[key].name for key in constraint.column_keys
     ]
@@ -397,7 +424,7 @@ def _fk_spec(constraint):
 
 
 def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool:
-    spec = constraint.elements[0]._get_colspec()  # type: ignore[attr-defined]
+    spec = constraint.elements[0]._get_colspec()
     tokens = spec.split(".")
     tokens.pop(-1)  # colname
     tablekey = ".".join(tokens)
@@ -409,13 +436,13 @@ def _is_type_bound(constraint: Constraint) -> bool:
     # this deals with SQLAlchemy #3260, don't copy CHECK constraints
     # that will be generated by the type.
     # new feature added for #3260
-    return constraint._type_bound  # type: ignore[attr-defined]
+    return constraint._type_bound
 
 
 def _find_columns(clause):
     """locate Column objects within the given expression."""
 
-    cols = set()
+    cols: Set[ColumnElement[Any]] = set()
     traverse(clause, {}, {"column": cols.add})
     return cols
 
@@ -562,9 +589,7 @@ def _get_constraint_final_name(
         if isinstance(constraint, schema.Index):
             # name should not be quoted.
             d = dialect.ddl_compiler(dialect, None)  # type: ignore[arg-type]
-            return d._prepared_index_name(  # type: ignore[attr-defined]
-                constraint
-            )
+            return d._prepared_index_name(constraint)
         else:
             # name should not be quoted.
             return dialect.identifier_preparer.format_constraint(constraint)
@@ -608,7 +633,11 @@ def _insert_inline(table: Union[TableClause, Table]) -> Insert:
 
 if sqla_14:
     from sqlalchemy import create_mock_engine
-    from sqlalchemy import select as _select
+
+    # weird mypy workaround
+    from sqlalchemy import select as _sa_select
+
+    _select = _sa_select
 else:
     from sqlalchemy import create_engine
 
@@ -617,7 +646,7 @@ else:
             "postgresql://", strategy="mock", executor=executor
         )
 
-    def _select(*columns, **kw) -> Select:  # type: ignore[no-redef]
+    def _select(*columns, **kw) -> Select:
         return sql.select(list(columns), **kw)  # type: ignore[call-overload]
 
 
diff --git a/docs/build/unreleased/1377.rst b/docs/build/unreleased/1377.rst
new file mode 100644 (file)
index 0000000..a8bb6c1
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 1377
+
+    Updated pep-484 typing to pass mypy "strict" mode, however including
+    per-module qualifications for specific typing elements not yet complete.
+    This allows us to catch specific typing issues that have been ongoing
+    such as import symbols not properly exported.
+
index f66269af6beb2b3b98965f24e8977103dc4fd935..b9b1f44a672fdc02359f90aab7d09f747ab62e7c 100644 (file)
@@ -16,15 +16,15 @@ exclude = [
 show_error_codes = true
 
 [[tool.mypy.overrides]]
+
 module = [
-    'alembic.operations.ops',
-    'alembic.op',
-    'alembic.context',
-    'alembic.autogenerate.api',
-    'alembic.runtime.*',
+    "alembic.*"
 ]
 
-disallow_incomplete_defs = true
+warn_unused_ignores = true
+strict = true
+
+
 
 [[tool.mypy.overrides]]
 module = [
index 5c3303830197799043885f23f60e6af696494432..fa957ecac63eddcf4e744613daddf0893a2d8386 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -125,18 +125,3 @@ python_files=tests/test_*.py
 markers =
     backend: tests that should run on all backends; typically dialect-sensitive
 
-[mypy]
-show_error_codes = True
-allow_redefinition = True
-
-[mypy-mako.*]
-ignore_missing_imports = True
-
-[mypy-sqlalchemy.testing.*]
-ignore_missing_imports = True
-
-[mypy-importlib_resources.*]
-ignore_missing_imports = True
-
-[mypy-importlib_metadata.*]
-ignore_missing_imports = True
index 5abb26ef1eb30f422153ba5e3a0b97d3d85feee9..363d727ec98ba5d83bf3fb69a4da67fd0b4de22f 100644 (file)
@@ -127,9 +127,7 @@ def generate_pyi_for_proxy(
         {"entrypoint": "zimports", "options": "-e"},
         ignore_output=ignore_output,
     )
-    # note that we do not distribute pyproject.toml with the distribution
-    # right now due to user complaints, so we can't refer to it here because
-    # this all has to run as part of the test suite
+
     console_scripts(
         str(destination_path),
         {"entrypoint": "black", "options": "-l79"},
@@ -190,6 +188,8 @@ def _generate_stub_for_meth(
         else:
             retval = annotation
 
+        retval = re.sub(r"TypeEngine\b", "TypeEngine[Any]", retval)
+
         retval = retval.replace("~", "")  # typevar repr as "~T"
         for trim in TRIM_MODULE:
             retval = retval.replace(trim, "")