]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Various typing related updates
authorCaselIT <cfederico87@gmail.com>
Thu, 21 Apr 2022 21:23:00 +0000 (23:23 +0200)
committerCaselIT <cfederico87@gmail.com>
Sat, 23 Apr 2022 20:04:36 +0000 (22:04 +0200)
Change-Id: I778b63b1c438f31964d841576f0dd54ae1a5fadc

38 files changed:
alembic/autogenerate/api.py
alembic/autogenerate/compare.py
alembic/autogenerate/render.py
alembic/autogenerate/rewriter.py
alembic/command.py
alembic/config.py
alembic/context.pyi
alembic/ddl/base.py
alembic/ddl/impl.py
alembic/ddl/mssql.py
alembic/ddl/mysql.py
alembic/ddl/oracle.py
alembic/ddl/postgresql.py
alembic/ddl/sqlite.py
alembic/op.pyi
alembic/operations/base.py
alembic/operations/batch.py
alembic/operations/ops.py
alembic/operations/schemaobj.py
alembic/runtime/environment.py
alembic/runtime/migration.py
alembic/script/base.py
alembic/script/revision.py
alembic/script/write_hooks.py
alembic/testing/assertions.py
alembic/testing/fixtures.py
alembic/testing/suite/_autogen_fixtures.py
alembic/testing/util.py
alembic/util/__init__.py
alembic/util/compat.py
alembic/util/editor.py
alembic/util/langhelpers.py
alembic/util/messaging.py
alembic/util/pyfiles.py
alembic/util/sqla_compat.py
pyproject.toml
tests/requirements.py
tools/write_pyi.py

index a5528ff766a3aee5536de293f29bb3d7bbf6bba8..4ab8a3593fa9a599e14211a85d042543dbf457ec 100644 (file)
@@ -1,5 +1,4 @@
-"""Provide the 'autogenerate' feature which can produce migration operations
-automatically."""
+from __future__ import annotations
 
 import contextlib
 from typing import Any
@@ -19,6 +18,9 @@ from . import render
 from .. import util
 from ..operations import ops
 
+"""Provide the 'autogenerate' feature which can produce migration operations
+automatically."""
+
 if TYPE_CHECKING:
     from sqlalchemy.engine import Connection
     from sqlalchemy.engine import Dialect
@@ -515,7 +517,7 @@ class RevisionContext:
             branch_labels=migration_script.branch_label,
             version_path=migration_script.version_path,
             depends_on=migration_script.depends_on,
-            **template_args
+            **template_args,
         )
 
     def run_autogenerate(
index 528b17ac4d88e85500b76c6ff35cd7d0db77fd58..693efae9ee494be2594d9ad72bc7e073395cc586 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import contextlib
 import logging
 import re
@@ -282,7 +284,7 @@ def _make_index(params: Dict[str, Any], conn_table: "Table") -> "Index":
         params["name"],
         *[conn_table.c[cname] for cname in params["column_names"]],
         unique=params["unique"],
-        _table=conn_table
+        _table=conn_table,
     )
     if "duplicates_constraint" in params:
         ix.info["duplicates_constraint"] = params["duplicates_constraint"]
@@ -294,7 +296,7 @@ def _make_unique_constraint(
 ) -> "UniqueConstraint":
     uq = sa_schema.UniqueConstraint(
         *[conn_table.c[cname] for cname in params["column_names"]],
-        name=params["name"]
+        name=params["name"],
     )
     if "duplicates_index" in params:
         uq.info["duplicates_index"] = params["duplicates_index"]
@@ -1245,7 +1247,7 @@ def _compare_foreign_keys(
         if isinstance(fk, sa_schema.ForeignKeyConstraint)
     )
 
-    conn_fks = [
+    conn_fks_list = [
         fk
         for fk in inspector.get_foreign_keys(tname, schema=schema)
         if autogen_context.run_name_filters(
@@ -1255,9 +1257,13 @@ def _compare_foreign_keys(
         )
     ]
 
-    backend_reflects_fk_options = bool(conn_fks and "options" in conn_fks[0])
+    backend_reflects_fk_options = bool(
+        conn_fks_list and "options" in conn_fks_list[0]
+    )
 
-    conn_fks = set(_make_foreign_key(const, conn_table) for const in conn_fks)
+    conn_fks = set(
+        _make_foreign_key(const, conn_table) for const in conn_fks_list
+    )
 
     # give the dialect a chance to correct the FKs to match more
     # closely
@@ -1265,24 +1271,24 @@ def _compare_foreign_keys(
         conn_fks, metadata_fks
     )
 
-    metadata_fks = set(
+    metadata_fks_sig = set(
         _fk_constraint_sig(fk, include_options=backend_reflects_fk_options)
         for fk in metadata_fks
     )
 
-    conn_fks = set(
+    conn_fks_sig = set(
         _fk_constraint_sig(fk, include_options=backend_reflects_fk_options)
         for fk in conn_fks
     )
 
-    conn_fks_by_sig = dict((c.sig, c) for c in conn_fks)
-    metadata_fks_by_sig = dict((c.sig, c) for c in metadata_fks)
+    conn_fks_by_sig = dict((c.sig, c) for c in conn_fks_sig)
+    metadata_fks_by_sig = dict((c.sig, c) for c in metadata_fks_sig)
 
     metadata_fks_by_name = dict(
-        (c.name, c) for c in metadata_fks if c.name is not None
+        (c.name, c) for c in metadata_fks_sig if c.name is not None
     )
     conn_fks_by_name = dict(
-        (c.name, c) for c in conn_fks if c.name is not None
+        (c.name, c) for c in conn_fks_sig if c.name is not None
     )
 
     def _add_fk(obj, compare_to):
index 77f0a8666edd62d1180138e84cc34befe7609bb6..9c992b47cdfdcf495cdc50b65a90e36e0d67c9e6 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from collections import OrderedDict
 from io import StringIO
 import re
index 0fdd3982776783154ceca60a735e9199105a3665..79f665a052b41cdb3a17a5e6c31b510669cc3a5e 100644 (file)
@@ -1,7 +1,10 @@
+from __future__ import annotations
+
 from typing import Any
 from typing import Callable
 from typing import Iterator
 from typing import List
+from typing import Optional
 from typing import Type
 from typing import TYPE_CHECKING
 from typing import Union
@@ -49,12 +52,12 @@ class Rewriter:
 
     _traverse = util.Dispatcher()
 
-    _chained = None
+    _chained: Optional[Rewriter] = None
 
     def __init__(self) -> None:
         self.dispatch = util.Dispatcher()
 
-    def chain(self, other: "Rewriter") -> "Rewriter":
+    def chain(self, other: Rewriter) -> Rewriter:
         """Produce a "chain" of this :class:`.Rewriter` to another.
 
         This allows two rewriters to operate serially on a stream,
index c1117244b9b95465a87c14759757eda5fbbd6f1f..d48affc738fa69cf3ec77e5fc452c45d5625f601 100644 (file)
@@ -1,6 +1,7 @@
+from __future__ import annotations
+
 import os
 from typing import Callable
-from typing import cast
 from typing import List
 from typing import Optional
 from typing import TYPE_CHECKING
@@ -86,8 +87,9 @@ def init(
     for file_ in os.listdir(template_dir):
         file_path = os.path.join(template_dir, file_)
         if file_ == "alembic.ini.mako":
-            config_file = os.path.abspath(cast(str, config.config_file_name))
-            if os.access(cast(str, config_file), os.F_OK):
+            assert config.config_file_name is not None
+            config_file = os.path.abspath(config.config_file_name)
+            if os.access(config_file, os.F_OK):
                 util.msg("File %s already exists, skipping" % config_file)
             else:
                 script._generate_template(
@@ -273,7 +275,7 @@ def merge(
         refresh=True,
         head=revisions,
         branch_labels=branch_label,
-        **template_args  # type:ignore[arg-type]
+        **template_args,  # type:ignore[arg-type]
     )
 
 
@@ -642,6 +644,7 @@ def edit(config: "Config", rev: str) -> None:
                 "No revision files indicated by symbol '%s'" % rev
             )
         for sc in revs:
+            assert sc
             util.open_in_editor(sc.path)
 
 
index f868bf7375a918beb55ddc1b8b5f3ecac6666ff9..dcfb928862c46b9591e25abbc5170187012fbb8f 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from argparse import ArgumentParser
 from argparse import Namespace
 from configparser import ConfigParser
@@ -559,7 +561,7 @@ class CommandLine:
             fn(
                 config,
                 *[getattr(options, k, None) for k in positional],
-                **dict((k, getattr(options, k, None)) for k in kwarg)
+                **dict((k, getattr(options, k, None)) for k in kwarg),
             )
         except util.CommandError as e:
             if options.raiseerr:
index 5c29d3aef0dd350df732a4bf795c1e0e5e5c3470..5ec4703733a7c405cc4bcc2d460933f88c373a14 100644 (file)
@@ -1,5 +1,6 @@
 # ### this file stubs are generated by tools/write_pyi.py - do not edit ###
 # ### imports are manually managed
+from __future__ import annotations
 
 from typing import Callable
 from typing import ContextManager
@@ -67,7 +68,7 @@ def begin_transaction() -> Union["_ProxyTransaction", ContextManager]:
 config: Config
 
 def configure(
-    connection: Optional["Connection"] = None,
+    connection: Optional[Connection] = None,
     url: Optional[str] = None,
     dialect_name: Optional[str] = None,
     dialect_opts: Optional[dict] = None,
@@ -78,7 +79,7 @@ def configure(
     tag: Optional[str] = None,
     template_args: Optional[dict] = None,
     render_as_batch: bool = False,
-    target_metadata: Optional["MetaData"] = None,
+    target_metadata: Optional[MetaData] = None,
     include_name: Optional[Callable] = None,
     include_object: Optional[Callable] = None,
     include_schemas: bool = False,
@@ -93,7 +94,7 @@ def configure(
     sqlalchemy_module_prefix: str = "sa.",
     user_module_prefix: Optional[str] = None,
     on_version_apply: Optional[Callable] = None,
-    **kw
+    **kw,
 ) -> None:
     """Configure a :class:`.MigrationContext` within this
     :class:`.EnvironmentContext` which will provide database
@@ -553,7 +554,7 @@ def get_bind():
 
     """
 
-def get_context() -> "MigrationContext":
+def get_context() -> MigrationContext:
     """Return the current :class:`.MigrationContext` object.
 
     If :meth:`.EnvironmentContext.configure` has not been
index 022dc244d3294dad6da6f4d75713f1116aada4f0..7b0f63e811efb6b380e8278bd8b69674abe46ad5 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import functools
 from typing import Optional
 from typing import TYPE_CHECKING
@@ -114,7 +116,7 @@ class ColumnDefault(AlterColumn):
         name: str,
         column_name: str,
         default: Optional[_ServerDefault],
-        **kw
+        **kw,
     ) -> None:
         super(ColumnDefault, self).__init__(name, column_name, **kw)
         self.default = default
@@ -135,7 +137,7 @@ class IdentityColumnDefault(AlterColumn):
         column_name: str,
         default: Optional["Identity"],
         impl: "DefaultImpl",
-        **kw
+        **kw,
     ) -> None:
         super(IdentityColumnDefault, self).__init__(name, column_name, **kw)
         self.default = default
index 10dcc7344cf4acf5fed16652731ed629b3c0afef..8c9e0b91fb138f21ef405e6f34cbea52910b5d4e 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from collections import namedtuple
 import re
 from typing import Any
@@ -215,7 +217,7 @@ class DefaultImpl(metaclass=ImplMeta):
         existing_server_default: Optional["_ServerDefault"] = None,
         existing_nullable: Optional[bool] = None,
         existing_autoincrement: Optional[bool] = None,
-        **kw: Any
+        **kw: Any,
     ) -> None:
         if autoincrement is not None or existing_autoincrement is not None:
             util.warn(
@@ -266,7 +268,7 @@ class DefaultImpl(metaclass=ImplMeta):
                     existing_server_default=existing_server_default,
                     existing_nullable=existing_nullable,
                     existing_comment=existing_comment,
-                    **kw
+                    **kw,
                 )
             )
         if type_ is not None:
@@ -324,7 +326,7 @@ class DefaultImpl(metaclass=ImplMeta):
         table_name: str,
         column: "Column",
         schema: Optional[str] = None,
-        **kw
+        **kw,
     ) -> None:
         self._exec(base.DropColumn(table_name, column, schema=schema))
 
index 4ea671801a541ec9aeb2de8bbaf4ee2ffec7f2f3..b48f8ba988ebb0f6181b43786b3fff89454fcf0e 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from typing import Any
 from typing import List
 from typing import Optional
@@ -96,7 +98,7 @@ class MSSQLImpl(DefaultImpl):
         existing_type: Optional["TypeEngine"] = None,
         existing_server_default: Optional["_ServerDefault"] = None,
         existing_nullable: Optional[bool] = None,
-        **kw: Any
+        **kw: Any,
     ) -> None:
 
         if nullable is not None:
@@ -145,7 +147,7 @@ class MSSQLImpl(DefaultImpl):
             schema=schema,
             existing_type=existing_type,
             existing_nullable=existing_nullable,
-            **kw
+            **kw,
         )
 
         if server_default is not False and used_default is False:
@@ -203,7 +205,7 @@ class MSSQLImpl(DefaultImpl):
         table_name: str,
         column: "Column",
         schema: Optional[str] = None,
-        **kw
+        **kw,
     ) -> None:
         drop_default = kw.pop("mssql_drop_default", False)
         if drop_default:
index c3d66bdf2ac6edc4029c5f79c46c40a4eb7b4495..0c03fbe1121489d459937f6319ec4e364465c714 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import re
 from typing import Any
 from typing import Optional
@@ -60,7 +62,7 @@ class MySQLImpl(DefaultImpl):
         existing_autoincrement: Optional[bool] = None,
         comment: Optional[Union[str, "Literal[False]"]] = False,
         existing_comment: Optional[str] = None,
-        **kw: Any
+        **kw: Any,
     ) -> None:
         if sqla_compat._server_default_is_identity(
             server_default, existing_server_default
@@ -79,7 +81,7 @@ class MySQLImpl(DefaultImpl):
                 existing_nullable=existing_nullable,
                 server_default=server_default,
                 existing_server_default=existing_server_default,
-                **kw
+                **kw,
             )
         if name is not None or self._is_mysql_allowed_functional_default(
             type_ if type_ is not None else existing_type, server_default
index 6dff65145e1a072ef6bf5434784cc804b6b1ce45..0e787fb1cd1483cde00b0a38a53cac432e3ba247 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from typing import Any
 from typing import Optional
 from typing import TYPE_CHECKING
index 6174f382a05026e872cb82e1c2e68464cf7c3c4a..019eb3c74c6d5bc4de582ef23547c91f5ba2494a 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import logging
 import re
 from typing import Any
@@ -143,7 +145,7 @@ class PostgresqlImpl(DefaultImpl):
         existing_server_default: Optional["_ServerDefault"] = None,
         existing_nullable: Optional[bool] = None,
         existing_autoincrement: Optional[bool] = None,
-        **kw: Any
+        **kw: Any,
     ) -> None:
 
         using = kw.pop("postgresql_using", None)
@@ -179,7 +181,7 @@ class PostgresqlImpl(DefaultImpl):
             existing_server_default=existing_server_default,
             existing_nullable=existing_nullable,
             existing_autoincrement=existing_autoincrement,
-            **kw
+            **kw,
         )
 
     def autogen_column_reflect(self, inspector, table, column_info):
@@ -417,7 +419,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
         where: Optional[Union["BinaryExpression", str]] = None,
         schema: Optional[str] = None,
         _orig_constraint: Optional["ExcludeConstraint"] = None,
-        **kw
+        **kw,
     ) -> None:
         self.constraint_name = constraint_name
         self.table_name = table_name
@@ -459,7 +461,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
             *self.elements,
             name=self.constraint_name,
             where=self.where,
-            **self.kw
+            **self.kw,
         )
         for (
             expr,
@@ -477,7 +479,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
         constraint_name: str,
         table_name: str,
         *elements: Any,
-        **kw: Any
+        **kw: Any,
     ) -> Optional["Table"]:
         """Issue an alter to create an EXCLUDE constraint using the
         current migration context.
index f916147261d58b71261997dcc504b285769235a6..9b387664c78d9f1a8e2de67dbfcd3f02b73be51b 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import re
 from typing import Any
 from typing import Dict
index d781ac1bcd86da6e0b19761a25485c6e014c1af1..9e3169ad3d77ac2aa8000cd4a1e94d928d42ba54 100644 (file)
@@ -33,8 +33,8 @@ if TYPE_CHECKING:
 ### end imports ###
 
 def add_column(
-    table_name: str, column: "Column", schema: Optional[str] = None
-) -> Optional["Table"]:
+    table_name: str, column: Column, schema: Optional[str] = None
+) -> Optional[Table]:
     """Issue an "add column" instruction using the current
     migration context.
 
@@ -91,16 +91,16 @@ def alter_column(
     comment: Union[str, bool, None] = False,
     server_default: Any = False,
     new_column_name: Optional[str] = None,
-    type_: Union["TypeEngine", Type["TypeEngine"], None] = None,
-    existing_type: Union["TypeEngine", Type["TypeEngine"], None] = None,
+    type_: Union[TypeEngine, Type[TypeEngine], None] = None,
+    existing_type: Union[TypeEngine, Type[TypeEngine], None] = None,
     existing_server_default: Union[
-        str, bool, "Identity", "Computed", None
+        str, bool, Identity, Computed, None
     ] = False,
     existing_nullable: Optional[bool] = None,
     existing_comment: Optional[str] = None,
     schema: Optional[str] = None,
     **kw
-) -> Optional["Table"]:
+) -> Optional[Table]:
     """Issue an "alter column" instruction using the
     current migration context.
 
@@ -340,7 +340,7 @@ def batch_alter_table(
     """
 
 def bulk_insert(
-    table: Union["Table", "TableClause"],
+    table: Union[Table, TableClause],
     rows: List[dict],
     multiinsert: bool = True,
 ) -> None:
@@ -422,10 +422,10 @@ def bulk_insert(
 def create_check_constraint(
     constraint_name: Optional[str],
     table_name: str,
-    condition: Union[str, "BinaryExpression"],
+    condition: Union[str, BinaryExpression],
     schema: Optional[str] = None,
     **kw
-) -> Optional["Table"]:
+) -> Optional[Table]:
     """Issue a "create check constraint" instruction using the
     current migration context.
 
@@ -469,7 +469,7 @@ def create_check_constraint(
 
 def create_exclude_constraint(
     constraint_name: str, table_name: str, *elements: Any, **kw: Any
-) -> Optional["Table"]:
+) -> Optional[Table]:
     """Issue an alter to create an EXCLUDE constraint using the
     current migration context.
 
@@ -521,7 +521,7 @@ def create_foreign_key(
     source_schema: Optional[str] = None,
     referent_schema: Optional[str] = None,
     **dialect_kw
-) -> Optional["Table"]:
+) -> Optional[Table]:
     """Issue a "create foreign key" instruction using the
     current migration context.
 
@@ -570,11 +570,11 @@ def create_foreign_key(
 def create_index(
     index_name: str,
     table_name: str,
-    columns: Sequence[Union[str, "TextClause", "Function"]],
+    columns: Sequence[Union[str, TextClause, Function]],
     schema: Optional[str] = None,
     unique: bool = False,
     **kw
-) -> Optional["Table"]:
+) -> Optional[Table]:
     """Issue a "create index" instruction using the current
     migration context.
 
@@ -622,7 +622,7 @@ def create_primary_key(
     table_name: str,
     columns: List[str],
     schema: Optional[str] = None,
-) -> Optional["Table"]:
+) -> Optional[Table]:
     """Issue a "create primary key" instruction using the current
     migration context.
 
@@ -660,7 +660,7 @@ def create_primary_key(
 
     """
 
-def create_table(table_name: str, *columns, **kw) -> Optional["Table"]:
+def create_table(table_name: str, *columns, **kw) -> Optional[Table]:
     """Issue a "create table" instruction using the current migration
     context.
 
@@ -743,7 +743,7 @@ def create_table_comment(
     comment: Optional[str],
     existing_comment: None = None,
     schema: Optional[str] = None,
-) -> Optional["Table"]:
+) -> Optional[Table]:
     """Emit a COMMENT ON operation to set the comment for a table.
 
     .. versionadded:: 1.0.6
@@ -811,7 +811,7 @@ def create_unique_constraint(
 
 def drop_column(
     table_name: str, column_name: str, schema: Optional[str] = None, **kw
-) -> Optional["Table"]:
+) -> Optional[Table]:
     """Issue a "drop column" instruction using the current
     migration context.
 
@@ -854,7 +854,7 @@ def drop_constraint(
     table_name: str,
     type_: Optional[str] = None,
     schema: Optional[str] = None,
-) -> Optional["Table"]:
+) -> Optional[Table]:
     """Drop a constraint of the given name, typically via DROP CONSTRAINT.
 
     :param constraint_name: name of the constraint.
@@ -873,7 +873,7 @@ def drop_index(
     table_name: Optional[str] = None,
     schema: Optional[str] = None,
     **kw
-) -> Optional["Table"]:
+) -> Optional[Table]:
     """Issue a "drop index" instruction using the current
     migration context.
 
@@ -921,7 +921,7 @@ def drop_table_comment(
     table_name: str,
     existing_comment: Optional[str] = None,
     schema: Optional[str] = None,
-) -> Optional["Table"]:
+) -> Optional[Table]:
     """Issue a "drop table comment" operation to
     remove an existing comment set on a table.
 
@@ -940,8 +940,8 @@ def drop_table_comment(
     """
 
 def execute(
-    sqltext: Union[str, "TextClause", "Update"], execution_options: None = None
-) -> Optional["Table"]:
+    sqltext: Union[str, TextClause, Update], execution_options: None = None
+) -> Optional[Table]:
     """Execute the given SQL using the current migration context.
 
     The given SQL can be a plain string, e.g.::
@@ -1024,7 +1024,7 @@ def execute(
      :meth:`sqlalchemy.engine.Connection.execution_options`.
     """
 
-def f(name: str) -> "conv":
+def f(name: str) -> conv:
     """Indicate a string name that has already had a naming convention
     applied to it.
 
@@ -1061,7 +1061,7 @@ def f(name: str) -> "conv":
 
     """
 
-def get_bind() -> "Connection":
+def get_bind() -> Connection:
     """Return the current 'bind'.
 
     Under normal circumstances, this is the
@@ -1134,7 +1134,7 @@ def inline_literal(
 
     """
 
-def invoke(operation: "MigrateOperation") -> Any:
+def invoke(operation: MigrateOperation) -> Any:
     """Given a :class:`.MigrateOperation`, invoke it in terms of
     this :class:`.Operations` instance.
 
@@ -1161,7 +1161,7 @@ def register_operation(
 
 def rename_table(
     old_table_name: str, new_table_name: str, schema: Optional[str] = None
-) -> Optional["Table"]:
+) -> Optional[Table]:
     """Emit an ALTER TABLE to rename a table.
 
     :param old_table_name: old name.
index 07ddd5afd1cd675338aff5948d8eee31958ee019..68b620fad7b9479c078a0be717877d59c49a6e46 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from contextlib import contextmanager
 import re
 import textwrap
index ba1d1967cf55b563194f6d82b129778844b32891..308bc2e8a4afbba1a88114b5ee48ba41cf5f92dd 100644 (file)
@@ -1,5 +1,6 @@
+from __future__ import annotations
+
 from typing import Any
-from typing import cast
 from typing import Dict
 from typing import List
 from typing import Optional
@@ -122,7 +123,7 @@ class BatchOperationsImpl:
                         schema=self.schema,
                         autoload_with=self.operations.get_bind(),
                         *self.reflect_args,
-                        **self.reflect_kwargs
+                        **self.reflect_kwargs,
                     )
                     reflected = True
 
@@ -311,7 +312,7 @@ class ApplyBatchImpl:
             m,
             *(list(self.columns.values()) + list(self.table_args)),
             schema=schema,
-            **self.table_kwargs
+            **self.table_kwargs,
         )
 
         for const in (
@@ -360,7 +361,7 @@ class ApplyBatchImpl:
                     index.name,
                     unique=index.unique,
                     *[self.new_table.c[col] for col in index.columns.keys()],
-                    **index.kwargs
+                    **index.kwargs,
                 )
             )
         return idx
@@ -401,7 +402,7 @@ class ApplyBatchImpl:
                             for elem in constraint.elements
                         ]
                     ],
-                    schema=referent_schema
+                    schema=referent_schema,
                 )
 
     def _create(self, op_impl: "DefaultImpl") -> None:
@@ -453,7 +454,7 @@ class ApplyBatchImpl:
         type_: Optional["TypeEngine"] = None,
         autoincrement: None = None,
         comment: Union[str, "Literal[False]"] = False,
-        **kw
+        **kw,
     ) -> None:
         existing = self.columns[column_name]
         existing_transfer: Dict[str, Any] = self.column_transfers[column_name]
@@ -574,7 +575,7 @@ class ApplyBatchImpl:
         column: "Column",
         insert_before: Optional[str] = None,
         insert_after: Optional[str] = None,
-        **kw
+        **kw,
     ) -> None:
         self._setup_dependencies_for_add_column(
             column.name, insert_before, insert_after
@@ -647,7 +648,8 @@ class ApplyBatchImpl:
                     if col_const.name == const.name:
                         self.columns[col.name].constraints.remove(col_const)
             else:
-                const = self.named_constraints.pop(cast(str, const.name))
+                assert const.name
+                const = self.named_constraints.pop(const.name)
         except KeyError:
             if _is_type_bound(const):
                 # type-bound constraints are only included in the new
index 99132dd661994465a12cfdea7775bfbf3e337af4..176c6ba6d181a7d5d728a145b7ba2eae6dc0d3a7 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from abc import abstractmethod
 import re
 from typing import Any
@@ -258,7 +260,7 @@ class CreatePrimaryKeyOp(AddConstraintOp):
         table_name: str,
         columns: Sequence[str],
         schema: Optional[str] = None,
-        **kw
+        **kw,
     ) -> None:
         self.constraint_name = constraint_name
         self.table_name = table_name
@@ -383,7 +385,7 @@ class CreateUniqueConstraintOp(AddConstraintOp):
         table_name: str,
         columns: Sequence[str],
         schema: Optional[str] = None,
-        **kw
+        **kw,
     ) -> None:
         self.constraint_name = constraint_name
         self.table_name = table_name
@@ -434,7 +436,7 @@ class CreateUniqueConstraintOp(AddConstraintOp):
         table_name: str,
         columns: Sequence[str],
         schema: Optional[str] = None,
-        **kw
+        **kw,
     ) -> Any:
         """Issue a "create unique constraint" instruction using the
         current migration context.
@@ -483,7 +485,7 @@ class CreateUniqueConstraintOp(AddConstraintOp):
         operations: "BatchOperations",
         constraint_name: str,
         columns: Sequence[str],
-        **kw
+        **kw,
     ) -> Any:
         """Issue a "create unique constraint" instruction using the
         current batch migration context.
@@ -518,7 +520,7 @@ class CreateForeignKeyOp(AddConstraintOp):
         referent_table: str,
         local_cols: List[str],
         remote_cols: List[str],
-        **kw
+        **kw,
     ) -> None:
         self.constraint_name = constraint_name
         self.source_table = source_table
@@ -600,7 +602,7 @@ class CreateForeignKeyOp(AddConstraintOp):
         match: Optional[str] = None,
         source_schema: Optional[str] = None,
         referent_schema: Optional[str] = None,
-        **dialect_kw
+        **dialect_kw,
     ) -> Optional["Table"]:
         """Issue a "create foreign key" instruction using the
         current migration context.
@@ -678,7 +680,7 @@ class CreateForeignKeyOp(AddConstraintOp):
         deferrable: None = None,
         initially: None = None,
         match: None = None,
-        **dialect_kw
+        **dialect_kw,
     ) -> None:
         """Issue a "create foreign key" instruction using the
         current batch migration context.
@@ -734,7 +736,7 @@ class CreateCheckConstraintOp(AddConstraintOp):
         table_name: str,
         condition: Union[str, "TextClause", "ColumnElement[Any]"],
         schema: Optional[str] = None,
-        **kw
+        **kw,
     ) -> None:
         self.constraint_name = constraint_name
         self.table_name = table_name
@@ -753,9 +755,7 @@ class CreateCheckConstraintOp(AddConstraintOp):
         return cls(
             ck_constraint.name,
             constraint_table.name,
-            cast(
-                "Union[TextClause, ColumnElement[Any]]", ck_constraint.sqltext
-            ),
+            cast("ColumnElement[Any]", ck_constraint.sqltext),
             schema=constraint_table.schema,
             **ck_constraint.dialect_kwargs,
         )
@@ -780,7 +780,7 @@ class CreateCheckConstraintOp(AddConstraintOp):
         table_name: str,
         condition: Union[str, "BinaryExpression"],
         schema: Optional[str] = None,
-        **kw
+        **kw,
     ) -> Optional["Table"]:
         """Issue a "create check constraint" instruction using the
         current migration context.
@@ -831,7 +831,7 @@ class CreateCheckConstraintOp(AddConstraintOp):
         operations: "BatchOperations",
         constraint_name: str,
         condition: "TextClause",
-        **kw
+        **kw,
     ) -> Optional["Table"]:
         """Issue a "create check constraint" instruction using the
         current batch migration context.
@@ -866,7 +866,7 @@ class CreateIndexOp(MigrateOperation):
         columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]],
         schema: Optional[str] = None,
         unique: bool = False,
-        **kw
+        **kw,
     ) -> None:
         self.index_name = index_name
         self.table_name = table_name
@@ -917,7 +917,7 @@ class CreateIndexOp(MigrateOperation):
         columns: Sequence[Union[str, "TextClause", "Function"]],
         schema: Optional[str] = None,
         unique: bool = False,
-        **kw
+        **kw,
     ) -> Optional["Table"]:
         r"""Issue a "create index" instruction using the current
         migration context.
@@ -971,7 +971,7 @@ class CreateIndexOp(MigrateOperation):
         operations: "BatchOperations",
         index_name: str,
         columns: List[str],
-        **kw
+        **kw,
     ) -> Optional["Table"]:
         """Issue a "create index" instruction using the
         current batch migration context.
@@ -1003,7 +1003,7 @@ class DropIndexOp(MigrateOperation):
         table_name: Optional[str] = None,
         schema: Optional[str] = None,
         _reverse: Optional["CreateIndexOp"] = None,
-        **kw
+        **kw,
     ) -> None:
         self.index_name = index_name
         self.table_name = table_name
@@ -1050,7 +1050,7 @@ class DropIndexOp(MigrateOperation):
         index_name: str,
         table_name: Optional[str] = None,
         schema: Optional[str] = None,
-        **kw
+        **kw,
     ) -> Optional["Table"]:
         r"""Issue a "drop index" instruction using the current
         migration context.
@@ -1109,7 +1109,7 @@ class CreateTableOp(MigrateOperation):
         schema: Optional[str] = None,
         _namespace_metadata: Optional["MetaData"] = None,
         _constraints_included: bool = False,
-        **kw
+        **kw,
     ) -> None:
         self.table_name = table_name
         self.columns = columns
@@ -1326,7 +1326,7 @@ class DropTableOp(MigrateOperation):
         operations: "Operations",
         table_name: str,
         schema: Optional[str] = None,
-        **kw: Any
+        **kw: Any,
     ) -> None:
         r"""Issue a "drop table" instruction using the current
         migration context.
@@ -1607,7 +1607,7 @@ class AlterColumnOp(AlterTableOp):
         modify_server_default: Any = False,
         modify_name: Optional[str] = None,
         modify_type: Optional[Any] = None,
-        **kw
+        **kw,
     ) -> None:
         super(AlterColumnOp, self).__init__(table_name, schema=schema)
         self.column_name = column_name
@@ -1770,7 +1770,7 @@ class AlterColumnOp(AlterTableOp):
         existing_nullable: Optional[bool] = None,
         existing_comment: Optional[str] = None,
         schema: Optional[str] = None,
-        **kw
+        **kw,
     ) -> Optional["Table"]:
         r"""Issue an "alter column" instruction using the
         current migration context.
@@ -1897,7 +1897,7 @@ class AlterColumnOp(AlterTableOp):
         existing_comment: None = None,
         insert_before: None = None,
         insert_after: None = None,
-        **kw
+        **kw,
     ) -> Optional["Table"]:
         """Issue an "alter column" instruction using the current
         batch migration context.
@@ -1954,7 +1954,7 @@ class AddColumnOp(AlterTableOp):
         table_name: str,
         column: "Column",
         schema: Optional[str] = None,
-        **kw
+        **kw,
     ) -> None:
         super(AddColumnOp, self).__init__(table_name, schema=schema)
         self.column = column
@@ -2089,7 +2089,7 @@ class DropColumnOp(AlterTableOp):
         column_name: str,
         schema: Optional[str] = None,
         _reverse: Optional["AddColumnOp"] = None,
-        **kw
+        **kw,
     ) -> None:
         super(DropColumnOp, self).__init__(table_name, schema=schema)
         self.column_name = column_name
@@ -2146,7 +2146,7 @@ class DropColumnOp(AlterTableOp):
         table_name: str,
         column_name: str,
         schema: Optional[str] = None,
-        **kw
+        **kw,
     ) -> Optional["Table"]:
         """Issue a "drop column" instruction using the current
         migration context.
index 0a27920b285b72f08bdb4fa84eb9f6fbac99b600..6c6f9714f6924c8e904f4fc0d5e9748cc347d75e 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from typing import Any
 from typing import Dict
 from typing import List
@@ -44,7 +46,7 @@ class SchemaObjects:
         table_name: str,
         cols: Sequence[str],
         schema: Optional[str] = None,
-        **dialect_kw
+        **dialect_kw,
     ) -> "PrimaryKeyConstraint":
         m = self.metadata()
         columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
@@ -68,7 +70,7 @@ class SchemaObjects:
         referent_schema: Optional[str] = None,
         initially: Optional[str] = None,
         match: Optional[str] = None,
-        **dialect_kw
+        **dialect_kw,
     ) -> "ForeignKeyConstraint":
         m = self.metadata()
         if source == referent and source_schema == referent_schema:
@@ -79,14 +81,14 @@ class SchemaObjects:
                 referent,
                 m,
                 *[sa_schema.Column(n, NULLTYPE) for n in remote_cols],
-                schema=referent_schema
+                schema=referent_schema,
             )
 
         t1 = sa_schema.Table(
             source,
             m,
             *[sa_schema.Column(n, NULLTYPE) for n in t1_cols],
-            schema=source_schema
+            schema=source_schema,
         )
 
         tname = (
@@ -105,7 +107,7 @@ class SchemaObjects:
             ondelete=ondelete,
             deferrable=deferrable,
             initially=initially,
-            **dialect_kw
+            **dialect_kw,
         )
         t1.append_constraint(f)
 
@@ -117,13 +119,13 @@ class SchemaObjects:
         source: str,
         local_cols: Sequence[str],
         schema: Optional[str] = None,
-        **kw
+        **kw,
     ) -> "UniqueConstraint":
         t = sa_schema.Table(
             source,
             self.metadata(),
             *[sa_schema.Column(n, NULLTYPE) for n in local_cols],
-            schema=schema
+            schema=schema,
         )
         kw["name"] = name
         uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw)
@@ -138,7 +140,7 @@ class SchemaObjects:
         source: str,
         condition: Union[str, "TextClause", "ColumnElement[Any]"],
         schema: Optional[str] = None,
-        **kw
+        **kw,
     ) -> Union["CheckConstraint"]:
         t = sa_schema.Table(
             source,
@@ -156,7 +158,7 @@ class SchemaObjects:
         table_name: str,
         type_: Optional[str],
         schema: Optional[str] = None,
-        **kw
+        **kw,
     ) -> Any:
         t = self.table(table_name, schema=schema)
         types: Dict[Optional[str], Any] = {
@@ -237,7 +239,7 @@ class SchemaObjects:
         tablename: Optional[str],
         columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]],
         schema: Optional[str] = None,
-        **kw
+        **kw,
     ) -> "Index":
         t = sa_schema.Table(
             tablename or "no_table",
@@ -248,7 +250,7 @@ class SchemaObjects:
         idx = sa_schema.Index(
             name,
             *[util.sqla_compat._textual_index_column(t, n) for n in columns],
-            **kw
+            **kw,
         )
         return idx
 
index f3473de40c273626745f250c067220f8e543ec51..5fcee573575edd76b70e41cbd8d3548fce09300b 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from typing import Callable
 from typing import ContextManager
 from typing import Dict
@@ -345,7 +347,7 @@ class EnvironmentContext(util.ModuleClsProxy):
         sqlalchemy_module_prefix: str = "sa.",
         user_module_prefix: Optional[str] = None,
         on_version_apply: Optional[Callable] = None,
-        **kw
+        **kw,
     ) -> None:
         """Configure a :class:`.MigrationContext` within this
         :class:`.EnvironmentContext` which will provide database
index 466264efcfc706e98f06bf28865a51b173dc0856..c09c8e416a3599574a00068c2714c913c7c66c3b 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from contextlib import contextmanager
 import logging
 import sys
@@ -39,6 +41,7 @@ if TYPE_CHECKING:
     from ..config import Config
     from ..script.base import Script
     from ..script.base import ScriptDirectory
+    from ..script.revision import _RevisionOrBase
     from ..script.revision import Revision
     from ..script.revision import RevisionMap
 
@@ -516,7 +519,7 @@ class MigrationContext:
             elif start_from_rev is not None and self.script:
 
                 start_from_rev = [
-                    self.script.get_revision(sfr).revision
+                    cast("Script", self.script.get_revision(sfr)).revision
                     for sfr in util.to_list(start_from_rev)
                     if sfr not in (None, "base")
                 ]
@@ -860,15 +863,15 @@ class MigrationInfo:
 
     """
 
-    is_upgrade: bool = None  # type:ignore[assignment]
+    is_upgrade: bool
     """True/False: indicates whether this operation ascends or descends the
     version tree."""
 
-    is_stamp: bool = None  # type:ignore[assignment]
+    is_stamp: bool
     """True/False: indicates whether this operation is a stamp (i.e. whether
     it results in any actual database operations)."""
 
-    up_revision_id: Optional[str] = None
+    up_revision_id: Optional[str]
     """Version string corresponding to :attr:`.Revision.revision`.
 
     In the case of a stamp operation, it is advised to use the
@@ -882,10 +885,10 @@ class MigrationInfo:
 
     """
 
-    up_revision_ids: Tuple[str, ...] = None  # type:ignore[assignment]
+    up_revision_ids: Tuple[str, ...]
     """Tuple of version strings corresponding to :attr:`.Revision.revision`.
 
-    In the majority of cases, this tuple will be a single value, synonomous
+    In the majority of cases, this tuple will be a single value, synonymous
     with the scalar value of :attr:`.MigrationInfo.up_revision_id`.
     It can be multiple revision identifiers only in the case of an
     ``alembic stamp`` operation which is moving downwards from multiple
@@ -893,7 +896,7 @@ class MigrationInfo:
 
     """
 
-    down_revision_ids: Tuple[str, ...] = None  # type:ignore[assignment]
+    down_revision_ids: Tuple[str, ...]
     """Tuple of strings representing the base revisions of this migration step.
 
     If empty, this represents a root revision; otherwise, the first item
@@ -901,7 +904,7 @@ class MigrationInfo:
     from dependencies.
     """
 
-    revision_map: "RevisionMap" = None  # type:ignore[assignment]
+    revision_map: "RevisionMap"
     """The revision map inside of which this operation occurs."""
 
     def __init__(
@@ -950,7 +953,7 @@ class MigrationInfo:
         )
 
     @property
-    def up_revision(self) -> "Revision":
+    def up_revision(self) -> Optional[Revision]:
         """Get :attr:`~.MigrationInfo.up_revision_id` as
         a :class:`.Revision`.
 
@@ -958,25 +961,25 @@ class MigrationInfo:
         return self.revision_map.get_revision(self.up_revision_id)
 
     @property
-    def up_revisions(self) -> Tuple["Revision", ...]:
+    def up_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]:
         """Get :attr:`~.MigrationInfo.up_revision_ids` as a
         :class:`.Revision`."""
         return self.revision_map.get_revisions(self.up_revision_ids)
 
     @property
-    def down_revisions(self) -> Tuple["Revision", ...]:
+    def down_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]:
         """Get :attr:`~.MigrationInfo.down_revision_ids` as a tuple of
         :class:`Revisions <.Revision>`."""
         return self.revision_map.get_revisions(self.down_revision_ids)
 
     @property
-    def source_revisions(self) -> Tuple["Revision", ...]:
+    def source_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]:
         """Get :attr:`~MigrationInfo.source_revision_ids` as a tuple of
         :class:`Revisions <.Revision>`."""
         return self.revision_map.get_revisions(self.source_revision_ids)
 
     @property
-    def destination_revisions(self) -> Tuple["Revision", ...]:
+    def destination_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]:
         """Get :attr:`~MigrationInfo.destination_revision_ids` as a tuple of
         :class:`Revisions <.Revision>`."""
         return self.revision_map.get_revisions(self.destination_revision_ids)
@@ -1059,7 +1062,7 @@ class RevisionStep(MigrationStep):
         )
 
     @property
-    def doc(self):
+    def doc(self) -> str:
         return self.revision.doc
 
     @property
@@ -1264,7 +1267,7 @@ class StampStep(MigrationStep):
         self.migration_fn = self.stamp_revision
         self.revision_map = revision_map
 
-    doc = None
+    doc: None = None
 
     def stamp_revision(self, **kw) -> None:
         return None
index ef0fd52a1f87f9f0ca718d6e9a20cc080e2ad78f..ccbf86c97b0c452d99e553bc1e1285791fc5a1eb 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from contextlib import contextmanager
 import datetime
 import os
@@ -21,11 +23,13 @@ from . import revision
 from . import write_hooks
 from .. import util
 from ..runtime import migration
+from ..util import not_none
 
 if TYPE_CHECKING:
     from ..config import Config
     from ..runtime.migration import RevisionStep
     from ..runtime.migration import StampStep
+    from ..script.revision import Revision
 
 try:
     from dateutil import tz
@@ -112,7 +116,7 @@ class ScriptDirectory:
         else:
             return (os.path.abspath(os.path.join(self.dir, "versions")),)
 
-    def _load_revisions(self) -> Iterator["Script"]:
+    def _load_revisions(self) -> Iterator[Script]:
         if self.version_locations:
             paths = [
                 vers
@@ -139,7 +143,7 @@ class ScriptDirectory:
                 yield script
 
     @classmethod
-    def from_config(cls, config: "Config") -> "ScriptDirectory":
+    def from_config(cls, config: Config) -> ScriptDirectory:
         """Produce a new :class:`.ScriptDirectory` given a :class:`.Config`
         instance.
 
@@ -152,14 +156,16 @@ class ScriptDirectory:
             raise util.CommandError(
                 "No 'script_location' key " "found in configuration."
             )
-        truncate_slug_length = cast(
-            Optional[int], config.get_main_option("truncate_slug_length")
-        )
-        if truncate_slug_length is not None:
-            truncate_slug_length = int(truncate_slug_length)
+        truncate_slug_length: Optional[int]
+        tsl = config.get_main_option("truncate_slug_length")
+        if tsl is not None:
+            truncate_slug_length = int(tsl)
+        else:
+            truncate_slug_length = None
 
-        version_locations = config.get_main_option("version_locations")
-        if version_locations:
+        version_locations_str = config.get_main_option("version_locations")
+        version_locations: Optional[List[str]]
+        if version_locations_str:
             version_path_separator = config.get_main_option(
                 "version_path_separator"
             )
@@ -173,7 +179,9 @@ class ScriptDirectory:
             }
 
             try:
-                split_char = split_on_path[version_path_separator]
+                split_char: Optional[str] = split_on_path[
+                    version_path_separator
+                ]
             except KeyError as ke:
                 raise ValueError(
                     "'%s' is not a valid value for "
@@ -183,17 +191,15 @@ class ScriptDirectory:
             else:
                 if split_char is None:
                     # legacy behaviour for backwards compatibility
-                    vl = _split_on_space_comma.split(
-                        cast(str, version_locations)
+                    version_locations = _split_on_space_comma.split(
+                        version_locations_str
                     )
-                    version_locations: List[str] = vl  # type: ignore[no-redef]
                 else:
-                    vl = [
-                        x
-                        for x in cast(str, version_locations).split(split_char)
-                        if x
+                    version_locations = [
+                        x for x in version_locations_str.split(split_char) if x
                     ]
-                    version_locations: List[str] = vl  # type: ignore[no-redef]
+        else:
+            version_locations = None
 
         prepend_sys_path = config.get_main_option("prepend_sys_path")
         if prepend_sys_path:
@@ -209,7 +215,7 @@ class ScriptDirectory:
             truncate_slug_length=truncate_slug_length,
             sourceless=config.get_main_option("sourceless") == "true",
             output_encoding=config.get_main_option("output_encoding", "utf-8"),
-            version_locations=cast("Optional[List[str]]", version_locations),
+            version_locations=version_locations,
             timezone=config.get_main_option("timezone"),
             hook_config=config.get_section("post_write_hooks", {}),
         )
@@ -262,7 +268,7 @@ class ScriptDirectory:
 
     def walk_revisions(
         self, base: str = "base", head: str = "heads"
-    ) -> Iterator["Script"]:
+    ) -> Iterator[Script]:
         """Iterate through all revisions.
 
         :param base: the base revision, or "base" to start from the
@@ -279,25 +285,26 @@ class ScriptDirectory:
             ):
                 yield cast(Script, rev)
 
-    def get_revisions(self, id_: _RevIdType) -> Tuple["Script", ...]:
+    def get_revisions(self, id_: _RevIdType) -> Tuple[Optional[Script], ...]:
         """Return the :class:`.Script` instance with the given rev identifier,
         symbolic name, or sequence of identifiers.
 
         """
         with self._catch_revision_errors():
             return cast(
-                "Tuple[Script, ...]", self.revision_map.get_revisions(id_)
+                Tuple[Optional[Script], ...],
+                self.revision_map.get_revisions(id_),
             )
 
-    def get_all_current(self, id_: Tuple[str, ...]) -> Set["Script"]:
+    def get_all_current(self, id_: Tuple[str, ...]) -> Set[Optional[Script]]:
         with self._catch_revision_errors():
             top_revs = cast(
-                "Set[Script]",
+                Set[Optional[Script]],
                 set(self.revision_map.get_revisions(id_)),
             )
             top_revs.update(
                 cast(
-                    "Iterator[Script]",
+                    Iterator[Script],
                     self.revision_map._get_ancestor_nodes(
                         list(top_revs), include_dependencies=True
                     ),
@@ -306,7 +313,7 @@ class ScriptDirectory:
             top_revs = self.revision_map._filter_into_branch_heads(top_revs)
             return top_revs
 
-    def get_revision(self, id_: str) -> "Script":
+    def get_revision(self, id_: str) -> Optional[Script]:
         """Return the :class:`.Script` instance with the given rev id.
 
         .. seealso::
@@ -316,7 +323,7 @@ class ScriptDirectory:
         """
 
         with self._catch_revision_errors():
-            return cast(Script, self.revision_map.get_revision(id_))
+            return cast(Optional[Script], self.revision_map.get_revision(id_))
 
     def as_revision_number(
         self, id_: Optional[str]
@@ -335,7 +342,12 @@ class ScriptDirectory:
         else:
             return rev[0]
 
-    def iterate_revisions(self, upper, lower):
+    def iterate_revisions(
+        self,
+        upper: Union[str, Tuple[str, ...], None],
+        lower: Union[str, Tuple[str, ...], None],
+        **kw: Any,
+    ) -> Iterator[Script]:
         """Iterate through script revisions, starting at the given
         upper revision identifier and ending at the lower.
 
@@ -351,9 +363,12 @@ class ScriptDirectory:
             :meth:`.RevisionMap.iterate_revisions`
 
         """
-        return self.revision_map.iterate_revisions(upper, lower)
+        return cast(
+            Iterator[Script],
+            self.revision_map.iterate_revisions(upper, lower, **kw),
+        )
 
-    def get_current_head(self):
+    def get_current_head(self) -> Optional[str]:
         """Return the current head revision.
 
         If the script directory has multiple heads
@@ -423,36 +438,36 @@ class ScriptDirectory:
 
     def _upgrade_revs(
         self, destination: str, current_rev: str
-    ) -> List["RevisionStep"]:
+    ) -> List[RevisionStep]:
         with self._catch_revision_errors(
             ancestor="Destination %(end)s is not a valid upgrade "
             "target from current head(s)",
             end=destination,
         ):
-            revs = self.revision_map.iterate_revisions(
+            revs = self.iterate_revisions(
                 destination, current_rev, implicit_base=True
             )
             return [
                 migration.MigrationStep.upgrade_from_script(
-                    self.revision_map, cast(Script, script)
+                    self.revision_map, script
                 )
                 for script in reversed(list(revs))
             ]
 
     def _downgrade_revs(
         self, destination: str, current_rev: Optional[str]
-    ) -> List["RevisionStep"]:
+    ) -> List[RevisionStep]:
         with self._catch_revision_errors(
             ancestor="Destination %(end)s is not a valid downgrade "
             "target from current head(s)",
             end=destination,
         ):
-            revs = self.revision_map.iterate_revisions(
+            revs = self.iterate_revisions(
                 current_rev, destination, select_for_downgrade=True
             )
             return [
                 migration.MigrationStep.downgrade_from_script(
-                    self.revision_map, cast(Script, script)
+                    self.revision_map, script
                 )
                 for script in revs
             ]
@@ -472,12 +487,14 @@ class ScriptDirectory:
             if not revision:
                 revision = "base"
 
-            filtered_heads: List["Script"] = []
+            filtered_heads: List[Script] = []
             for rev in util.to_tuple(revision):
                 if rev:
                     filtered_heads.extend(
                         self.revision_map.filter_for_lineage(
-                            heads_revs, rev, include_dependencies=True
+                            cast(Sequence[Script], heads_revs),
+                            rev,
+                            include_dependencies=True,
                         )
                     )
             filtered_heads = util.unique_list(filtered_heads)
@@ -573,7 +590,7 @@ class ScriptDirectory:
             src,
             dest,
             self.output_encoding,
-            **kw
+            **kw,
         )
 
     def _copy_file(self, src: str, dest: str) -> None:
@@ -621,8 +638,8 @@ class ScriptDirectory:
         branch_labels: Optional[str] = None,
         version_path: Optional[str] = None,
         depends_on: Optional[_RevIdType] = None,
-        **kw: Any
-    ) -> Optional["Script"]:
+        **kw: Any,
+    ) -> Optional[Script]:
         """Generate a new revision file.
 
         This runs the ``script.py.mako`` template, given
@@ -656,7 +673,12 @@ class ScriptDirectory:
                 "or perform a merge."
             )
         ):
-            heads = self.revision_map.get_revisions(head)
+            heads = cast(
+                Tuple[Optional["Revision"], ...],
+                self.revision_map.get_revisions(head),
+            )
+            for h in heads:
+                assert h != "base"
 
         if len(set(heads)) != len(heads):
             raise util.CommandError("Duplicate head revisions specified")
@@ -702,17 +724,20 @@ class ScriptDirectory:
                         % head_.revision
                     )
 
+        resolved_depends_on: Optional[List[str]]
         if depends_on:
             with self._catch_revision_errors():
-                depends_on = [
+                resolved_depends_on = [
                     dep
                     if dep in rev.branch_labels  # maintain branch labels
                     else rev.revision  # resolve partial revision identifiers
                     for rev, dep in [
-                        (self.revision_map.get_revision(dep), dep)
+                        (not_none(self.revision_map.get_revision(dep)), dep)
                         for dep in util.to_list(depends_on)
                     ]
                 ]
+        else:
+            resolved_depends_on = None
 
         self._generate_template(
             os.path.join(self.dir, "script.py.mako"),
@@ -722,13 +747,11 @@ class ScriptDirectory:
                 tuple(h.revision if h is not None else None for h in heads)
             ),
             branch_labels=util.to_tuple(branch_labels),
-            depends_on=revision.tuple_rev_as_scalar(
-                cast("Optional[List[str]]", depends_on)
-            ),
+            depends_on=revision.tuple_rev_as_scalar(resolved_depends_on),
             create_date=create_date,
             comma=util.format_as_comma,
             message=message if message is not None else ("empty message"),
-            **kw
+            **kw,
         )
 
         post_write_hooks = self.hook_config
@@ -801,13 +824,13 @@ class Script(revision.Revision):
             ),
         )
 
-    module: ModuleType = None  # type: ignore[assignment]
+    module: ModuleType
     """The Python module representing the actual script itself."""
 
-    path: str = None  # type: ignore[assignment]
+    path: str
     """Filesystem path of the script."""
 
-    _db_current_indicator = None
+    _db_current_indicator: Optional[bool] = None
     """Utility variable which when set will cause string output to indicate
     this is a "current" version in some database"""
 
@@ -939,7 +962,7 @@ class Script(revision.Revision):
     @classmethod
     def _from_path(
         cls, scriptdir: ScriptDirectory, path: str
-    ) -> Optional["Script"]:
+    ) -> Optional[Script]:
         dir_, filename = os.path.split(path)
         return cls._from_filename(scriptdir, dir_, filename)
 
@@ -969,7 +992,7 @@ class Script(revision.Revision):
     @classmethod
     def _from_filename(
         cls, scriptdir: ScriptDirectory, dir_: str, filename: str
-    ) -> Optional["Script"]:
+    ) -> Optional[Script]:
         if scriptdir.sourceless:
             py_match = _sourceless_rev_file.match(filename)
         else:
index 2bfb7f9d44326489ee3341374f15164b20b3076e..335314f9c6ba2657ba6b2f20b431a99043755255 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import collections
 import re
 from typing import Any
@@ -11,6 +13,7 @@ from typing import Iterable
 from typing import Iterator
 from typing import List
 from typing import Optional
+from typing import overload
 from typing import Sequence
 from typing import Set
 from typing import Tuple
@@ -21,6 +24,7 @@ from typing import Union
 from sqlalchemy import util as sqlautil
 
 from .. import util
+from ..util import not_none
 
 if TYPE_CHECKING:
     from typing import Literal
@@ -439,7 +443,7 @@ class RevisionMap:
                     "Revision %s referenced from %s is not present"
                     % (downrev, revision)
                 )
-            cast("Revision", map_[downrev]).add_nextrev(revision)
+            not_none(map_[downrev]).add_nextrev(revision)
 
         self._normalize_depends_on(revisions, map_)
 
@@ -502,8 +506,8 @@ class RevisionMap:
         return self.filter_for_lineage(self.bases, identifier)
 
     def get_revisions(
-        self, id_: Union[str, Collection[str], None]
-    ) -> Tuple["Revision", ...]:
+        self, id_: Union[str, Collection[Optional[str]], None]
+    ) -> Tuple[Optional[_RevisionOrBase], ...]:
         """Return the :class:`.Revision` instances with the given rev id
         or identifiers.
 
@@ -537,7 +541,8 @@ class RevisionMap:
                             select_heads = tuple(
                                 head
                                 for head in select_heads
-                                if branch_label in head.branch_labels
+                                if branch_label
+                                in is_revision(head).branch_labels
                             )
                         return tuple(
                             self._walk(head, steps=rint)
@@ -551,7 +556,7 @@ class RevisionMap:
                 for rev_id in resolved_id
             )
 
-    def get_revision(self, id_: Optional[str]) -> "Revision":
+    def get_revision(self, id_: Optional[str]) -> Optional[Revision]:
         """Return the :class:`.Revision` instance with the given rev id.
 
         If a symbolic name such as "head" or "base" is given, resolves
@@ -568,12 +573,11 @@ class RevisionMap:
         resolved_id, branch_label = self._resolve_revision_number(id_)
         if len(resolved_id) > 1:
             raise MultipleHeads(resolved_id, id_)
-        elif resolved_id:
-            resolved_id = resolved_id[0]  # type:ignore[assignment]
 
-        return self._revision_for_ident(cast(str, resolved_id), branch_label)
+        resolved: Union[str, Tuple[()]] = resolved_id[0] if resolved_id else ()
+        return self._revision_for_ident(resolved, branch_label)
 
-    def _resolve_branch(self, branch_label: str) -> "Revision":
+    def _resolve_branch(self, branch_label: str) -> Optional[Revision]:
         try:
             branch_rev = self._revision_map[branch_label]
         except KeyError:
@@ -587,25 +591,28 @@ class RevisionMap:
             else:
                 return nonbranch_rev
         else:
-            return cast("Revision", branch_rev)
+            return branch_rev
 
     def _revision_for_ident(
-        self, resolved_id: str, check_branch: Optional[str] = None
-    ) -> "Revision":
-        branch_rev: Optional["Revision"]
+        self,
+        resolved_id: Union[str, Tuple[()]],
+        check_branch: Optional[str] = None,
+    ) -> Optional[Revision]:
+        branch_rev: Optional[Revision]
         if check_branch:
             branch_rev = self._resolve_branch(check_branch)
         else:
             branch_rev = None
 
-        revision: Union["Revision", "Literal[False]"]
+        revision: Union[Optional[Revision], "Literal[False]"]
         try:
-            revision = cast("Revision", self._revision_map[resolved_id])
+            revision = self._revision_map[resolved_id]
         except KeyError:
             # break out to avoid misleading py3k stack traces
             revision = False
         revs: Sequence[str]
         if revision is False:
+            assert resolved_id
             # do a partial lookup
             revs = [
                 x
@@ -637,11 +644,11 @@ class RevisionMap:
                     resolved_id,
                 )
             else:
-                revision = cast("Revision", self._revision_map[revs[0]])
+                revision = self._revision_map[revs[0]]
 
-        revision = cast("Revision", revision)
         if check_branch and revision is not None:
             assert branch_rev is not None
+            assert resolved_id
             if not self._shares_lineage(
                 revision.revision, branch_rev.revision
             ):
@@ -653,11 +660,12 @@ class RevisionMap:
         return revision
 
     def _filter_into_branch_heads(
-        self, targets: Set["Script"]
-    ) -> Set["Script"]:
+        self, targets: Set[Optional[Script]]
+    ) -> Set[Optional[Script]]:
         targets = set(targets)
 
         for rev in list(targets):
+            assert rev
             if targets.intersection(
                 self._get_descendant_nodes([rev], include_dependencies=False)
             ).difference([rev]):
@@ -695,9 +703,11 @@ class RevisionMap:
         if not test_against_revs:
             return True
         if not isinstance(target, Revision):
-            target = self._revision_for_ident(target)
+            resolved_target = not_none(self._revision_for_ident(target))
+        else:
+            resolved_target = target
 
-        test_against_revs = [
+        resolved_test_against_revs = [
             self._revision_for_ident(test_against_rev)
             if not isinstance(test_against_rev, Revision)
             else test_against_rev
@@ -709,15 +719,17 @@ class RevisionMap:
         return bool(
             set(
                 self._get_descendant_nodes(
-                    [target], include_dependencies=include_dependencies
+                    [resolved_target],
+                    include_dependencies=include_dependencies,
                 )
             )
             .union(
                 self._get_ancestor_nodes(
-                    [target], include_dependencies=include_dependencies
+                    [resolved_target],
+                    include_dependencies=include_dependencies,
                 )
             )
-            .intersection(test_against_revs)
+            .intersection(resolved_test_against_revs)
         )
 
     def _resolve_revision_number(
@@ -768,7 +780,7 @@ class RevisionMap:
         inclusive: bool = False,
         assert_relative_length: bool = True,
         select_for_downgrade: bool = False,
-    ) -> Iterator["Revision"]:
+    ) -> Iterator[Revision]:
         """Iterate through script revisions, starting at the given
         upper revision identifier and ending at the lower.
 
@@ -795,11 +807,11 @@ class RevisionMap:
         )
 
         for node in self._topological_sort(revisions, heads):
-            yield self.get_revision(node)
+            yield not_none(self.get_revision(node))
 
     def _get_descendant_nodes(
         self,
-        targets: Collection["Revision"],
+        targets: Collection[Revision],
         map_: Optional[_RevisionMapType] = None,
         check: bool = False,
         omit_immediate_dependencies: bool = False,
@@ -830,11 +842,11 @@ class RevisionMap:
 
     def _get_ancestor_nodes(
         self,
-        targets: Collection["Revision"],
+        targets: Collection[Optional[_RevisionOrBase]],
         map_: Optional[_RevisionMapType] = None,
         check: bool = False,
         include_dependencies: bool = True,
-    ) -> Iterator["Revision"]:
+    ) -> Iterator[Revision]:
 
         if include_dependencies:
 
@@ -853,17 +865,17 @@ class RevisionMap:
     def _iterate_related_revisions(
         self,
         fn: Callable,
-        targets: Collection["Revision"],
+        targets: Collection[Optional[_RevisionOrBase]],
         map_: Optional[_RevisionMapType],
         check: bool = False,
-    ) -> Iterator["Revision"]:
+    ) -> Iterator[Revision]:
         if map_ is None:
             map_ = self._revision_map
 
         seen = set()
-        todo: Deque["Revision"] = collections.deque()
-        for target in targets:
-
+        todo: Deque[Revision] = collections.deque()
+        for target_for in targets:
+            target = is_revision(target_for)
             todo.append(target)
             if check:
                 per_target = set()
@@ -902,7 +914,7 @@ class RevisionMap:
 
     def _topological_sort(
         self,
-        revisions: Collection["Revision"],
+        revisions: Collection[Revision],
         heads: Any,
     ) -> List[str]:
         """Yield revision ids of a collection of Revision objects in
@@ -1007,11 +1019,11 @@ class RevisionMap:
 
     def _walk(
         self,
-        start: Optional[Union[str, "Revision"]],
+        start: Optional[Union[str, Revision]],
         steps: int,
         branch_label: Optional[str] = None,
         no_overwalk: bool = True,
-    ) -> "Revision":
+    ) -> Optional[_RevisionOrBase]:
         """
         Walk the requested number of :steps up (steps > 0) or down (steps < 0)
         the revision tree.
@@ -1030,20 +1042,21 @@ class RevisionMap:
         else:
             initial = start
 
-        children: Sequence[_RevisionOrBase]
+        children: Sequence[Optional[_RevisionOrBase]]
         for _ in range(abs(steps)):
             if steps > 0:
+                assert initial != "base"
                 # Walk up
-                children = [
-                    rev
+                walk_up = [
+                    is_revision(rev)
                     for rev in self.get_revisions(
-                        self.bases
-                        if initial is None
-                        else cast("Revision", initial).nextrev
+                        self.bases if initial is None else initial.nextrev
                     )
                 ]
                 if branch_label:
-                    children = self.filter_for_lineage(children, branch_label)
+                    children = self.filter_for_lineage(walk_up, branch_label)
+                else:
+                    children = walk_up
             else:
                 # Walk down
                 if initial == "base":
@@ -1055,17 +1068,17 @@ class RevisionMap:
                         else initial.down_revision
                     )
                     if not children:
-                        children = cast("Tuple[Literal['base']]", ("base",))
+                        children = ("base",)
             if not children:
                 # This will return an invalid result if no_overwalk, otherwise
                 # further steps will stay where we are.
                 ret = None if no_overwalk else initial
-                return ret  # type:ignore[return-value]
+                return ret
             elif len(children) > 1:
                 raise RevisionError("Ambiguous walk")
             initial = children[0]
 
-        return cast("Revision", initial)
+        return initial
 
     def _parse_downgrade_target(
         self,
@@ -1170,7 +1183,7 @@ class RevisionMap:
         current_revisions: _RevisionIdentifierType,
         target: _RevisionIdentifierType,
         assert_relative_length: bool,
-    ) -> Tuple["Revision", ...]:
+    ) -> Tuple[Optional[_RevisionOrBase], ...]:
         """
         Parse upgrade command syntax :target to retrieve the target revision
         and given the :current_revisons stamp of the database.
@@ -1188,26 +1201,27 @@ class RevisionMap:
             # No relative destination, target is absolute.
             return self.get_revisions(target)
 
-        current_revisions = util.to_tuple(current_revisions)
+        current_revisions_tup: Union[str, Collection[Optional[str]], None]
+        current_revisions_tup = util.to_tuple(current_revisions)
 
         branch_label, symbol, relative_str = match.groups()
         relative = int(relative_str)
         if relative > 0:
             if symbol is None:
-                if not current_revisions:
-                    current_revisions = (None,)
+                if not current_revisions_tup:
+                    current_revisions_tup = (None,)
                 # Try to filter to a single target (avoid ambiguous branches).
-                start_revs = current_revisions
+                start_revs = current_revisions_tup
                 if branch_label:
                     start_revs = self.filter_for_lineage(
-                        self.get_revisions(current_revisions), branch_label
+                        self.get_revisions(current_revisions_tup), branch_label
                     )
                     if not start_revs:
                         # The requested branch is not a head, so we need to
                         # backtrack to find a branchpoint.
                         active_on_branch = self.filter_for_lineage(
                             self._get_ancestor_nodes(
-                                self.get_revisions(current_revisions)
+                                self.get_revisions(current_revisions_tup)
                             ),
                             branch_label,
                         )
@@ -1294,6 +1308,7 @@ class RevisionMap:
             target_revision = None
         assert target_revision is None or isinstance(target_revision, Revision)
 
+        roots: List[Revision]
         # Find candidates to drop.
         if target_revision is None:
             # Downgrading back to base: find all tree roots.
@@ -1307,7 +1322,10 @@ class RevisionMap:
             roots = [target_revision]
         else:
             # Downgrading to fixed target: find all direct children.
-            roots = list(self.get_revisions(target_revision.nextrev))
+            roots = [
+                is_revision(rev)
+                for rev in self.get_revisions(target_revision.nextrev)
+            ]
 
         if branch_label and len(roots) > 1:
             # Need to filter roots.
@@ -1320,11 +1338,12 @@ class RevisionMap:
             }
             # Intersection gives the root revisions we are trying to
             # rollback with the downgrade.
-            roots = list(
-                self.get_revisions(
+            roots = [
+                is_revision(rev)
+                for rev in self.get_revisions(
                     {rev.revision for rev in roots}.intersection(ancestors)
                 )
-            )
+            ]
 
             # Ensure we didn't throw everything away when filtering branches.
             if len(roots) == 0:
@@ -1374,7 +1393,7 @@ class RevisionMap:
         inclusive: bool,
         implicit_base: bool,
         assert_relative_length: bool,
-    ) -> Tuple[Set["Revision"], Tuple[Optional[_RevisionOrBase]]]:
+    ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase]]]:
         """
         Compute the set of required revisions specified by :upper, and the
         current set of active revisions specified by :lower. Find the
@@ -1386,11 +1405,14 @@ class RevisionMap:
         of the current/lower revisions. Dependencies from branches with
         different bases will not be included.
         """
-        targets: Collection["Revision"] = self._parse_upgrade_target(
-            current_revisions=lower,
-            target=upper,
-            assert_relative_length=assert_relative_length,
-        )
+        targets: Collection[Revision] = [
+            is_revision(rev)
+            for rev in self._parse_upgrade_target(
+                current_revisions=lower,
+                target=upper,
+                assert_relative_length=assert_relative_length,
+            )
+        ]
 
         # assert type(targets) is tuple, "targets should be a tuple"
 
@@ -1432,6 +1454,7 @@ class RevisionMap:
                 target=lower,
                 assert_relative_length=assert_relative_length,
             )
+            assert rev
             if rev == "base":
                 current_revisions = tuple()
                 lower = None
@@ -1449,14 +1472,16 @@ class RevisionMap:
 
         # Include the lower revision (=current_revisions?) in the iteration
         if inclusive:
-            needs.update(self.get_revisions(lower))
+            needs.update(is_revision(rev) for rev in self.get_revisions(lower))
         # By default, base is implicit as we want all dependencies returned.
         # Base is also implicit if lower = base
         # implicit_base=False -> only return direct downstreams of
         # current_revisions
         if current_revisions and not implicit_base:
             lower_descendents = self._get_descendant_nodes(
-                current_revisions, check=True, include_dependencies=False
+                [is_revision(rev) for rev in current_revisions],
+                check=True,
+                include_dependencies=False,
             )
             needs.intersection_update(lower_descendents)
 
@@ -1545,7 +1570,7 @@ class Revision:
             args.append("branch_labels=%r" % (self.branch_labels,))
         return "%s(%s)" % (self.__class__.__name__, ", ".join(args))
 
-    def add_nextrev(self, revision: "Revision") -> None:
+    def add_nextrev(self, revision: Revision) -> None:
         self._all_nextrev = self._all_nextrev.union([revision.revision])
         if self.revision in revision._versioned_down_revisions:
             self.nextrev = self.nextrev.union([revision.revision])
@@ -1630,12 +1655,29 @@ class Revision:
         return len(self._versioned_down_revisions) > 1
 
 
+@overload
 def tuple_rev_as_scalar(
     rev: Optional[Sequence[str]],
 ) -> Optional[Union[str, Sequence[str]]]:
+    ...
+
+
+@overload
+def tuple_rev_as_scalar(
+    rev: Optional[Sequence[Optional[str]]],
+) -> Optional[Union[Optional[str], Sequence[Optional[str]]]]:
+    ...
+
+
+def tuple_rev_as_scalar(rev):
     if not rev:
         return None
     elif len(rev) == 1:
         return rev[0]
     else:
         return rev
+
+
+def is_revision(rev: Any) -> Revision:
+    assert isinstance(rev, Revision)
+    return rev
index 0cc9bb8eee6c79f70db504ae25643461fea82f43..8bc7ac1c172a14b8084e3d5e71ba92e463af01e8 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import shlex
 import subprocess
 import sys
@@ -14,7 +16,7 @@ from ..util import compat
 
 REVISION_SCRIPT_TOKEN = "REVISION_SCRIPT_FILENAME"
 
-_registry = {}
+_registry: dict = {}
 
 
 def register(name: str) -> Callable:
index e7a12c65ae02729f4d70e3c3e3eb96b04b4b7ef6..1c24066b808f57a51f57b3933e506dad4cefa9fc 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import contextlib
 import re
 import sys
@@ -69,7 +71,7 @@ def _assert_raises(
 
 
 class _ErrorContainer:
-    error = None
+    error: Any = None
 
 
 @contextlib.contextmanager
index 849bc83089b31861264737ff03efcdb8dba87e96..26427507eddeeb9209082fd33809f891a8d508c8 100644 (file)
@@ -1,4 +1,6 @@
 # coding: utf-8
+from __future__ import annotations
+
 import configparser
 from contextlib import contextmanager
 import io
index c47cc10ef83d0996df8ee65b48cf0049d48add28..f97dd753204cb1d06c30b33caa70948aad449146 100644 (file)
@@ -1,5 +1,8 @@
+from __future__ import annotations
+
 from typing import Any
 from typing import Dict
+from typing import Set
 
 from sqlalchemy import CHAR
 from sqlalchemy import CheckConstraint
@@ -28,7 +31,7 @@ from ...testing import eq_
 from ...testing.env import clear_staging_env
 from ...testing.env import staging_env
 
-names_in_this_test = set()
+names_in_this_test: Set[Any] = set()
 
 
 @event.listens_for(Table, "after_parent_attach")
@@ -43,15 +46,15 @@ def _default_include_object(obj, name, type_, reflected, compare_to):
         return True
 
 
-_default_object_filters = _default_include_object
+_default_object_filters: Any = _default_include_object
 
-_default_name_filters = None
+_default_name_filters: Any = None
 
 
 class ModelOne:
     __requires__ = ("unique_constraint_reflection",)
 
-    schema = None
+    schema: Any = None
 
     @classmethod
     def _get_db_schema(cls):
index 9d24d0fe931f85fc44afd3990881f7d3e263ad0f..a82690d12655bd815560883cf0b438b945a87763 100644 (file)
@@ -4,6 +4,7 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
+from __future__ import annotations
 
 import re
 import types
@@ -55,7 +56,7 @@ def flag_combinations(*combinations):
             for d in combinations
         ],
         id_="i" + ("a" * len(keys)),
-        argnames=",".join(keys)
+        argnames=",".join(keys),
     )
 
 
index 49bee432cdb016c50951500c431defc833da06bf..d5fa4d32550ddecd0d978bdbf1f1578dee00ef57 100644 (file)
@@ -7,6 +7,7 @@ from .langhelpers import Dispatcher
 from .langhelpers import immutabledict
 from .langhelpers import memoized_property
 from .langhelpers import ModuleClsProxy
+from .langhelpers import not_none
 from .langhelpers import rev_id
 from .langhelpers import to_list
 from .langhelpers import to_tuple
index 54420cbc95480a3040cb739d8761fb1bdee66fe6..e6a8f6e0acfcd1d1e932fea605f155b9851b63f0 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import io
 import os
 import sys
@@ -26,8 +28,8 @@ if py39:
     from importlib import metadata as importlib_metadata
     from importlib.metadata import EntryPoint
 else:
-    import importlib_resources  # type:ignore[no-redef] # noqa
-    import importlib_metadata  # type:ignore[no-redef] # noqa
+    import importlib_resources  # type:ignore # noqa
+    import importlib_metadata  # type:ignore # noqa
     from importlib_metadata import EntryPoint  # type:ignore # noqa
 
 
index ba376c0793068302e8ff1572b7637ddbcde5372f..f1d1557f74c8977efa0b22535f45f44a2c9e2564 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import os
 from os.path import exists
 from os.path import join
index fd7ccb8fd1e3151810797d80ef356cb51bc89f9f..b6ceb0cd953bdb4454fb6893dc4583912c3f7096 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import collections
 from collections.abc import Iterable
 import textwrap
@@ -280,3 +282,8 @@ class Dispatcher:
         else:
             d._registry.update(self._registry)
         return d
+
+
+def not_none(value: Optional[_T]) -> _T:
+    assert value is not None
+    return value
index 66f8cc256a3938c486f12abdfee03abe8e7bed77..dad222f630d7528b7f79e7b0180e91127e95e0f0 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from collections.abc import Iterable
 import logging
 import sys
index b6662b716a2c1f11658206c36da8ea01e0f4460f..7535004767a537a794a8db2977713147e2fbd310 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import atexit
 from contextlib import ExitStack
 import importlib
index 787b77c263a4282d646c9d6d00cf6a865a095fc8..21a9f7f26706239166d4919407ac3609f7cdfbce 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import contextlib
 import re
 from typing import Iterator
@@ -56,6 +58,7 @@ _vers = tuple(
 sqla_13 = _vers >= (1, 3)
 sqla_14 = _vers >= (1, 4)
 sqla_14_26 = _vers >= (1, 4, 26)
+sqla_2 = _vers >= (1, 5)
 
 
 if sqla_14:
index a8f43fefdf149ecbf2415952cba59074a3027744..721b0db0157a4bda0524440c4090275f156a3f23 100644 (file)
@@ -1,2 +1,16 @@
 [tool.black]
 line-length = 79
+
+[tool.mypy]
+
+exclude = [
+    'alembic/template',
+    'alembic.testing.*',
+]
+
+[[tool.mypy.overrides]]
+module = [
+    'mako.*',
+    'sqlalchemy.testing.*'
+]
+ignore_missing_imports = true
index 1060e25ec357fd3212caa3856e021e3551a2db8f..aa88f66d22a194133c522382a64ca77ba1a70839 100644 (file)
@@ -404,4 +404,9 @@ class DefaultRequirements(SuiteRequirements):
         version = exclusions.only_if(
             lambda _: compat.py39, "python 3.9 is required"
         )
-        return imports + version
+
+        sqlalchemy = exclusions.only_if(
+            lambda _: sqla_compat.sqla_2, "sqlalchemy 2 is required"
+        )
+
+        return imports + version + sqlalchemy
index 60728a8d72d6fd90576bfbd32d0b6690da616940..ec928cc114cd754837042a67f3ce520fd6939de8 100644 (file)
@@ -4,6 +4,7 @@ import re
 import sys
 from tempfile import NamedTemporaryFile
 import textwrap
+import typing
 
 from mako.pygen import PythonPrinter
 
@@ -15,6 +16,8 @@ if True:  # avoid flake/zimports messing with the order
     from alembic.script.write_hooks import console_scripts
     from alembic.util.compat import inspect_formatargspec
     from alembic.util.compat import inspect_getfullargspec
+    from alembic.operations import ops
+    import sqlalchemy as sa
 
 IGNORE_ITEMS = {
     "op": {"context", "create_module_class_proxy"},
@@ -24,6 +27,17 @@ IGNORE_ITEMS = {
         "requires_connection",
     },
 }
+TRIM_MODULE = [
+    "alembic.runtime.migration.",
+    "alembic.operations.ops.",
+    "sqlalchemy.engine.base.",
+    "sqlalchemy.sql.schema.",
+    "sqlalchemy.sql.selectable.",
+    "sqlalchemy.sql.elements.",
+    "sqlalchemy.sql.type_api.",
+    "sqlalchemy.sql.functions.",
+    "sqlalchemy.sql.dml.",
+]
 
 
 def generate_pyi_for_proxy(
@@ -66,14 +80,22 @@ def generate_pyi_for_proxy(
         printer.writeline("### end imports ###")
         buf.write("\n\n")
 
+        module = sys.modules[cls.__module__]
+        env = {
+            **sa.__dict__,
+            **sa.types.__dict__,
+            **ops.__dict__,
+            **module.__dict__,
+        }
+
         for name in dir(cls):
             if name.startswith("_") or name in ignore_items:
                 continue
-            meth = getattr(cls, name)
+            meth = getattr(cls, name, None)
             if callable(meth):
-                _generate_stub_for_meth(cls, name, printer)
+                _generate_stub_for_meth(cls, name, printer, env)
             else:
-                _generate_stub_for_attr(cls, name, printer)
+                _generate_stub_for_attr(cls, name, printer, env)
 
         printer.close()
 
@@ -92,18 +114,29 @@ def generate_pyi_for_proxy(
     )
 
 
-def _generate_stub_for_attr(cls, name, printer):
-    type_ = cls.__annotations__.get(name, "Any")
+def _generate_stub_for_attr(cls, name, printer, env):
+    try:
+        annotations = typing.get_type_hints(cls, env)
+    except NameError as e:
+        annotations = cls.__annotations__
+    type_ = annotations.get(name, "Any")
+    if isinstance(type_, str) and type_[0] in "'\"":
+        type_ = type_[1:-1]
     printer.writeline(f"{name}: {type_}")
 
 
-def _generate_stub_for_meth(cls, name, printer):
+def _generate_stub_for_meth(cls, name, printer, env):
 
     fn = getattr(cls, name)
     while hasattr(fn, "__wrapped__"):
         fn = fn.__wrapped__
 
     spec = inspect_getfullargspec(fn)
+    try:
+        annotations = typing.get_type_hints(fn, env)
+        spec.annotations.update(annotations)
+    except NameError as e:
+        pass
 
     name_args = spec[0]
     assert name_args[0:1] == ["self"] or name_args[0:1] == ["cls"]
@@ -119,7 +152,10 @@ def _generate_stub_for_meth(cls, name, printer):
             else:
                 retval = annotation.__module__ + "." + annotation.__qualname__
         else:
-            retval = repr(annotation)
+            retval = annotation
+
+        for trim in TRIM_MODULE:
+            retval = retval.replace(trim, "")
 
         retval = re.sub(
             r'ForwardRef\(([\'"].+?[\'"])\)', lambda m: m.group(1), retval
@@ -127,7 +163,11 @@ def _generate_stub_for_meth(cls, name, printer):
         retval = re.sub("NoneType", "None", retval)
         return retval
 
-    argspec = inspect_formatargspec(*spec, formatannotation=_formatannotation)
+    argspec = inspect_formatargspec(
+        *spec,
+        formatannotation=_formatannotation,
+        formatreturns=lambda val: f"-> {_formatannotation(val)}",
+    )
 
     func_text = textwrap.dedent(
         """\