]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Add pep-484 type annotations
authorCaselIT <cfederico87@gmail.com>
Sun, 18 Apr 2021 13:44:50 +0000 (15:44 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 11 Aug 2021 19:04:56 +0000 (15:04 -0400)
pep-484 type annotations have been added throughout the library. This
should be helpful in providing Mypy and IDE support, however there is not
full support for Alembic's dynamically modified "op" namespace as of yet; a
future release will likely modify the approach used for importing this
namespace to be better compatible with pep-484 capabilities.

Type originally created using MonkeyType

Add types extracted with the MonkeyType https://github.com/instagram/MonkeyType
library by running the unit tests using ``monkeytype run -m pytest tests``, then
``monkeytype apply <module>`` (see below for further details).
USed MonkeyType version 20.5 on Python 3.8, since newer version have issues

After applying the types, the new imports are placed in a ``TYPE_CHECKING`` guard
and all type definition of non base types are deferred by using the string notation.

NOTE: since to apply the types MonkeType need to import the module, also the test
ones, the patch below mocks the setup done by pytest so that the tests could be
correctly imported

diff --git a/alembic/testing/__init__.py b/alembic/testing/__init__.py
index bdd1746..b1090c7 100644

Change-Id: Iff93628f4b43c740848871ce077a118db5e75d41
--- a/alembic/testing/__init__.py
+++ b/alembic/testing/__init__.py
@@ -9,6 +9,12 @@ from sqlalchemy.testing.config import combinations
 from sqlalchemy.testing.config import fixture
 from sqlalchemy.testing.config import requirements as requires

+from sqlalchemy.testing.plugin.pytestplugin import PytestFixtureFunctions
+from sqlalchemy.testing.plugin.plugin_base import _setup_requirements
+
+config._fixture_functions = PytestFixtureFunctions()
+_setup_requirements("tests.requirements:DefaultRequirements")
+
 from alembic import util
 from .assertions import assert_raises
 from .assertions import assert_raises_message

Currently I'm using this branch of the sqlalchemy stubs:
https://github.com/sqlalchemy/sqlalchemy2-stubs/tree/alembic_updates

Change-Id: I8fd0700aab1913f395302626b8b84fea60334abd

40 files changed:
alembic/__init__.py
alembic/autogenerate/api.py
alembic/autogenerate/compare.py
alembic/autogenerate/render.py
alembic/autogenerate/rewriter.py
alembic/command.py
alembic/config.py
alembic/ddl/base.py
alembic/ddl/impl.py
alembic/ddl/mssql.py
alembic/ddl/mysql.py
alembic/ddl/oracle.py
alembic/ddl/postgresql.py
alembic/ddl/sqlite.py
alembic/environment.py [new file with mode: 0644]
alembic/migration.py [new file with mode: 0644]
alembic/operations/base.py
alembic/operations/batch.py
alembic/operations/ops.py
alembic/operations/schemaobj.py
alembic/operations/toimpl.py
alembic/runtime/environment.py
alembic/runtime/migration.py
alembic/script/base.py
alembic/script/revision.py
alembic/script/write_hooks.py
alembic/testing/assertions.py
alembic/testing/fixtures.py
alembic/testing/requirements.py
alembic/testing/suite/_autogen_fixtures.py
alembic/util/compat.py
alembic/util/editor.py
alembic/util/langhelpers.py
alembic/util/messaging.py
alembic/util/pyfiles.py
alembic/util/sqla_compat.py
docs/build/unreleased/py3_typing.rst [new file with mode: 0644]
setup.cfg
tests/test_revision.py
tox.ini

index 0820de065abdfc011f422320ad0cbbb98a445365..023fd068661214b039af0b2be6afcbba69717127 100644 (file)
@@ -2,10 +2,5 @@ import sys
 
 from . import context
 from . import op
-from .runtime import environment
-from .runtime import migration
 
 __version__ = "1.7.0"
-
-sys.modules["alembic.migration"] = migration
-sys.modules["alembic.environment"] = environment
index 4c156c4ee90126dd66bb0afe1ffb7a0f72245a87..3b23dcd21ba8328e6b5d46919b3c3dd1a69d1d10 100644 (file)
@@ -2,6 +2,15 @@
 automatically."""
 
 import contextlib
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterator
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
 
 from sqlalchemy import inspect
 
@@ -10,8 +19,26 @@ from . import render
 from .. import util
 from ..operations import ops
 
-
-def compare_metadata(context, metadata):
+if TYPE_CHECKING:
+    from sqlalchemy.engine import Connection
+    from sqlalchemy.engine import Dialect
+    from sqlalchemy.engine import Inspector
+    from sqlalchemy.sql.schema import Column
+    from sqlalchemy.sql.schema import ForeignKeyConstraint
+    from sqlalchemy.sql.schema import Index
+    from sqlalchemy.sql.schema import MetaData
+    from sqlalchemy.sql.schema import Table
+    from sqlalchemy.sql.schema import UniqueConstraint
+
+    from alembic.config import Config
+    from alembic.operations.ops import MigrationScript
+    from alembic.operations.ops import UpgradeOps
+    from alembic.runtime.migration import MigrationContext
+    from alembic.script.base import Script
+    from alembic.script.base import ScriptDirectory
+
+
+def compare_metadata(context: "MigrationContext", metadata: "MetaData") -> Any:
     """Compare a database schema to that given in a
     :class:`~sqlalchemy.schema.MetaData` instance.
 
@@ -106,7 +133,9 @@ def compare_metadata(context, metadata):
     return migration_script.upgrade_ops.as_diffs()
 
 
-def produce_migrations(context, metadata):
+def produce_migrations(
+    context: "MigrationContext", metadata: "MetaData"
+) -> "MigrationScript":
     """Produce a :class:`.MigrationScript` structure based on schema
     comparison.
 
@@ -136,14 +165,14 @@ def produce_migrations(context, metadata):
 
 
 def render_python_code(
-    up_or_down_op,
-    sqlalchemy_module_prefix="sa.",
-    alembic_module_prefix="op.",
-    render_as_batch=False,
-    imports=(),
-    render_item=None,
-    migration_context=None,
-):
+    up_or_down_op: "UpgradeOps",
+    sqlalchemy_module_prefix: str = "sa.",
+    alembic_module_prefix: str = "op.",
+    render_as_batch: bool = False,
+    imports: Tuple[str, ...] = (),
+    render_item: None = None,
+    migration_context: Optional["MigrationContext"] = None,
+) -> str:
     """Render Python code given an :class:`.UpgradeOps` or
     :class:`.DowngradeOps` object.
 
@@ -173,7 +202,9 @@ def render_python_code(
     )
 
 
-def _render_migration_diffs(context, template_args):
+def _render_migration_diffs(
+    context: "MigrationContext", template_args: Dict[Any, Any]
+) -> None:
     """legacy, used by test_autogen_composition at the moment"""
 
     autogen_context = AutogenContext(context)
@@ -196,7 +227,7 @@ class AutogenContext:
     """Maintains configuration and state that's specific to an
     autogenerate operation."""
 
-    metadata = None
+    metadata: Optional["MetaData"] = None
     """The :class:`~sqlalchemy.schema.MetaData` object
     representing the destination.
 
@@ -214,7 +245,7 @@ class AutogenContext:
 
     """
 
-    connection = None
+    connection: Optional["Connection"] = None
     """The :class:`~sqlalchemy.engine.base.Connection` object currently
     connected to the database backend being compared.
 
@@ -223,7 +254,7 @@ class AutogenContext:
 
     """
 
-    dialect = None
+    dialect: Optional["Dialect"] = None
     """The :class:`~sqlalchemy.engine.Dialect` object currently in use.
 
     This is normally obtained from the
@@ -231,7 +262,7 @@ class AutogenContext:
 
     """
 
-    imports = None
+    imports: Set[str] = None  # type: ignore[assignment]
     """A ``set()`` which contains string Python import directives.
 
     The directives are to be rendered into the ``${imports}`` section
@@ -245,12 +276,16 @@ class AutogenContext:
 
     """
 
-    migration_context = None
+    migration_context: "MigrationContext" = None  # type: ignore[assignment]
     """The :class:`.MigrationContext` established by the ``env.py`` script."""
 
     def __init__(
-        self, migration_context, metadata=None, opts=None, autogenerate=True
-    ):
+        self,
+        migration_context: "MigrationContext",
+        metadata: Optional["MetaData"] = None,
+        opts: Optional[dict] = None,
+        autogenerate: bool = True,
+    ) -> None:
 
         if (
             autogenerate
@@ -301,20 +336,25 @@ class AutogenContext:
             self.dialect = self.migration_context.dialect
 
         self.imports = set()
-        self.opts = opts
-        self._has_batch = False
+        self.opts: Dict[str, Any] = opts
+        self._has_batch: bool = False
 
     @util.memoized_property
-    def inspector(self):
+    def inspector(self) -> "Inspector":
         return inspect(self.connection)
 
     @contextlib.contextmanager
-    def _within_batch(self):
+    def _within_batch(self) -> Iterator[None]:
         self._has_batch = True
         yield
         self._has_batch = False
 
-    def run_name_filters(self, name, type_, parent_names):
+    def run_name_filters(
+        self,
+        name: Optional[str],
+        type_: str,
+        parent_names: Dict[str, Optional[str]],
+    ) -> bool:
         """Run the context's name filters and return True if the targets
         should be part of the autogenerate operation.
 
@@ -348,7 +388,22 @@ class AutogenContext:
         else:
             return True
 
-    def run_object_filters(self, object_, name, type_, reflected, compare_to):
+    def run_object_filters(
+        self,
+        object_: Union[
+            "Table",
+            "Index",
+            "Column",
+            "UniqueConstraint",
+            "ForeignKeyConstraint",
+        ],
+        name: Optional[str],
+        type_: str,
+        reflected: bool,
+        compare_to: Optional[
+            Union["Table", "Index", "Column", "UniqueConstraint"]
+        ],
+    ) -> bool:
         """Run the context's object filters and return True if the targets
         should be part of the autogenerate operation.
 
@@ -414,11 +469,11 @@ class RevisionContext:
 
     def __init__(
         self,
-        config,
-        script_directory,
-        command_args,
-        process_revision_directives=None,
-    ):
+        config: "Config",
+        script_directory: "ScriptDirectory",
+        command_args: Dict[str, Any],
+        process_revision_directives: Optional[Callable] = None,
+    ) -> None:
         self.config = config
         self.script_directory = script_directory
         self.command_args = command_args
@@ -429,10 +484,10 @@ class RevisionContext:
         }
         self.generated_revisions = [self._default_revision()]
 
-    def _to_script(self, migration_script):
-        template_args = {}
-        for k, v in self.template_args.items():
-            template_args.setdefault(k, v)
+    def _to_script(
+        self, migration_script: "MigrationScript"
+    ) -> Optional["Script"]:
+        template_args: Dict[str, Any] = self.template_args.copy()
 
         if getattr(migration_script, "_needs_render", False):
             autogen_context = self._last_autogen_context
@@ -446,6 +501,7 @@ class RevisionContext:
                 autogen_context, migration_script, template_args
             )
 
+        assert migration_script.rev_id is not None
         return self.script_directory.generate_revision(
             migration_script.rev_id,
             migration_script.message,
@@ -458,13 +514,22 @@ class RevisionContext:
             **template_args
         )
 
-    def run_autogenerate(self, rev, migration_context):
+    def run_autogenerate(
+        self, rev: tuple, migration_context: "MigrationContext"
+    ):
         self._run_environment(rev, migration_context, True)
 
-    def run_no_autogenerate(self, rev, migration_context):
+    def run_no_autogenerate(
+        self, rev: tuple, migration_context: "MigrationContext"
+    ):
         self._run_environment(rev, migration_context, False)
 
-    def _run_environment(self, rev, migration_context, autogenerate):
+    def _run_environment(
+        self,
+        rev: tuple,
+        migration_context: "MigrationContext",
+        autogenerate: bool,
+    ):
         if autogenerate:
             if self.command_args["sql"]:
                 raise util.CommandError(
@@ -493,9 +558,10 @@ class RevisionContext:
                 ops.DowngradeOps([], downgrade_token=downgrade_token)
             )
 
-        self._last_autogen_context = autogen_context = AutogenContext(
+        autogen_context = AutogenContext(
             migration_context, autogenerate=autogenerate
         )
+        self._last_autogen_context: AutogenContext = autogen_context
 
         if autogenerate:
             compare._populate_migration_script(
@@ -514,20 +580,21 @@ class RevisionContext:
         for migration_script in self.generated_revisions:
             migration_script._needs_render = True
 
-    def _default_revision(self):
+    def _default_revision(self) -> "MigrationScript":
+        command_args: Dict[str, Any] = self.command_args
         op = ops.MigrationScript(
-            rev_id=self.command_args["rev_id"] or util.rev_id(),
-            message=self.command_args["message"],
+            rev_id=command_args["rev_id"] or util.rev_id(),
+            message=command_args["message"],
             upgrade_ops=ops.UpgradeOps([]),
             downgrade_ops=ops.DowngradeOps([]),
-            head=self.command_args["head"],
-            splice=self.command_args["splice"],
-            branch_label=self.command_args["branch_label"],
-            version_path=self.command_args["version_path"],
-            depends_on=self.command_args["depends_on"],
+            head=command_args["head"],
+            splice=command_args["splice"],
+            branch_label=command_args["branch_label"],
+            version_path=command_args["version_path"],
+            depends_on=command_args["depends_on"],
         )
         return op
 
-    def generate_scripts(self):
+    def generate_scripts(self) -> Iterator[Optional["Script"]]:
         for generated_revision in self.generated_revisions:
             yield self._to_script(generated_revision)
index dbb0706c464049acee7216e25207a7f7e81cd437..528b17ac4d88e85500b76c6ff35cd7d0db77fd58 100644 (file)
@@ -1,6 +1,16 @@
 import contextlib
 import logging
 import re
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
 
 from sqlalchemy import event
 from sqlalchemy import inspect
@@ -14,10 +24,29 @@ from .. import util
 from ..operations import ops
 from ..util import sqla_compat
 
+if TYPE_CHECKING:
+    from typing import Literal
+
+    from sqlalchemy.engine.reflection import Inspector
+    from sqlalchemy.sql.elements import quoted_name
+    from sqlalchemy.sql.schema import Column
+    from sqlalchemy.sql.schema import ForeignKeyConstraint
+    from sqlalchemy.sql.schema import Index
+    from sqlalchemy.sql.schema import Table
+    from sqlalchemy.sql.schema import UniqueConstraint
+
+    from alembic.autogenerate.api import AutogenContext
+    from alembic.operations.ops import AlterColumnOp
+    from alembic.operations.ops import MigrationScript
+    from alembic.operations.ops import ModifyTableOps
+    from alembic.operations.ops import UpgradeOps
+
 log = logging.getLogger(__name__)
 
 
-def _populate_migration_script(autogen_context, migration_script):
+def _populate_migration_script(
+    autogen_context: "AutogenContext", migration_script: "MigrationScript"
+) -> None:
     upgrade_ops = migration_script.upgrade_ops_list[-1]
     downgrade_ops = migration_script.downgrade_ops_list[-1]
 
@@ -28,14 +57,18 @@ def _populate_migration_script(autogen_context, migration_script):
 comparators = util.Dispatcher(uselist=True)
 
 
-def _produce_net_changes(autogen_context, upgrade_ops):
+def _produce_net_changes(
+    autogen_context: "AutogenContext", upgrade_ops: "UpgradeOps"
+) -> None:
 
     connection = autogen_context.connection
+    assert connection is not None
     include_schemas = autogen_context.opts.get("include_schemas", False)
 
-    inspector = inspect(connection)
+    inspector: "Inspector" = inspect(connection)
 
     default_schema = connection.dialect.default_schema_name
+    schemas: Set[Optional[str]]
     if include_schemas:
         schemas = set(inspector.get_schema_names())
         # replace default schema name with None
@@ -44,22 +77,27 @@ def _produce_net_changes(autogen_context, upgrade_ops):
         schemas.discard(default_schema)
         schemas.add(None)
     else:
-        schemas = [None]
+        schemas = {None}
 
     schemas = {
         s for s in schemas if autogen_context.run_name_filters(s, "schema", {})
     }
 
+    assert autogen_context.dialect is not None
     comparators.dispatch("schema", autogen_context.dialect.name)(
         autogen_context, upgrade_ops, schemas
     )
 
 
 @comparators.dispatch_for("schema")
-def _autogen_for_tables(autogen_context, upgrade_ops, schemas):
+def _autogen_for_tables(
+    autogen_context: "AutogenContext",
+    upgrade_ops: "UpgradeOps",
+    schemas: Union[Set[None], Set[Optional[str]]],
+) -> None:
     inspector = autogen_context.inspector
 
-    conn_table_names = set()
+    conn_table_names: Set[Tuple[Optional[str], str]] = set()
 
     version_table_schema = (
         autogen_context.migration_context.version_table_schema
@@ -95,12 +133,12 @@ def _autogen_for_tables(autogen_context, upgrade_ops, schemas):
 
 
 def _compare_tables(
-    conn_table_names,
-    metadata_table_names,
-    inspector,
-    upgrade_ops,
-    autogen_context,
-):
+    conn_table_names: "set",
+    metadata_table_names: "set",
+    inspector: "Inspector",
+    upgrade_ops: "UpgradeOps",
+    autogen_context: "AutogenContext",
+) -> None:
 
     default_schema = inspector.bind.dialect.default_schema_name
 
@@ -239,7 +277,7 @@ def _compare_tables(
                 upgrade_ops.ops.append(modify_table_ops)
 
 
-def _make_index(params, conn_table):
+def _make_index(params: Dict[str, Any], conn_table: "Table") -> "Index":
     ix = sa_schema.Index(
         params["name"],
         *[conn_table.c[cname] for cname in params["column_names"]],
@@ -251,7 +289,9 @@ def _make_index(params, conn_table):
     return ix
 
 
-def _make_unique_constraint(params, conn_table):
+def _make_unique_constraint(
+    params: Dict[str, Any], conn_table: "Table"
+) -> "UniqueConstraint":
     uq = sa_schema.UniqueConstraint(
         *[conn_table.c[cname] for cname in params["column_names"]],
         name=params["name"]
@@ -262,7 +302,9 @@ def _make_unique_constraint(params, conn_table):
     return uq
 
 
-def _make_foreign_key(params, conn_table):
+def _make_foreign_key(
+    params: Dict[str, Any], conn_table: "Table"
+) -> "ForeignKeyConstraint":
     tname = params["referred_table"]
     if params["referred_schema"]:
         tname = "%s.%s" % (params["referred_schema"], tname)
@@ -285,14 +327,14 @@ def _make_foreign_key(params, conn_table):
 
 @contextlib.contextmanager
 def _compare_columns(
-    schema,
-    tname,
-    conn_table,
-    metadata_table,
-    modify_table_ops,
-    autogen_context,
-    inspector,
-):
+    schema: Optional[str],
+    tname: Union["quoted_name", str],
+    conn_table: "Table",
+    metadata_table: "Table",
+    modify_table_ops: "ModifyTableOps",
+    autogen_context: "AutogenContext",
+    inspector: "Inspector",
+) -> Iterator[None]:
     name = "%s.%s" % (schema, tname) if schema else tname
     metadata_col_names = OrderedSet(
         c.name for c in metadata_table.c if not c.system
@@ -357,7 +399,9 @@ def _compare_columns(
 
 
 class _constraint_sig:
-    def md_name_to_sql_name(self, context):
+    const: Union["UniqueConstraint", "ForeignKeyConstraint", "Index"]
+
+    def md_name_to_sql_name(self, context: "AutogenContext") -> Optional[str]:
         return sqla_compat._get_constraint_final_name(
             self.const, context.dialect
         )
@@ -368,7 +412,7 @@ class _constraint_sig:
     def __ne__(self, other):
         return self.const != other.const
 
-    def __hash__(self):
+    def __hash__(self) -> int:
         return hash(self.const)
 
 
@@ -376,37 +420,39 @@ class _uq_constraint_sig(_constraint_sig):
     is_index = False
     is_unique = True
 
-    def __init__(self, const):
+    def __init__(self, const: "UniqueConstraint") -> None:
         self.const = const
         self.name = const.name
         self.sig = tuple(sorted([col.name for col in const.columns]))
 
     @property
-    def column_names(self):
+    def column_names(self) -> List[str]:
         return [col.name for col in self.const.columns]
 
 
 class _ix_constraint_sig(_constraint_sig):
     is_index = True
 
-    def __init__(self, const):
+    def __init__(self, const: "Index") -> None:
         self.const = const
         self.name = const.name
         self.sig = tuple(sorted([col.name for col in const.columns]))
         self.is_unique = bool(const.unique)
 
-    def md_name_to_sql_name(self, context):
+    def md_name_to_sql_name(self, context: "AutogenContext") -> Optional[str]:
         return sqla_compat._get_constraint_final_name(
             self.const, context.dialect
         )
 
     @property
-    def column_names(self):
+    def column_names(self) -> Union[List["quoted_name"], List[None]]:
         return sqla_compat._get_index_column_names(self.const)
 
 
 class _fk_constraint_sig(_constraint_sig):
-    def __init__(self, const, include_options=False):
+    def __init__(
+        self, const: "ForeignKeyConstraint", include_options: bool = False
+    ) -> None:
         self.const = const
         self.name = const.name
 
@@ -423,7 +469,7 @@ class _fk_constraint_sig(_constraint_sig):
             initially,
         ) = _fk_spec(const)
 
-        self.sig = (
+        self.sig: Tuple[Any, ...] = (
             self.source_schema,
             self.source_table,
             tuple(self.source_columns),
@@ -450,8 +496,13 @@ class _fk_constraint_sig(_constraint_sig):
 
 @comparators.dispatch_for("table")
 def _compare_indexes_and_uniques(
-    autogen_context, modify_ops, schema, tname, conn_table, metadata_table
-):
+    autogen_context: "AutogenContext",
+    modify_ops: "ModifyTableOps",
+    schema: Optional[str],
+    tname: Union["quoted_name", str],
+    conn_table: Optional["Table"],
+    metadata_table: Optional["Table"],
+) -> None:
 
     inspector = autogen_context.inspector
     is_create_table = conn_table is None
@@ -469,7 +520,7 @@ def _compare_indexes_and_uniques(
         metadata_unique_constraints = set()
         metadata_indexes = set()
 
-    conn_uniques = conn_indexes = frozenset()
+    conn_uniques = conn_indexes = frozenset()  # type:ignore[var-annotated]
 
     supports_unique_constraints = False
 
@@ -479,7 +530,7 @@ def _compare_indexes_and_uniques(
         # 1b. ... and from connection, if the table exists
         if hasattr(inspector, "get_unique_constraints"):
             try:
-                conn_uniques = inspector.get_unique_constraints(
+                conn_uniques = inspector.get_unique_constraints(  # type:ignore[assignment] # noqa
                     tname, schema=schema
                 )
                 supports_unique_constraints = True
@@ -491,7 +542,7 @@ def _compare_indexes_and_uniques(
                 # not being present
                 pass
             else:
-                conn_uniques = [
+                conn_uniques = [  # type:ignore[assignment]
                     uq
                     for uq in conn_uniques
                     if autogen_context.run_name_filters(
@@ -504,11 +555,13 @@ def _compare_indexes_and_uniques(
                     if uq.get("duplicates_index"):
                         unique_constraints_duplicate_unique_indexes = True
         try:
-            conn_indexes = inspector.get_indexes(tname, schema=schema)
+            conn_indexes = inspector.get_indexes(  # type:ignore[assignment]
+                tname, schema=schema
+            )
         except NotImplementedError:
             pass
         else:
-            conn_indexes = [
+            conn_indexes = [  # type:ignore[assignment]
                 ix
                 for ix in conn_indexes
                 if autogen_context.run_name_filters(
@@ -522,14 +575,16 @@ def _compare_indexes_and_uniques(
         # into schema objects
         if is_drop_table:
             # for DROP TABLE uniques are inline, don't need them
-            conn_uniques = set()
+            conn_uniques = set()  # type:ignore[assignment]
         else:
-            conn_uniques = set(
+            conn_uniques = set(  # type:ignore[assignment]
                 _make_unique_constraint(uq_def, conn_table)
                 for uq_def in conn_uniques
             )
 
-        conn_indexes = set(_make_index(ix, conn_table) for ix in conn_indexes)
+        conn_indexes = set(  # type:ignore[assignment]
+            _make_index(ix, conn_table) for ix in conn_indexes
+        )
 
     # 2a. if the dialect dupes unique indexes as unique constraints
     # (mysql and oracle), correct for that
@@ -557,31 +612,39 @@ def _compare_indexes_and_uniques(
     # _constraint_sig() objects provide a consistent facade over both
     # Index and UniqueConstraint so we can easily work with them
     # interchangeably
-    metadata_unique_constraints = set(
+    metadata_unique_constraints_sig = set(
         _uq_constraint_sig(uq) for uq in metadata_unique_constraints
     )
 
-    metadata_indexes = set(_ix_constraint_sig(ix) for ix in metadata_indexes)
+    metadata_indexes_sig = set(
+        _ix_constraint_sig(ix) for ix in metadata_indexes
+    )
 
     conn_unique_constraints = set(
         _uq_constraint_sig(uq) for uq in conn_uniques
     )
 
-    conn_indexes = set(_ix_constraint_sig(ix) for ix in conn_indexes)
+    conn_indexes_sig = set(_ix_constraint_sig(ix) for ix in conn_indexes)
 
     # 5. index things by name, for those objects that have names
     metadata_names = dict(
-        (c.md_name_to_sql_name(autogen_context), c)
-        for c in metadata_unique_constraints.union(metadata_indexes)
+        (cast(str, c.md_name_to_sql_name(autogen_context)), c)
+        for c in metadata_unique_constraints_sig.union(
+            metadata_indexes_sig  # type:ignore[arg-type]
+        )
         if isinstance(c, _ix_constraint_sig)
         or sqla_compat._constraint_is_named(c.const, autogen_context.dialect)
     )
 
     conn_uniques_by_name = dict((c.name, c) for c in conn_unique_constraints)
-    conn_indexes_by_name = dict((c.name, c) for c in conn_indexes)
+    conn_indexes_by_name: Dict[Optional[str], _ix_constraint_sig] = dict(
+        (c.name, c) for c in conn_indexes_sig
+    )
     conn_names = dict(
         (c.name, c)
-        for c in conn_unique_constraints.union(conn_indexes)
+        for c in conn_unique_constraints.union(
+            conn_indexes_sig  # type:ignore[arg-type]
+        )
         if c.name is not None
     )
 
@@ -596,12 +659,12 @@ def _compare_indexes_and_uniques(
     # constraints.
     conn_uniques_by_sig = dict((uq.sig, uq) for uq in conn_unique_constraints)
     metadata_uniques_by_sig = dict(
-        (uq.sig, uq) for uq in metadata_unique_constraints
+        (uq.sig, uq) for uq in metadata_unique_constraints_sig
     )
-    metadata_indexes_by_sig = dict((ix.sig, ix) for ix in metadata_indexes)
+    metadata_indexes_by_sig = dict((ix.sig, ix) for ix in metadata_indexes_sig)
     unnamed_metadata_uniques = dict(
         (uq.sig, uq)
-        for uq in metadata_unique_constraints
+        for uq in metadata_unique_constraints_sig
         if not sqla_compat._constraint_is_named(
             uq.const, autogen_context.dialect
         )
@@ -709,7 +772,9 @@ def _compare_indexes_and_uniques(
                 )
 
     for removed_name in sorted(set(conn_names).difference(metadata_names)):
-        conn_obj = conn_names[removed_name]
+        conn_obj: Union[_ix_constraint_sig, _uq_constraint_sig] = conn_names[
+            removed_name
+        ]
         if not conn_obj.is_index and conn_obj.sig in unnamed_metadata_uniques:
             continue
         elif removed_name in doubled_constraints:
@@ -831,14 +896,14 @@ def _correct_for_uq_duplicates_uix(
 
 @comparators.dispatch_for("column")
 def _compare_nullable(
-    autogen_context,
-    alter_column_op,
-    schema,
-    tname,
-    cname,
-    conn_col,
-    metadata_col,
-):
+    autogen_context: "AutogenContext",
+    alter_column_op: "AlterColumnOp",
+    schema: Optional[str],
+    tname: Union["quoted_name", str],
+    cname: Union["quoted_name", str],
+    conn_col: "Column",
+    metadata_col: "Column",
+) -> None:
 
     metadata_col_nullable = metadata_col.nullable
     conn_col_nullable = conn_col.nullable
@@ -873,14 +938,14 @@ def _compare_nullable(
 
 @comparators.dispatch_for("column")
 def _setup_autoincrement(
-    autogen_context,
-    alter_column_op,
-    schema,
-    tname,
-    cname,
-    conn_col,
-    metadata_col,
-):
+    autogen_context: "AutogenContext",
+    alter_column_op: "AlterColumnOp",
+    schema: Optional[str],
+    tname: Union["quoted_name", str],
+    cname: "quoted_name",
+    conn_col: "Column",
+    metadata_col: "Column",
+) -> None:
 
     if metadata_col.table._autoincrement_column is metadata_col:
         alter_column_op.kw["autoincrement"] = True
@@ -892,14 +957,14 @@ def _setup_autoincrement(
 
 @comparators.dispatch_for("column")
 def _compare_type(
-    autogen_context,
-    alter_column_op,
-    schema,
-    tname,
-    cname,
-    conn_col,
-    metadata_col,
-):
+    autogen_context: "AutogenContext",
+    alter_column_op: "AlterColumnOp",
+    schema: Optional[str],
+    tname: Union["quoted_name", str],
+    cname: Union["quoted_name", str],
+    conn_col: "Column",
+    metadata_col: "Column",
+) -> None:
 
     conn_type = conn_col.type
     alter_column_op.existing_type = conn_type
@@ -935,8 +1000,10 @@ def _compare_type(
 
 
 def _render_server_default_for_compare(
-    metadata_default, metadata_col, autogen_context
-):
+    metadata_default: Optional[Any],
+    metadata_col: "Column",
+    autogen_context: "AutogenContext",
+) -> Optional[str]:
     rendered = _user_defined_render(
         "server_default", metadata_default, autogen_context
     )
@@ -963,7 +1030,7 @@ def _render_server_default_for_compare(
         return None
 
 
-def _normalize_computed_default(sqltext):
+def _normalize_computed_default(sqltext: str) -> str:
     """we want to warn if a computed sql expression has changed.  however
     we don't want false positives and the warning is not that critical.
     so filter out most forms of variability from the SQL text.
@@ -974,16 +1041,16 @@ def _normalize_computed_default(sqltext):
 
 
 def _compare_computed_default(
-    autogen_context,
-    alter_column_op,
-    schema,
-    tname,
-    cname,
-    conn_col,
-    metadata_col,
-):
+    autogen_context: "AutogenContext",
+    alter_column_op: "AlterColumnOp",
+    schema: Optional[str],
+    tname: "str",
+    cname: "str",
+    conn_col: "Column",
+    metadata_col: "Column",
+) -> None:
     rendered_metadata_default = str(
-        metadata_col.server_default.sqltext.compile(
+        cast(sa_schema.Computed, metadata_col.server_default).sqltext.compile(
             dialect=autogen_context.dialect,
             compile_kwargs={"literal_binds": True},
         )
@@ -1017,7 +1084,7 @@ def _compare_computed_default(
         _warn_computed_not_supported(tname, cname)
 
 
-def _warn_computed_not_supported(tname, cname):
+def _warn_computed_not_supported(tname: str, cname: str) -> None:
     util.warn("Computed default on %s.%s cannot be modified" % (tname, cname))
 
 
@@ -1040,14 +1107,14 @@ def _compare_identity_default(
 
 @comparators.dispatch_for("column")
 def _compare_server_default(
-    autogen_context,
-    alter_column_op,
-    schema,
-    tname,
-    cname,
-    conn_col,
-    metadata_col,
-):
+    autogen_context: "AutogenContext",
+    alter_column_op: "AlterColumnOp",
+    schema: Optional[str],
+    tname: Union["quoted_name", str],
+    cname: Union["quoted_name", str],
+    conn_col: "Column",
+    metadata_col: "Column",
+) -> Optional[bool]:
 
     metadata_default = metadata_col.server_default
     conn_col_default = conn_col.server_default
@@ -1065,14 +1132,16 @@ def _compare_server_default(
             return False
 
         else:
-            return _compare_computed_default(
-                autogen_context,
-                alter_column_op,
-                schema,
-                tname,
-                cname,
-                conn_col,
-                metadata_col,
+            return (
+                _compare_computed_default(  # type:ignore[func-returns-value]
+                    autogen_context,
+                    alter_column_op,
+                    schema,
+                    tname,
+                    cname,
+                    conn_col,
+                    metadata_col,
+                )
             )
     if sqla_compat._server_default_is_computed(conn_col_default):
         _warn_computed_not_supported(tname, cname)
@@ -1107,7 +1176,7 @@ def _compare_server_default(
         )
 
         rendered_conn_default = (
-            conn_col_default.arg.text if conn_col_default else None
+            cast(Any, conn_col_default).arg.text if conn_col_default else None
         )
 
         alter_column_op.existing_server_default = conn_col_default
@@ -1122,20 +1191,23 @@ def _compare_server_default(
             alter_column_op.modify_server_default = metadata_default
             log.info("Detected server default on column '%s.%s'", tname, cname)
 
+    return None
+
 
 @comparators.dispatch_for("column")
 def _compare_column_comment(
-    autogen_context,
-    alter_column_op,
-    schema,
-    tname,
-    cname,
-    conn_col,
-    metadata_col,
-):
-
+    autogen_context: "AutogenContext",
+    alter_column_op: "AlterColumnOp",
+    schema: Optional[str],
+    tname: Union["quoted_name", str],
+    cname: "quoted_name",
+    conn_col: "Column",
+    metadata_col: "Column",
+) -> Optional["Literal[False]"]:
+
+    assert autogen_context.dialect is not None
     if not autogen_context.dialect.supports_comments:
-        return
+        return None
 
     metadata_comment = metadata_col.comment
     conn_col_comment = conn_col.comment
@@ -1148,16 +1220,18 @@ def _compare_column_comment(
         alter_column_op.modify_comment = metadata_comment
         log.info("Detected column comment '%s.%s'", tname, cname)
 
+    return None
+
 
 @comparators.dispatch_for("table")
 def _compare_foreign_keys(
-    autogen_context,
-    modify_table_ops,
-    schema,
-    tname,
-    conn_table,
-    metadata_table,
-):
+    autogen_context: "AutogenContext",
+    modify_table_ops: "ModifyTableOps",
+    schema: Optional[str],
+    tname: Union["quoted_name", str],
+    conn_table: Optional["Table"],
+    metadata_table: Optional["Table"],
+) -> None:
 
     # if we're doing CREATE TABLE, all FKs are created
     # inline within the table def
@@ -1181,7 +1255,7 @@ def _compare_foreign_keys(
         )
     ]
 
-    backend_reflects_fk_options = conn_fks and "options" in conn_fks[0]
+    backend_reflects_fk_options = bool(conn_fks and "options" in conn_fks[0])
 
     conn_fks = set(_make_foreign_key(const, conn_table) for const in conn_fks)
 
@@ -1268,14 +1342,15 @@ def _compare_foreign_keys(
 
 @comparators.dispatch_for("table")
 def _compare_table_comment(
-    autogen_context,
-    modify_table_ops,
-    schema,
-    tname,
-    conn_table,
-    metadata_table,
-):
-
+    autogen_context: "AutogenContext",
+    modify_table_ops: "ModifyTableOps",
+    schema: Optional[str],
+    tname: Union["quoted_name", str],
+    conn_table: Optional["Table"],
+    metadata_table: Optional["Table"],
+) -> None:
+
+    assert autogen_context.dialect is not None
     if not autogen_context.dialect.supports_comments:
         return
 
index 490d65cb6d9d8e91aceb7205273e9b51afe37a3e..90d49e5f373002a620b836618ebde12bb20d6e43 100644 (file)
@@ -1,11 +1,20 @@
 from collections import OrderedDict
 from io import StringIO
 import re
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
 
 from mako.pygen import PythonPrinter
 from sqlalchemy import schema as sa_schema
 from sqlalchemy import sql
 from sqlalchemy import types as sqltypes
+from sqlalchemy.sql.elements import conv
 
 from .. import util
 from ..operations import ops
@@ -13,34 +22,59 @@ from ..util import compat
 from ..util import sqla_compat
 from ..util.compat import string_types
 
+if TYPE_CHECKING:
+    from typing import Literal
+
+    from sqlalchemy.sql.elements import ColumnElement
+    from sqlalchemy.sql.elements import quoted_name
+    from sqlalchemy.sql.elements import TextClause
+    from sqlalchemy.sql.schema import CheckConstraint
+    from sqlalchemy.sql.schema import Column
+    from sqlalchemy.sql.schema import Constraint
+    from sqlalchemy.sql.schema import DefaultClause
+    from sqlalchemy.sql.schema import FetchedValue
+    from sqlalchemy.sql.schema import ForeignKey
+    from sqlalchemy.sql.schema import ForeignKeyConstraint
+    from sqlalchemy.sql.schema import Index
+    from sqlalchemy.sql.schema import MetaData
+    from sqlalchemy.sql.schema import PrimaryKeyConstraint
+    from sqlalchemy.sql.schema import UniqueConstraint
+    from sqlalchemy.sql.sqltypes import ARRAY
+    from sqlalchemy.sql.type_api import TypeEngine
+    from sqlalchemy.sql.type_api import Variant
+
+    from alembic.autogenerate.api import AutogenContext
+    from alembic.config import Config
+    from alembic.operations.ops import MigrationScript
+    from alembic.operations.ops import ModifyTableOps
+    from alembic.util.sqla_compat import Computed
+    from alembic.util.sqla_compat import Identity
 
-MAX_PYTHON_ARGS = 255
-
-try:
-    from sqlalchemy.sql.naming import conv
-
-    def _render_gen_name(autogen_context, name):
-        if isinstance(name, conv):
-            return _f_name(_alembic_autogenerate_prefix(autogen_context), name)
-        else:
-            return name
 
+MAX_PYTHON_ARGS = 255
 
-except ImportError:
 
-    def _render_gen_name(autogen_context, name):
+def _render_gen_name(
+    autogen_context: "AutogenContext",
+    name: Optional[Union["quoted_name", str]],
+) -> Optional[Union["quoted_name", str, "_f_name"]]:
+    if isinstance(name, conv):
+        return _f_name(_alembic_autogenerate_prefix(autogen_context), name)
+    else:
         return name
 
 
-def _indent(text):
+def _indent(text: str) -> str:
     text = re.compile(r"^", re.M).sub("    ", text).strip()
     text = re.compile(r" +$", re.M).sub("", text)
     return text
 
 
 def _render_python_into_templatevars(
-    autogen_context, migration_script, template_args
-):
+    autogen_context: "AutogenContext",
+    migration_script: "MigrationScript",
+    template_args: Dict[str, Union[str, "Config"]],
+) -> None:
     imports = autogen_context.imports
 
     for upgrade_ops, downgrade_ops in zip(
@@ -58,7 +92,10 @@ def _render_python_into_templatevars(
 default_renderers = renderers = util.Dispatcher()
 
 
-def _render_cmd_body(op_container, autogen_context):
+def _render_cmd_body(
+    op_container: "ops.OpContainer",
+    autogen_context: "AutogenContext",
+) -> str:
 
     buf = StringIO()
     printer = PythonPrinter(buf)
@@ -70,7 +107,7 @@ def _render_cmd_body(op_container, autogen_context):
     has_lines = False
     for op in op_container.ops:
         lines = render_op(autogen_context, op)
-        has_lines = has_lines or lines
+        has_lines = has_lines or bool(lines)
 
         for line in lines:
             printer.writeline(line)
@@ -83,18 +120,24 @@ def _render_cmd_body(op_container, autogen_context):
     return buf.getvalue()
 
 
-def render_op(autogen_context, op):
+def render_op(
+    autogen_context: "AutogenContext", op: "ops.MigrateOperation"
+) -> List[str]:
     renderer = renderers.dispatch(op)
     lines = util.to_list(renderer(autogen_context, op))
     return lines
 
 
-def render_op_text(autogen_context, op):
+def render_op_text(
+    autogen_context: "AutogenContext", op: "ops.MigrateOperation"
+) -> str:
     return "\n".join(render_op(autogen_context, op))
 
 
 @renderers.dispatch_for(ops.ModifyTableOps)
-def _render_modify_table(autogen_context, op):
+def _render_modify_table(
+    autogen_context: "AutogenContext", op: "ModifyTableOps"
+) -> List[str]:
     opts = autogen_context.opts
     render_as_batch = opts.get("render_as_batch", False)
 
@@ -121,7 +164,9 @@ def _render_modify_table(autogen_context, op):
 
 
 @renderers.dispatch_for(ops.CreateTableCommentOp)
-def _render_create_table_comment(autogen_context, op):
+def _render_create_table_comment(
+    autogen_context: "AutogenContext", op: "ops.CreateTableCommentOp"
+) -> str:
 
     templ = (
         "{prefix}create_table_comment(\n"
@@ -144,7 +189,9 @@ def _render_create_table_comment(autogen_context, op):
 
 
 @renderers.dispatch_for(ops.DropTableCommentOp)
-def _render_drop_table_comment(autogen_context, op):
+def _render_drop_table_comment(
+    autogen_context: "AutogenContext", op: "ops.DropTableCommentOp"
+) -> str:
 
     templ = (
         "{prefix}drop_table_comment(\n"
@@ -165,7 +212,9 @@ def _render_drop_table_comment(autogen_context, op):
 
 
 @renderers.dispatch_for(ops.CreateTableOp)
-def _add_table(autogen_context, op):
+def _add_table(
+    autogen_context: "AutogenContext", op: "ops.CreateTableOp"
+) -> str:
     table = op.to_table()
 
     args = [
@@ -188,14 +237,14 @@ def _add_table(autogen_context, op):
     )
 
     if len(args) > MAX_PYTHON_ARGS:
-        args = "*[" + ",\n".join(args) + "]"
+        args_str = "*[" + ",\n".join(args) + "]"
     else:
-        args = ",\n".join(args)
+        args_str = ",\n".join(args)
 
     text = "%(prefix)screate_table(%(tablename)r,\n%(args)s" % {
         "tablename": _ident(op.table_name),
         "prefix": _alembic_autogenerate_prefix(autogen_context),
-        "args": args,
+        "args": args_str,
     }
     if op.schema:
         text += ",\nschema=%r" % _ident(op.schema)
@@ -215,7 +264,9 @@ def _add_table(autogen_context, op):
 
 
 @renderers.dispatch_for(ops.DropTableOp)
-def _drop_table(autogen_context, op):
+def _drop_table(
+    autogen_context: "AutogenContext", op: "ops.DropTableOp"
+) -> str:
     text = "%(prefix)sdrop_table(%(tname)r" % {
         "prefix": _alembic_autogenerate_prefix(autogen_context),
         "tname": _ident(op.table_name),
@@ -227,7 +278,9 @@ def _drop_table(autogen_context, op):
 
 
 @renderers.dispatch_for(ops.CreateIndexOp)
-def _add_index(autogen_context, op):
+def _add_index(
+    autogen_context: "AutogenContext", op: "ops.CreateIndexOp"
+) -> str:
     index = op.to_index()
 
     has_batch = autogen_context._has_batch
@@ -243,6 +296,7 @@ def _add_index(autogen_context, op):
             "unique=%(unique)r%(schema)s%(kwargs)s)"
         )
 
+    assert index.table is not None
     text = tmpl % {
         "prefix": _alembic_autogenerate_prefix(autogen_context),
         "name": _render_gen_name(autogen_context, index.name),
@@ -271,7 +325,9 @@ def _add_index(autogen_context, op):
 
 
 @renderers.dispatch_for(ops.DropIndexOp)
-def _drop_index(autogen_context, op):
+def _drop_index(
+    autogen_context: "AutogenContext", op: "ops.DropIndexOp"
+) -> str:
     index = op.to_index()
 
     has_batch = autogen_context._has_batch
@@ -306,12 +362,16 @@ def _drop_index(autogen_context, op):
 
 
 @renderers.dispatch_for(ops.CreateUniqueConstraintOp)
-def _add_unique_constraint(autogen_context, op):
+def _add_unique_constraint(
+    autogen_context: "AutogenContext", op: "ops.CreateUniqueConstraintOp"
+) -> List[str]:
     return [_uq_constraint(op.to_constraint(), autogen_context, True)]
 
 
 @renderers.dispatch_for(ops.CreateForeignKeyOp)
-def _add_fk_constraint(autogen_context, op):
+def _add_fk_constraint(
+    autogen_context: "AutogenContext", op: "ops.CreateForeignKeyOp"
+) -> str:
 
     args = [repr(_render_gen_name(autogen_context, op.constraint_name))]
     if not autogen_context._has_batch:
@@ -358,7 +418,9 @@ def _add_check_constraint(constraint, autogen_context):
 
 
 @renderers.dispatch_for(ops.DropConstraintOp)
-def _drop_constraint(autogen_context, op):
+def _drop_constraint(
+    autogen_context: "AutogenContext", op: "ops.DropConstraintOp"
+) -> str:
 
     if autogen_context._has_batch:
         template = "%(prefix)sdrop_constraint" "(%(name)r, type_=%(type)r)"
@@ -379,7 +441,9 @@ def _drop_constraint(autogen_context, op):
 
 
 @renderers.dispatch_for(ops.AddColumnOp)
-def _add_column(autogen_context, op):
+def _add_column(
+    autogen_context: "AutogenContext", op: "ops.AddColumnOp"
+) -> str:
 
     schema, tname, column = op.schema, op.table_name, op.column
     if autogen_context._has_batch:
@@ -399,7 +463,9 @@ def _add_column(autogen_context, op):
 
 
 @renderers.dispatch_for(ops.DropColumnOp)
-def _drop_column(autogen_context, op):
+def _drop_column(
+    autogen_context: "AutogenContext", op: "ops.DropColumnOp"
+) -> str:
 
     schema, tname, column_name = op.schema, op.table_name, op.column_name
 
@@ -421,7 +487,9 @@ def _drop_column(autogen_context, op):
 
 
 @renderers.dispatch_for(ops.AlterColumnOp)
-def _alter_column(autogen_context, op):
+def _alter_column(
+    autogen_context: "AutogenContext", op: "ops.AlterColumnOp"
+) -> str:
 
     tname = op.table_name
     cname = op.column_name
@@ -481,15 +549,15 @@ def _alter_column(autogen_context, op):
 
 
 class _f_name:
-    def __init__(self, prefix, name):
+    def __init__(self, prefix: str, name: conv) -> None:
         self.prefix = prefix
         self.name = name
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "%sf(%r)" % (self.prefix, _ident(self.name))
 
 
-def _ident(name):
+def _ident(name: Optional[Union["quoted_name", str]]) -> Optional[str]:
     """produce a __repr__() object for a string identifier that may
     use quoted_name() in SQLAlchemy 0.9 and greater.
 
@@ -506,8 +574,11 @@ def _ident(name):
 
 
 def _render_potential_expr(
-    value, autogen_context, wrap_in_text=True, is_server_default=False
-):
+    value: Any,
+    autogen_context: "AutogenContext",
+    wrap_in_text: bool = True,
+    is_server_default: bool = False,
+) -> str:
     if isinstance(value, sql.ClauseElement):
 
         if wrap_in_text:
@@ -526,7 +597,9 @@ def _render_potential_expr(
         return repr(value)
 
 
-def _get_index_rendered_expressions(idx, autogen_context):
+def _get_index_rendered_expressions(
+    idx: "Index", autogen_context: "AutogenContext"
+) -> List[str]:
     return [
         repr(_ident(getattr(exp, "name", None)))
         if isinstance(exp, sa_schema.Column)
@@ -535,8 +608,12 @@ def _get_index_rendered_expressions(idx, autogen_context):
     ]
 
 
-def _uq_constraint(constraint, autogen_context, alter):
-    opts = []
+def _uq_constraint(
+    constraint: "UniqueConstraint",
+    autogen_context: "AutogenContext",
+    alter: bool,
+) -> str:
+    opts: List[Tuple[str, Any]] = []
 
     has_batch = autogen_context._has_batch
 
@@ -578,18 +655,20 @@ def _user_autogenerate_prefix(autogen_context, target):
         return prefix
 
 
-def _sqlalchemy_autogenerate_prefix(autogen_context):
+def _sqlalchemy_autogenerate_prefix(autogen_context: "AutogenContext") -> str:
     return autogen_context.opts["sqlalchemy_module_prefix"] or ""
 
 
-def _alembic_autogenerate_prefix(autogen_context):
+def _alembic_autogenerate_prefix(autogen_context: "AutogenContext") -> str:
     if autogen_context._has_batch:
         return "batch_op."
     else:
         return autogen_context.opts["alembic_module_prefix"] or ""
 
 
-def _user_defined_render(type_, object_, autogen_context):
+def _user_defined_render(
+    type_: str, object_: Any, autogen_context: "AutogenContext"
+) -> Union[str, "Literal[False]"]:
     if "render_item" in autogen_context.opts:
         render = autogen_context.opts["render_item"]
         if render:
@@ -599,17 +678,17 @@ def _user_defined_render(type_, object_, autogen_context):
     return False
 
 
-def _render_column(column, autogen_context):
+def _render_column(column: "Column", autogen_context: "AutogenContext") -> str:
     rendered = _user_defined_render("column", column, autogen_context)
     if rendered is not False:
         return rendered
 
-    args = []
-    opts = []
+    args: List[str] = []
+    opts: List[Tuple[str, Any]] = []
 
     if column.server_default:
 
-        rendered = _render_server_default(
+        rendered = _render_server_default(  # type:ignore[assignment]
             column.server_default, autogen_context
         )
         if rendered:
@@ -655,21 +734,29 @@ def _render_column(column, autogen_context):
     }
 
 
-def _should_render_server_default_positionally(server_default):
+def _should_render_server_default_positionally(
+    server_default: Union["Computed", "DefaultClause"]
+) -> bool:
     return sqla_compat._server_default_is_computed(
         server_default
     ) or sqla_compat._server_default_is_identity(server_default)
 
 
-def _render_server_default(default, autogen_context, repr_=True):
+def _render_server_default(
+    default: Optional[
+        Union["FetchedValue", str, "TextClause", "ColumnElement"]
+    ],
+    autogen_context: "AutogenContext",
+    repr_: bool = True,
+) -> Optional[str]:
     rendered = _user_defined_render("server_default", default, autogen_context)
     if rendered is not False:
         return rendered
 
     if sqla_compat._server_default_is_computed(default):
-        return _render_computed(default, autogen_context)
+        return _render_computed(cast("Computed", default), autogen_context)
     elif sqla_compat._server_default_is_identity(default):
-        return _render_identity(default, autogen_context)
+        return _render_identity(cast("Identity", default), autogen_context)
     elif isinstance(default, sa_schema.DefaultClause):
         if isinstance(default.arg, compat.string_types):
             default = default.arg
@@ -681,10 +768,12 @@ def _render_server_default(default, autogen_context, repr_=True):
     if isinstance(default, string_types) and repr_:
         default = repr(re.sub(r"^'|'$", "", default))
 
-    return default
+    return cast(str, default)
 
 
-def _render_computed(computed, autogen_context):
+def _render_computed(
+    computed: "Computed", autogen_context: "AutogenContext"
+) -> str:
     text = _render_potential_expr(
         computed.sqltext, autogen_context, wrap_in_text=False
     )
@@ -699,7 +788,9 @@ def _render_computed(computed, autogen_context):
     }
 
 
-def _render_identity(identity, autogen_context):
+def _render_identity(
+    identity: "Identity", autogen_context: "AutogenContext"
+) -> str:
     # always=None means something different than always=False
     kwargs = OrderedDict(always=identity.always)
     if identity.on_null is not None:
@@ -712,7 +803,7 @@ def _render_identity(identity, autogen_context):
     }
 
 
-def _get_identity_options(identity_options):
+def _get_identity_options(identity_options: "Identity") -> OrderedDict:
     kwargs = OrderedDict()
     for attr in sqla_compat._identity_options_attrs:
         value = getattr(identity_options, attr, None)
@@ -721,7 +812,7 @@ def _get_identity_options(identity_options):
     return kwargs
 
 
-def _repr_type(type_, autogen_context):
+def _repr_type(type_: "TypeEngine", autogen_context: "AutogenContext") -> str:
     rendered = _user_defined_render("type", type_, autogen_context)
     if rendered is not False:
         return rendered
@@ -736,7 +827,9 @@ def _repr_type(type_, autogen_context):
     mod = type(type_).__module__
     imports = autogen_context.imports
     if mod.startswith("sqlalchemy.dialects"):
-        dname = re.match(r"sqlalchemy\.dialects\.(\w+)", mod).group(1)
+        match = re.match(r"sqlalchemy\.dialects\.(\w+)", mod)
+        assert match is not None
+        dname = match.group(1)
         if imports is not None:
             imports.add("from sqlalchemy.dialects import %s" % dname)
         if impl_rt:
@@ -759,14 +852,22 @@ def _repr_type(type_, autogen_context):
         return "%s%r" % (prefix, type_)
 
 
-def _render_ARRAY_type(type_, autogen_context):
-    return _render_type_w_subtype(
-        type_, autogen_context, "item_type", r"(.+?\()"
+def _render_ARRAY_type(
+    type_: "ARRAY", autogen_context: "AutogenContext"
+) -> str:
+    return cast(
+        str,
+        _render_type_w_subtype(
+            type_, autogen_context, "item_type", r"(.+?\()"
+        ),
     )
 
 
-def _render_Variant_type(type_, autogen_context):
+def _render_Variant_type(
+    type_: "Variant", autogen_context: "AutogenContext"
+) -> str:
     base = _repr_type(type_.impl, autogen_context)
+    assert base is not None and base is not False
     for dialect in sorted(type_.mapping):
         typ = type_.mapping[dialect]
         base += ".with_variant(%s, %r)" % (
@@ -777,8 +878,12 @@ def _render_Variant_type(type_, autogen_context):
 
 
 def _render_type_w_subtype(
-    type_, autogen_context, attrname, regexp, prefix=None
-):
+    type_: "TypeEngine",
+    autogen_context: "AutogenContext",
+    attrname: str,
+    regexp: str,
+    prefix: Optional[str] = None,
+) -> Union[Optional[str], "Literal[False]"]:
     outer_repr = repr(type_)
     inner_type = getattr(type_, attrname, None)
     if inner_type is None:
@@ -795,7 +900,9 @@ def _render_type_w_subtype(
 
     mod = type(type_).__module__
     if mod.startswith("sqlalchemy.dialects"):
-        dname = re.match(r"sqlalchemy\.dialects\.(\w+)", mod).group(1)
+        match = re.match(r"sqlalchemy\.dialects\.(\w+)", mod)
+        assert match is not None
+        dname = match.group(1)
         return "%s.%s" % (dname, outer_type)
     elif mod.startswith("sqlalchemy"):
         prefix = _sqlalchemy_autogenerate_prefix(autogen_context)
@@ -807,7 +914,11 @@ def _render_type_w_subtype(
 _constraint_renderers = util.Dispatcher()
 
 
-def _render_constraint(constraint, autogen_context, namespace_metadata):
+def _render_constraint(
+    constraint: "Constraint",
+    autogen_context: "AutogenContext",
+    namespace_metadata: Optional["MetaData"],
+) -> Optional[str]:
     try:
         renderer = _constraint_renderers.dispatch(constraint)
     except ValueError:
@@ -818,7 +929,11 @@ def _render_constraint(constraint, autogen_context, namespace_metadata):
 
 
 @_constraint_renderers.dispatch_for(sa_schema.PrimaryKeyConstraint)
-def _render_primary_key(constraint, autogen_context, namespace_metadata):
+def _render_primary_key(
+    constraint: "PrimaryKeyConstraint",
+    autogen_context: "AutogenContext",
+    namespace_metadata: Optional["MetaData"],
+) -> Optional[str]:
     rendered = _user_defined_render("primary_key", constraint, autogen_context)
     if rendered is not False:
         return rendered
@@ -840,12 +955,16 @@ def _render_primary_key(constraint, autogen_context, namespace_metadata):
     }
 
 
-def _fk_colspec(fk, metadata_schema, namespace_metadata):
+def _fk_colspec(
+    fk: "ForeignKey",
+    metadata_schema: Optional[str],
+    namespace_metadata: "MetaData",
+) -> str:
     """Implement a 'safe' version of ForeignKey._get_colspec() that
     won't fail if the remote table can't be resolved.
 
     """
-    colspec = fk._get_colspec()
+    colspec = fk._get_colspec()  # type:ignore[attr-defined]
     tokens = colspec.split(".")
     tname, colname = tokens[-2:]
 
@@ -873,7 +992,9 @@ def _fk_colspec(fk, metadata_schema, namespace_metadata):
     return colspec
 
 
-def _populate_render_fk_opts(constraint, opts):
+def _populate_render_fk_opts(
+    constraint: "ForeignKeyConstraint", opts: List[Tuple[str, str]]
+) -> None:
 
     if constraint.onupdate:
         opts.append(("onupdate", repr(constraint.onupdate)))
@@ -888,7 +1009,11 @@ def _populate_render_fk_opts(constraint, opts):
 
 
 @_constraint_renderers.dispatch_for(sa_schema.ForeignKeyConstraint)
-def _render_foreign_key(constraint, autogen_context, namespace_metadata):
+def _render_foreign_key(
+    constraint: "ForeignKeyConstraint",
+    autogen_context: "AutogenContext",
+    namespace_metadata: "MetaData",
+) -> Optional[str]:
     rendered = _user_defined_render("foreign_key", constraint, autogen_context)
     if rendered is not False:
         return rendered
@@ -908,7 +1033,8 @@ def _render_foreign_key(constraint, autogen_context, namespace_metadata):
         % {
             "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
             "cols": ", ".join(
-                "%r" % _ident(f.parent.name) for f in constraint.elements
+                "%r" % _ident(cast("Column", f.parent).name)
+                for f in constraint.elements
             ),
             "refcols": ", ".join(
                 repr(_fk_colspec(f, apply_metadata_schema, namespace_metadata))
@@ -922,7 +1048,11 @@ def _render_foreign_key(constraint, autogen_context, namespace_metadata):
 
 
 @_constraint_renderers.dispatch_for(sa_schema.UniqueConstraint)
-def _render_unique_constraint(constraint, autogen_context, namespace_metadata):
+def _render_unique_constraint(
+    constraint: "UniqueConstraint",
+    autogen_context: "AutogenContext",
+    namespace_metadata: Optional["MetaData"],
+) -> str:
     rendered = _user_defined_render("unique", constraint, autogen_context)
     if rendered is not False:
         return rendered
@@ -931,7 +1061,11 @@ def _render_unique_constraint(constraint, autogen_context, namespace_metadata):
 
 
 @_constraint_renderers.dispatch_for(sa_schema.CheckConstraint)
-def _render_check_constraint(constraint, autogen_context, namespace_metadata):
+def _render_check_constraint(
+    constraint: "CheckConstraint",
+    autogen_context: "AutogenContext",
+    namespace_metadata: Optional["MetaData"],
+) -> Optional[str]:
     rendered = _user_defined_render("check", constraint, autogen_context)
     if rendered is not False:
         return rendered
@@ -941,9 +1075,14 @@ def _render_check_constraint(constraint, autogen_context, namespace_metadata):
     # ideally SQLAlchemy would give us more of a first class
     # way to detect this.
     if (
-        constraint._create_rule
-        and hasattr(constraint._create_rule, "target")
-        and isinstance(constraint._create_rule.target, sqltypes.TypeEngine)
+        constraint._create_rule  # type:ignore[attr-defined]
+        and hasattr(
+            constraint._create_rule, "target"  # type:ignore[attr-defined]
+        )
+        and isinstance(
+            constraint._create_rule.target,  # type:ignore[attr-defined]
+            sqltypes.TypeEngine,
+        )
     ):
         return None
     opts = []
@@ -963,7 +1102,9 @@ def _render_check_constraint(constraint, autogen_context, namespace_metadata):
 
 
 @renderers.dispatch_for(ops.ExecuteSQLOp)
-def _execute_sql(autogen_context, op):
+def _execute_sql(
+    autogen_context: "AutogenContext", op: "ops.ExecuteSQLOp"
+) -> str:
     if not isinstance(op.sqltext, string_types):
         raise NotImplementedError(
             "Autogenerate rendering of SQL Expression language constructs "
index ba9a06dd54829fa01419175b0be00db19ca8b424..0fdd3982776783154ceca60a735e9199105a3665 100644 (file)
@@ -1,6 +1,25 @@
+from typing import Any
+from typing import Callable
+from typing import Iterator
+from typing import List
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
+
 from alembic import util
 from alembic.operations import ops
 
+if TYPE_CHECKING:
+    from alembic.operations.ops import AddColumnOp
+    from alembic.operations.ops import AlterColumnOp
+    from alembic.operations.ops import CreateTableOp
+    from alembic.operations.ops import MigrateOperation
+    from alembic.operations.ops import MigrationScript
+    from alembic.operations.ops import ModifyTableOps
+    from alembic.operations.ops import OpContainer
+    from alembic.runtime.migration import MigrationContext
+    from alembic.script.revision import Revision
+
 
 class Rewriter:
     """A helper object that allows easy 'rewriting' of ops streams.
@@ -32,10 +51,10 @@ class Rewriter:
 
     _chained = None
 
-    def __init__(self):
+    def __init__(self) -> None:
         self.dispatch = util.Dispatcher()
 
-    def chain(self, other):
+    def chain(self, other: "Rewriter") -> "Rewriter":
         """Produce a "chain" of this :class:`.Rewriter` to another.
 
         This allows two rewriters to operate serially on a stream,
@@ -70,7 +89,16 @@ class Rewriter:
         wr._chained = other
         return wr
 
-    def rewrites(self, operator):
+    def rewrites(
+        self,
+        operator: Union[
+            Type["AddColumnOp"],
+            Type["MigrateOperation"],
+            Type["AlterColumnOp"],
+            Type["CreateTableOp"],
+            Type["ModifyTableOps"],
+        ],
+    ) -> Callable:
         """Register a function as rewriter for a given type.
 
         The function should receive three arguments, which are
@@ -85,7 +113,12 @@ class Rewriter:
         """
         return self.dispatch.dispatch_for(operator)
 
-    def _rewrite(self, context, revision, directive):
+    def _rewrite(
+        self,
+        context: "MigrationContext",
+        revision: "Revision",
+        directive: "MigrateOperation",
+    ) -> Iterator["MigrateOperation"]:
         try:
             _rewriter = self.dispatch.dispatch(directive)
         except ValueError:
@@ -96,20 +129,30 @@ class Rewriter:
                 yield directive
             else:
                 for r_directive in util.to_list(
-                    _rewriter(context, revision, directive)
+                    _rewriter(context, revision, directive), []
                 ):
                     r_directive._mutations = r_directive._mutations.union(
                         [self]
                     )
                     yield r_directive
 
-    def __call__(self, context, revision, directives):
+    def __call__(
+        self,
+        context: "MigrationContext",
+        revision: "Revision",
+        directives: List["MigrationScript"],
+    ) -> None:
         self.process_revision_directives(context, revision, directives)
         if self._chained:
             self._chained(context, revision, directives)
 
     @_traverse.dispatch_for(ops.MigrationScript)
-    def _traverse_script(self, context, revision, directive):
+    def _traverse_script(
+        self,
+        context: "MigrationContext",
+        revision: "Revision",
+        directive: "MigrationScript",
+    ) -> None:
         upgrade_ops_list = []
         for upgrade_ops in directive.upgrade_ops_list:
             ret = self._traverse_for(context, revision, upgrade_ops)
@@ -131,26 +174,51 @@ class Rewriter:
         directive.downgrade_ops = downgrade_ops_list
 
     @_traverse.dispatch_for(ops.OpContainer)
-    def _traverse_op_container(self, context, revision, directive):
+    def _traverse_op_container(
+        self,
+        context: "MigrationContext",
+        revision: "Revision",
+        directive: "OpContainer",
+    ) -> None:
         self._traverse_list(context, revision, directive.ops)
 
     @_traverse.dispatch_for(ops.MigrateOperation)
-    def _traverse_any_directive(self, context, revision, directive):
+    def _traverse_any_directive(
+        self,
+        context: "MigrationContext",
+        revision: "Revision",
+        directive: "MigrateOperation",
+    ) -> None:
         pass
 
-    def _traverse_for(self, context, revision, directive):
+    def _traverse_for(
+        self,
+        context: "MigrationContext",
+        revision: "Revision",
+        directive: "MigrateOperation",
+    ) -> Any:
         directives = list(self._rewrite(context, revision, directive))
         for directive in directives:
             traverser = self._traverse.dispatch(directive)
             traverser(self, context, revision, directive)
         return directives
 
-    def _traverse_list(self, context, revision, directives):
+    def _traverse_list(
+        self,
+        context: "MigrationContext",
+        revision: "Revision",
+        directives: Any,
+    ) -> None:
         dest = []
         for directive in directives:
             dest.extend(self._traverse_for(context, revision, directive))
 
         directives[:] = dest
 
-    def process_revision_directives(self, context, revision, directives):
+    def process_revision_directives(
+        self,
+        context: "MigrationContext",
+        revision: "Revision",
+        directives: List["MigrationScript"],
+    ) -> None:
         self._traverse_list(context, revision, directives)
index ada458d3b00d34d4fb5f38fc26c41478968c7c3e..1e794602b4c7711bb80f7b3af1ce22b0ba9cc8c0 100644 (file)
@@ -1,10 +1,20 @@
 import os
+from typing import Callable
+from typing import cast
+from typing import List
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
 
 from . import autogenerate as autogen
 from . import util
 from .runtime.environment import EnvironmentContext
 from .script import ScriptDirectory
 
+if TYPE_CHECKING:
+    from alembic.config import Config
+    from alembic.script.base import Script
+
 
 def list_templates(config):
     """List available templates.
@@ -25,7 +35,12 @@ def list_templates(config):
     config.print_stdout("\n  alembic init --template generic ./scripts")
 
 
-def init(config, directory, template="generic", package=False):
+def init(
+    config: "Config",
+    directory: str,
+    template: str = "generic",
+    package: bool = False,
+) -> None:
     """Initialize a new scripts directory.
 
     :param config: a :class:`.Config` object.
@@ -71,8 +86,8 @@ def init(config, directory, template="generic", package=False):
     for file_ in os.listdir(template_dir):
         file_path = os.path.join(template_dir, file_)
         if file_ == "alembic.ini.mako":
-            config_file = os.path.abspath(config.config_file_name)
-            if os.access(config_file, os.F_OK):
+            config_file = os.path.abspath(cast(str, config.config_file_name))
+            if os.access(cast(str, config_file), os.F_OK):
                 util.msg("File %s already exists, skipping" % config_file)
             else:
                 script._generate_template(
@@ -88,7 +103,7 @@ def init(config, directory, template="generic", package=False):
             os.path.join(os.path.abspath(versions), "__init__.py"),
         ]:
             file_ = util.status("Adding %s" % path, open, path, "w")
-            file_.close()
+            file_.close()  # type:ignore[attr-defined]
 
     util.msg(
         "Please edit configuration/connection/logging "
@@ -97,18 +112,18 @@ def init(config, directory, template="generic", package=False):
 
 
 def revision(
-    config,
-    message=None,
-    autogenerate=False,
-    sql=False,
-    head="head",
-    splice=False,
-    branch_label=None,
-    version_path=None,
-    rev_id=None,
-    depends_on=None,
-    process_revision_directives=None,
-):
+    config: "Config",
+    message: Optional[str] = None,
+    autogenerate: bool = False,
+    sql: bool = False,
+    head: str = "head",
+    splice: bool = False,
+    branch_label: Optional[str] = None,
+    version_path: Optional[str] = None,
+    rev_id: Optional[str] = None,
+    depends_on: Optional[str] = None,
+    process_revision_directives: Callable = None,
+) -> Union[Optional["Script"], List[Optional["Script"]]]:
     """Create a new revision file.
 
     :param config: a :class:`.Config` object.
@@ -223,7 +238,13 @@ def revision(
         return scripts
 
 
-def merge(config, revisions, message=None, branch_label=None, rev_id=None):
+def merge(
+    config: "Config",
+    revisions: str,
+    message: str = None,
+    branch_label: str = None,
+    rev_id: str = None,
+) -> Optional["Script"]:
     """Merge two revisions together.  Creates a new migration file.
 
     :param config: a :class:`.Config` instance
@@ -243,7 +264,7 @@ def merge(config, revisions, message=None, branch_label=None, rev_id=None):
 
     script = ScriptDirectory.from_config(config)
     template_args = {
-        "config": config  # Let templates use config for
+        "config": "config"  # Let templates use config for
         # e.g. multiple databases
     }
     return script.generate_revision(
@@ -252,11 +273,16 @@ def merge(config, revisions, message=None, branch_label=None, rev_id=None):
         refresh=True,
         head=revisions,
         branch_labels=branch_label,
-        **template_args
+        **template_args  # type:ignore[arg-type]
     )
 
 
-def upgrade(config, revision, sql=False, tag=None):
+def upgrade(
+    config: "Config",
+    revision: str,
+    sql: bool = False,
+    tag: Optional[str] = None,
+) -> None:
     """Upgrade to a later version.
 
     :param config: a :class:`.Config` instance.
@@ -294,7 +320,12 @@ def upgrade(config, revision, sql=False, tag=None):
         script.run_env()
 
 
-def downgrade(config, revision, sql=False, tag=None):
+def downgrade(
+    config: "Config",
+    revision: str,
+    sql: bool = False,
+    tag: Optional[str] = None,
+) -> None:
     """Revert to a previous version.
 
     :param config: a :class:`.Config` instance.
@@ -360,7 +391,12 @@ def show(config, rev):
             config.print_stdout(sc.log_entry)
 
 
-def history(config, rev_range=None, verbose=False, indicate_current=False):
+def history(
+    config: "Config",
+    rev_range: Optional[str] = None,
+    verbose: bool = False,
+    indicate_current: bool = False,
+) -> None:
     """List changeset scripts in chronological order.
 
     :param config: a :class:`.Config` instance.
@@ -372,7 +408,8 @@ def history(config, rev_range=None, verbose=False, indicate_current=False):
     :param indicate_current: indicate current revision.
 
     """
-
+    base: Optional[str]
+    head: Optional[str]
     script = ScriptDirectory.from_config(config)
     if rev_range is not None:
         if ":" not in rev_range:
@@ -478,7 +515,7 @@ def branches(config, verbose=False):
             )
 
 
-def current(config, verbose=False):
+def current(config: "Config", verbose: bool = False) -> None:
     """Display the current revision for a database.
 
     :param config: a :class:`.Config` instance.
@@ -506,7 +543,13 @@ def current(config, verbose=False):
         script.run_env()
 
 
-def stamp(config, revision, sql=False, tag=None, purge=False):
+def stamp(
+    config: "Config",
+    revision: str,
+    sql: bool = False,
+    tag: Optional[str] = None,
+    purge: bool = False,
+) -> None:
     """'stamp' the revision table with the given revision; don't
     run any migrations.
 
@@ -570,7 +613,7 @@ def stamp(config, revision, sql=False, tag=None, purge=False):
         script.run_env()
 
 
-def edit(config, rev):
+def edit(config: "Config", rev: str) -> None:
     """Edit revision script(s) using $EDITOR.
 
     :param config: a :class:`.Config` instance.
index b8b465d1283bf897c8766f1947482de0fd5f6929..dbcd106ff9f11ec636f49f2d41cce0de8ca511c6 100644 (file)
@@ -1,8 +1,13 @@
 from argparse import ArgumentParser
+from argparse import Namespace
 from configparser import ConfigParser
 import inspect
 import os
 import sys
+from typing import Dict
+from typing import Optional
+from typing import overload
+from typing import TextIO
 
 from . import __version__
 from . import command
@@ -86,14 +91,14 @@ class Config:
 
     def __init__(
         self,
-        file_=None,
-        ini_section="alembic",
-        output_buffer=None,
-        stdout=sys.stdout,
-        cmd_opts=None,
-        config_args=util.immutabledict(),
-        attributes=None,
-    ):
+        file_: Optional[str] = None,
+        ini_section: str = "alembic",
+        output_buffer: Optional[TextIO] = None,
+        stdout: TextIO = sys.stdout,
+        cmd_opts: Optional[Namespace] = None,
+        config_args: util.immutabledict = util.immutabledict(),
+        attributes: dict = None,
+    ) -> None:
         """Construct a new :class:`.Config`"""
         self.config_file_name = file_
         self.config_ini_section = ini_section
@@ -104,7 +109,7 @@ class Config:
         if attributes:
             self.attributes.update(attributes)
 
-    cmd_opts = None
+    cmd_opts: Optional[Namespace] = None
     """The command-line options passed to the ``alembic`` script.
 
     Within an ``env.py`` script this can be accessed via the
@@ -116,10 +121,10 @@ class Config:
 
     """
 
-    config_file_name = None
+    config_file_name: Optional[str] = None
     """Filesystem path to the .ini file in use."""
 
-    config_ini_section = None
+    config_ini_section: str = None  # type:ignore[assignment]
     """Name of the config file section to read basic configuration
     from.  Defaults to ``alembic``, that is the ``[alembic]`` section
     of the .ini file.  This value is modified using the ``-n/--name``
@@ -147,7 +152,7 @@ class Config:
         """
         return {}
 
-    def print_stdout(self, text, *arg):
+    def print_stdout(self, text: str, *arg) -> None:
         """Render a message to standard out.
 
         When :meth:`.Config.print_stdout` is called with additional args
@@ -191,7 +196,7 @@ class Config:
             file_config.add_section(self.config_ini_section)
         return file_config
 
-    def get_template_directory(self):
+    def get_template_directory(self) -> str:
         """Return the directory where Alembic setup templates are found.
 
         This method is used by the alembic ``init`` and ``list_templates``
@@ -203,7 +208,19 @@ class Config:
         package_dir = os.path.abspath(os.path.dirname(alembic.__file__))
         return os.path.join(package_dir, "templates")
 
-    def get_section(self, name, default=None):
+    @overload
+    def get_section(
+        self, name: str, default: Dict[str, str]
+    ) -> Dict[str, str]:
+        ...
+
+    @overload
+    def get_section(
+        self, name: str, default: Optional[Dict[str, str]] = ...
+    ) -> Optional[Dict[str, str]]:
+        ...
+
+    def get_section(self, name: str, default=None):
         """Return all the configuration options from a given .ini file section
         as a dictionary.
 
@@ -213,7 +230,7 @@ class Config:
 
         return dict(self.file_config.items(name))
 
-    def set_main_option(self, name, value):
+    def set_main_option(self, name: str, value: str) -> None:
         """Set an option programmatically within the 'main' section.
 
         This overrides whatever was in the .ini file.
@@ -230,10 +247,10 @@ class Config:
         """
         self.set_section_option(self.config_ini_section, name, value)
 
-    def remove_main_option(self, name):
+    def remove_main_option(self, name: str) -> None:
         self.file_config.remove_option(self.config_ini_section, name)
 
-    def set_section_option(self, section, name, value):
+    def set_section_option(self, section: str, name: str, value: str) -> None:
         """Set an option programmatically within the given section.
 
         The section is created if it doesn't exist already.
@@ -257,7 +274,9 @@ class Config:
             self.file_config.add_section(section)
         self.file_config.set(section, name, value)
 
-    def get_section_option(self, section, name, default=None):
+    def get_section_option(
+        self, section: str, name: str, default: Optional[str] = None
+    ) -> Optional[str]:
         """Return an option from the given section of the .ini file."""
         if not self.file_config.has_section(section):
             raise util.CommandError(
@@ -269,6 +288,16 @@ class Config:
         else:
             return default
 
+    @overload
+    def get_main_option(self, name: str, default: str) -> str:
+        ...
+
+    @overload
+    def get_main_option(
+        self, name: str, default: Optional[str] = None
+    ) -> Optional[str]:
+        ...
+
     def get_main_option(self, name, default=None):
         """Return an option from the 'main' section of the .ini file.
 
@@ -281,10 +310,10 @@ class Config:
 
 
 class CommandLine:
-    def __init__(self, prog=None):
+    def __init__(self, prog: Optional[str] = None) -> None:
         self._generate_args(prog)
 
-    def _generate_args(self, prog):
+    def _generate_args(self, prog: Optional[str]) -> None:
         def add_options(fn, parser, positional, kwargs):
             kwargs_opts = {
                 "template": (
@@ -515,7 +544,7 @@ class CommandLine:
                         else:
                             help_text.append(line.strip())
                 else:
-                    help_text = ""
+                    help_text = []
                 subparser = subparsers.add_parser(
                     fn.__name__, help=" ".join(help_text)
                 )
@@ -523,7 +552,7 @@ class CommandLine:
                 subparser.set_defaults(cmd=(fn, positional, kwarg))
         self.parser = parser
 
-    def run_cmd(self, config, options):
+    def run_cmd(self, config: Config, options: Namespace) -> None:
         fn, positional, kwarg = options.cmd
 
         try:
index da81c72206cbab03fbe4061df636b100a3c454ab..022dc244d3294dad6da6f4d75713f1116aada4f0 100644 (file)
@@ -1,4 +1,7 @@
 import functools
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
 
 from sqlalchemy import exc
 from sqlalchemy import Integer
@@ -14,6 +17,20 @@ from ..util.sqla_compat import _fk_spec  # noqa
 from ..util.sqla_compat import _is_type_bound  # noqa
 from ..util.sqla_compat import _table_for_constraint  # noqa
 
+if TYPE_CHECKING:
+    from sqlalchemy.sql.compiler import Compiled
+    from sqlalchemy.sql.compiler import DDLCompiler
+    from sqlalchemy.sql.elements import TextClause
+    from sqlalchemy.sql.functions import Function
+    from sqlalchemy.sql.schema import FetchedValue
+    from sqlalchemy.sql.type_api import TypeEngine
+
+    from .impl import DefaultImpl
+    from ..util.sqla_compat import Computed
+    from ..util.sqla_compat import Identity
+
+_ServerDefault = Union["TextClause", "FetchedValue", "Function", str]
+
 
 class AlterTable(DDLElement):
 
@@ -24,13 +41,22 @@ class AlterTable(DDLElement):
 
     """
 
-    def __init__(self, table_name, schema=None):
+    def __init__(
+        self,
+        table_name: str,
+        schema: Optional[Union["quoted_name", str]] = None,
+    ) -> None:
         self.table_name = table_name
         self.schema = schema
 
 
 class RenameTable(AlterTable):
-    def __init__(self, old_table_name, new_table_name, schema=None):
+    def __init__(
+        self,
+        old_table_name: str,
+        new_table_name: Union["quoted_name", str],
+        schema: Optional[Union["quoted_name", str]] = None,
+    ) -> None:
         super(RenameTable, self).__init__(old_table_name, schema=schema)
         self.new_table_name = new_table_name
 
@@ -38,14 +64,14 @@ class RenameTable(AlterTable):
 class AlterColumn(AlterTable):
     def __init__(
         self,
-        name,
-        column_name,
-        schema=None,
-        existing_type=None,
-        existing_nullable=None,
-        existing_server_default=None,
-        existing_comment=None,
-    ):
+        name: str,
+        column_name: str,
+        schema: Optional[str] = None,
+        existing_type: Optional["TypeEngine"] = None,
+        existing_nullable: Optional[bool] = None,
+        existing_server_default: Optional[_ServerDefault] = None,
+        existing_comment: Optional[str] = None,
+    ) -> None:
         super(AlterColumn, self).__init__(name, schema=schema)
         self.column_name = column_name
         self.existing_type = (
@@ -59,62 +85,94 @@ class AlterColumn(AlterTable):
 
 
 class ColumnNullable(AlterColumn):
-    def __init__(self, name, column_name, nullable, **kw):
+    def __init__(
+        self, name: str, column_name: str, nullable: bool, **kw
+    ) -> None:
         super(ColumnNullable, self).__init__(name, column_name, **kw)
         self.nullable = nullable
 
 
 class ColumnType(AlterColumn):
-    def __init__(self, name, column_name, type_, **kw):
+    def __init__(
+        self, name: str, column_name: str, type_: "TypeEngine", **kw
+    ) -> None:
         super(ColumnType, self).__init__(name, column_name, **kw)
         self.type_ = sqltypes.to_instance(type_)
 
 
 class ColumnName(AlterColumn):
-    def __init__(self, name, column_name, newname, **kw):
+    def __init__(
+        self, name: str, column_name: str, newname: str, **kw
+    ) -> None:
         super(ColumnName, self).__init__(name, column_name, **kw)
         self.newname = newname
 
 
 class ColumnDefault(AlterColumn):
-    def __init__(self, name, column_name, default, **kw):
+    def __init__(
+        self,
+        name: str,
+        column_name: str,
+        default: Optional[_ServerDefault],
+        **kw
+    ) -> None:
         super(ColumnDefault, self).__init__(name, column_name, **kw)
         self.default = default
 
 
 class ComputedColumnDefault(AlterColumn):
-    def __init__(self, name, column_name, default, **kw):
+    def __init__(
+        self, name: str, column_name: str, default: Optional["Computed"], **kw
+    ) -> None:
         super(ComputedColumnDefault, self).__init__(name, column_name, **kw)
         self.default = default
 
 
 class IdentityColumnDefault(AlterColumn):
-    def __init__(self, name, column_name, default, impl, **kw):
+    def __init__(
+        self,
+        name: str,
+        column_name: str,
+        default: Optional["Identity"],
+        impl: "DefaultImpl",
+        **kw
+    ) -> None:
         super(IdentityColumnDefault, self).__init__(name, column_name, **kw)
         self.default = default
         self.impl = impl
 
 
 class AddColumn(AlterTable):
-    def __init__(self, name, column, schema=None):
+    def __init__(
+        self,
+        name: str,
+        column: "Column",
+        schema: Optional[Union["quoted_name", str]] = None,
+    ) -> None:
         super(AddColumn, self).__init__(name, schema=schema)
         self.column = column
 
 
 class DropColumn(AlterTable):
-    def __init__(self, name, column, schema=None):
+    def __init__(
+        self, name: str, column: "Column", schema: Optional[str] = None
+    ) -> None:
         super(DropColumn, self).__init__(name, schema=schema)
         self.column = column
 
 
 class ColumnComment(AlterColumn):
-    def __init__(self, name, column_name, comment, **kw):
+    def __init__(
+        self, name: str, column_name: str, comment: Optional[str], **kw
+    ) -> None:
         super(ColumnComment, self).__init__(name, column_name, **kw)
         self.comment = comment
 
 
 @compiles(RenameTable)
-def visit_rename_table(element, compiler, **kw):
+def visit_rename_table(
+    element: "RenameTable", compiler: "DDLCompiler", **kw
+) -> str:
     return "%s RENAME TO %s" % (
         alter_table(compiler, element.table_name, element.schema),
         format_table_name(compiler, element.new_table_name, element.schema),
@@ -122,7 +180,9 @@ def visit_rename_table(element, compiler, **kw):
 
 
 @compiles(AddColumn)
-def visit_add_column(element, compiler, **kw):
+def visit_add_column(
+    element: "AddColumn", compiler: "DDLCompiler", **kw
+) -> str:
     return "%s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         add_column(compiler, element.column, **kw),
@@ -130,7 +190,9 @@ def visit_add_column(element, compiler, **kw):
 
 
 @compiles(DropColumn)
-def visit_drop_column(element, compiler, **kw):
+def visit_drop_column(
+    element: "DropColumn", compiler: "DDLCompiler", **kw
+) -> str:
     return "%s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         drop_column(compiler, element.column.name, **kw),
@@ -138,7 +200,9 @@ def visit_drop_column(element, compiler, **kw):
 
 
 @compiles(ColumnNullable)
-def visit_column_nullable(element, compiler, **kw):
+def visit_column_nullable(
+    element: "ColumnNullable", compiler: "DDLCompiler", **kw
+) -> str:
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
@@ -147,7 +211,9 @@ def visit_column_nullable(element, compiler, **kw):
 
 
 @compiles(ColumnType)
-def visit_column_type(element, compiler, **kw):
+def visit_column_type(
+    element: "ColumnType", compiler: "DDLCompiler", **kw
+) -> str:
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
@@ -156,7 +222,9 @@ def visit_column_type(element, compiler, **kw):
 
 
 @compiles(ColumnName)
-def visit_column_name(element, compiler, **kw):
+def visit_column_name(
+    element: "ColumnName", compiler: "DDLCompiler", **kw
+) -> str:
     return "%s RENAME %s TO %s" % (
         alter_table(compiler, element.table_name, element.schema),
         format_column_name(compiler, element.column_name),
@@ -165,7 +233,9 @@ def visit_column_name(element, compiler, **kw):
 
 
 @compiles(ColumnDefault)
-def visit_column_default(element, compiler, **kw):
+def visit_column_default(
+    element: "ColumnDefault", compiler: "DDLCompiler", **kw
+) -> str:
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
@@ -176,7 +246,9 @@ def visit_column_default(element, compiler, **kw):
 
 
 @compiles(ComputedColumnDefault)
-def visit_computed_column(element, compiler, **kw):
+def visit_computed_column(
+    element: "ComputedColumnDefault", compiler: "DDLCompiler", **kw
+):
     raise exc.CompileError(
         'Adding or removing a "computed" construct, e.g. GENERATED '
         "ALWAYS AS, to or from an existing column is not supported."
@@ -184,7 +256,9 @@ def visit_computed_column(element, compiler, **kw):
 
 
 @compiles(IdentityColumnDefault)
-def visit_identity_column(element, compiler, **kw):
+def visit_identity_column(
+    element: "IdentityColumnDefault", compiler: "DDLCompiler", **kw
+):
     raise exc.CompileError(
         'Adding, removing or modifying an "identity" construct, '
         "e.g. GENERATED AS IDENTITY, to or from an existing "
@@ -192,7 +266,9 @@ def visit_identity_column(element, compiler, **kw):
     )
 
 
-def quote_dotted(name, quote):
+def quote_dotted(
+    name: Union["quoted_name", str], quote: functools.partial
+) -> Union["quoted_name", str]:
     """quote the elements of a dotted name"""
 
     if isinstance(name, quoted_name):
@@ -201,7 +277,11 @@ def quote_dotted(name, quote):
     return result
 
 
-def format_table_name(compiler, name, schema):
+def format_table_name(
+    compiler: "Compiled",
+    name: Union["quoted_name", str],
+    schema: Optional[Union["quoted_name", str]],
+) -> Union["quoted_name", str]:
     quote = functools.partial(compiler.preparer.quote)
     if schema:
         return quote_dotted(schema, quote) + "." + quote(name)
@@ -209,33 +289,42 @@ def format_table_name(compiler, name, schema):
         return quote(name)
 
 
-def format_column_name(compiler, name):
+def format_column_name(
+    compiler: "DDLCompiler", name: Optional[Union["quoted_name", str]]
+) -> Union["quoted_name", str]:
     return compiler.preparer.quote(name)
 
 
-def format_server_default(compiler, default):
+def format_server_default(
+    compiler: "DDLCompiler",
+    default: Optional[_ServerDefault],
+) -> str:
     return compiler.get_column_default_string(
         Column("x", Integer, server_default=default)
     )
 
 
-def format_type(compiler, type_):
+def format_type(compiler: "DDLCompiler", type_: "TypeEngine") -> str:
     return compiler.dialect.type_compiler.process(type_)
 
 
-def alter_table(compiler, name, schema):
+def alter_table(
+    compiler: "DDLCompiler",
+    name: str,
+    schema: Optional[str],
+) -> str:
     return "ALTER TABLE %s" % format_table_name(compiler, name, schema)
 
 
-def drop_column(compiler, name):
+def drop_column(compiler: "DDLCompiler", name: str, **kw) -> str:
     return "DROP COLUMN %s" % format_column_name(compiler, name)
 
 
-def alter_column(compiler, name):
+def alter_column(compiler: "DDLCompiler", name: str) -> str:
     return "ALTER COLUMN %s" % format_column_name(compiler, name)
 
 
-def add_column(compiler, column, **kw):
+def add_column(compiler: "DDLCompiler", column: "Column", **kw) -> str:
     text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
 
     const = " ".join(
index 710509c2fe0d39d5dffa5b1cba623f1291b3b19f..2ca316c7f47226293369486b10d654a5e34dea64 100644 (file)
@@ -1,5 +1,16 @@
 from collections import namedtuple
 import re
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
 
 from sqlalchemy import cast
 from sqlalchemy import schema
@@ -11,16 +22,49 @@ from ..util import sqla_compat
 from ..util.compat import string_types
 from ..util.compat import text_type
 
+if TYPE_CHECKING:
+    from io import StringIO
+    from typing import Literal
+
+    from sqlalchemy.engine import Connection
+    from sqlalchemy.engine import Dialect
+    from sqlalchemy.engine.cursor import CursorResult
+    from sqlalchemy.engine.cursor import LegacyCursorResult
+    from sqlalchemy.engine.reflection import Inspector
+    from sqlalchemy.sql.dml import Update
+    from sqlalchemy.sql.elements import ClauseElement
+    from sqlalchemy.sql.elements import ColumnElement
+    from sqlalchemy.sql.elements import quoted_name
+    from sqlalchemy.sql.elements import TextClause
+    from sqlalchemy.sql.schema import Column
+    from sqlalchemy.sql.schema import Constraint
+    from sqlalchemy.sql.schema import ForeignKeyConstraint
+    from sqlalchemy.sql.schema import Index
+    from sqlalchemy.sql.schema import Table
+    from sqlalchemy.sql.schema import UniqueConstraint
+    from sqlalchemy.sql.selectable import TableClause
+    from sqlalchemy.sql.type_api import TypeEngine
+
+    from .base import _ServerDefault
+    from ..autogenerate.api import AutogenContext
+    from ..operations.batch import ApplyBatchImpl
+    from ..operations.batch import BatchOperationsImpl
+
 
 class ImplMeta(type):
-    def __init__(cls, classname, bases, dict_):
+    def __init__(
+        cls,
+        classname: str,
+        bases: Tuple[Type["DefaultImpl"]],
+        dict_: Dict[str, Any],
+    ):
         newtype = type.__init__(cls, classname, bases, dict_)
         if "__dialect__" in dict_:
             _impls[dict_["__dialect__"]] = cls
         return newtype
 
 
-_impls = {}
+_impls: dict = {}
 
 Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
 
@@ -43,27 +87,27 @@ class DefaultImpl(metaclass=ImplMeta):
 
     transactional_ddl = False
     command_terminator = ";"
-    type_synonyms = ({"NUMERIC", "DECIMAL"},)
-    type_arg_extract = ()
+    type_synonyms: Tuple[Set[str], ...] = ({"NUMERIC", "DECIMAL"},)
+    type_arg_extract: Sequence[str] = ()
     # on_null is known to be supported only by oracle
-    identity_attrs_ignore = ("on_null",)
+    identity_attrs_ignore: Tuple[str, ...] = ("on_null",)
 
     def __init__(
         self,
-        dialect,
-        connection,
-        as_sql,
-        transactional_ddl,
-        output_buffer,
-        context_opts,
-    ):
+        dialect: "Dialect",
+        connection: Optional["Connection"],
+        as_sql: bool,
+        transactional_ddl: Optional[bool],
+        output_buffer: Optional["StringIO"],
+        context_opts: Dict[str, Any],
+    ) -> None:
         self.dialect = dialect
         self.connection = connection
         self.as_sql = as_sql
         self.literal_binds = context_opts.get("literal_binds", False)
 
         self.output_buffer = output_buffer
-        self.memo = {}
+        self.memo: dict = {}
         self.context_opts = context_opts
         if transactional_ddl is not None:
             self.transactional_ddl = transactional_ddl
@@ -75,14 +119,17 @@ class DefaultImpl(metaclass=ImplMeta):
                 )
 
     @classmethod
-    def get_by_dialect(cls, dialect):
+    def get_by_dialect(cls, dialect: "Dialect") -> Any:
         return _impls[dialect.name]
 
-    def static_output(self, text):
+    def static_output(self, text: str) -> None:
+        assert self.output_buffer is not None
         self.output_buffer.write(text_type(text + "\n\n"))
         self.output_buffer.flush()
 
-    def requires_recreate_in_batch(self, batch_op):
+    def requires_recreate_in_batch(
+        self, batch_op: "BatchOperationsImpl"
+    ) -> bool:
         """Return True if the given :class:`.BatchOperationsImpl`
         would need the table to be recreated and copied in order to
         proceed.
@@ -93,7 +140,9 @@ class DefaultImpl(metaclass=ImplMeta):
         """
         return False
 
-    def prep_table_for_batch(self, batch_impl, table):
+    def prep_table_for_batch(
+        self, batch_impl: "ApplyBatchImpl", table: "Table"
+    ) -> None:
         """perform any operations needed on a table before a new
         one is created to replace it in batch mode.
 
@@ -103,16 +152,16 @@ class DefaultImpl(metaclass=ImplMeta):
         """
 
     @property
-    def bind(self):
+    def bind(self) -> Optional["Connection"]:
         return self.connection
 
     def _exec(
         self,
-        construct,
-        execution_options=None,
-        multiparams=(),
-        params=util.immutabledict(),
-    ):
+        construct: Union["ClauseElement", str],
+        execution_options: None = None,
+        multiparams: Sequence[dict] = (),
+        params: Dict[str, int] = util.immutabledict(),
+    ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
         if isinstance(construct, string_types):
             construct = text(construct)
         if self.as_sql:
@@ -135,35 +184,43 @@ class DefaultImpl(metaclass=ImplMeta):
                 .strip()
                 + self.command_terminator
             )
+            return None
         else:
             conn = self.connection
+            assert conn is not None
             if execution_options:
                 conn = conn.execution_options(**execution_options)
             if params:
+                assert isinstance(multiparams, tuple)
                 multiparams += (params,)
 
             return conn.execute(construct, multiparams)
 
-    def execute(self, sql, execution_options=None):
+    def execute(
+        self,
+        sql: Union["Update", "TextClause", str],
+        execution_options: None = None,
+    ) -> None:
         self._exec(sql, execution_options)
 
     def alter_column(
         self,
-        table_name,
-        column_name,
-        nullable=None,
-        server_default=False,
-        name=None,
-        type_=None,
-        schema=None,
-        autoincrement=None,
-        comment=False,
-        existing_comment=None,
-        existing_type=None,
-        existing_server_default=None,
-        existing_nullable=None,
-        existing_autoincrement=None,
-    ):
+        table_name: str,
+        column_name: str,
+        nullable: Optional[bool] = None,
+        server_default: Union["_ServerDefault", "Literal[False]"] = False,
+        name: Optional[str] = None,
+        type_: Optional["TypeEngine"] = None,
+        schema: Optional[str] = None,
+        autoincrement: Optional[bool] = None,
+        comment: Optional[Union[str, "Literal[False]"]] = False,
+        existing_comment: Optional[str] = None,
+        existing_type: Optional["TypeEngine"] = None,
+        existing_server_default: Optional["_ServerDefault"] = None,
+        existing_nullable: Optional[bool] = None,
+        existing_autoincrement: Optional[bool] = None,
+        **kw: Any
+    ) -> None:
         if autoincrement is not None or existing_autoincrement is not None:
             util.warn(
                 "autoincrement and existing_autoincrement "
@@ -185,6 +242,13 @@ class DefaultImpl(metaclass=ImplMeta):
             )
         if server_default is not False:
             kw = {}
+            cls_: Type[
+                Union[
+                    base.ComputedColumnDefault,
+                    base.IdentityColumnDefault,
+                    base.ColumnDefault,
+                ]
+            ]
             if sqla_compat._server_default_is_computed(
                 server_default, existing_server_default
             ):
@@ -200,7 +264,7 @@ class DefaultImpl(metaclass=ImplMeta):
                 cls_(
                     table_name,
                     column_name,
-                    server_default,
+                    server_default,  # type:ignore[arg-type]
                     schema=schema,
                     existing_type=existing_type,
                     existing_server_default=existing_server_default,
@@ -251,25 +315,41 @@ class DefaultImpl(metaclass=ImplMeta):
                 )
             )
 
-    def add_column(self, table_name, column, schema=None):
+    def add_column(
+        self,
+        table_name: str,
+        column: "Column",
+        schema: Optional[Union[str, "quoted_name"]] = None,
+    ) -> None:
         self._exec(base.AddColumn(table_name, column, schema=schema))
 
-    def drop_column(self, table_name, column, schema=None, **kw):
+    def drop_column(
+        self,
+        table_name: str,
+        column: "Column",
+        schema: Optional[str] = None,
+        **kw
+    ) -> None:
         self._exec(base.DropColumn(table_name, column, schema=schema))
 
-    def add_constraint(self, const):
+    def add_constraint(self, const: Any) -> None:
         if const._create_rule is None or const._create_rule(self):
             self._exec(schema.AddConstraint(const))
 
-    def drop_constraint(self, const):
+    def drop_constraint(self, const: "Constraint") -> None:
         self._exec(schema.DropConstraint(const))
 
-    def rename_table(self, old_table_name, new_table_name, schema=None):
+    def rename_table(
+        self,
+        old_table_name: str,
+        new_table_name: Union[str, "quoted_name"],
+        schema: Optional[Union[str, "quoted_name"]] = None,
+    ) -> None:
         self._exec(
             base.RenameTable(old_table_name, new_table_name, schema=schema)
         )
 
-    def create_table(self, table):
+    def create_table(self, table: "Table") -> None:
         table.dispatch.before_create(
             table, self.connection, checkfirst=False, _ddl_runner=self
         )
@@ -292,25 +372,30 @@ class DefaultImpl(metaclass=ImplMeta):
             if comment and with_comment:
                 self.create_column_comment(column)
 
-    def drop_table(self, table):
+    def drop_table(self, table: "Table") -> None:
         self._exec(schema.DropTable(table))
 
-    def create_index(self, index):
+    def create_index(self, index: "Index") -> None:
         self._exec(schema.CreateIndex(index))
 
-    def create_table_comment(self, table):
+    def create_table_comment(self, table: "Table") -> None:
         self._exec(schema.SetTableComment(table))
 
-    def drop_table_comment(self, table):
+    def drop_table_comment(self, table: "Table") -> None:
         self._exec(schema.DropTableComment(table))
 
-    def create_column_comment(self, column):
+    def create_column_comment(self, column: "ColumnElement") -> None:
         self._exec(schema.SetColumnComment(column))
 
-    def drop_index(self, index):
+    def drop_index(self, index: "Index") -> None:
         self._exec(schema.DropIndex(index))
 
-    def bulk_insert(self, table, rows, multiinsert=True):
+    def bulk_insert(
+        self,
+        table: Union["TableClause", "Table"],
+        rows: List[dict],
+        multiinsert: bool = True,
+    ) -> None:
         if not isinstance(rows, list):
             raise TypeError("List expected")
         elif rows and not isinstance(rows[0], dict):
@@ -349,7 +434,7 @@ class DefaultImpl(metaclass=ImplMeta):
                             sqla_compat._insert_inline(table).values(**row)
                         )
 
-    def _tokenize_column_type(self, column):
+    def _tokenize_column_type(self, column: "Column") -> Params:
         definition = self.dialect.type_compiler.process(column.type).lower()
 
         # tokenize the SQLAlchemy-generated version of a type, so that
@@ -387,7 +472,9 @@ class DefaultImpl(metaclass=ImplMeta):
 
         return params
 
-    def _column_types_match(self, inspector_params, metadata_params):
+    def _column_types_match(
+        self, inspector_params: "Params", metadata_params: "Params"
+    ) -> bool:
         if inspector_params.token0 == metadata_params.token0:
             return True
 
@@ -407,7 +494,9 @@ class DefaultImpl(metaclass=ImplMeta):
                 return True
         return False
 
-    def _column_args_match(self, inspected_params, meta_params):
+    def _column_args_match(
+        self, inspected_params: "Params", meta_params: "Params"
+    ) -> bool:
         """We want to compare column parameters. However, we only want
         to compare parameters that are set. If they both have `collation`,
         we want to make sure they are the same. However, if only one
@@ -438,7 +527,9 @@ class DefaultImpl(metaclass=ImplMeta):
 
         return True
 
-    def compare_type(self, inspector_column, metadata_column):
+    def compare_type(
+        self, inspector_column: "Column", metadata_column: "Column"
+    ) -> bool:
         """Returns True if there ARE differences between the types of the two
         columns. Takes impl.type_synonyms into account between retrospected
         and metadata types
@@ -463,11 +554,11 @@ class DefaultImpl(metaclass=ImplMeta):
 
     def correct_for_autogen_constraints(
         self,
-        conn_uniques,
-        conn_indexes,
-        metadata_unique_constraints,
-        metadata_indexes,
-    ):
+        conn_uniques: Union[Set["UniqueConstraint"]],
+        conn_indexes: Union[Set["Index"]],
+        metadata_unique_constraints: Set["UniqueConstraint"],
+        metadata_indexes: Set["Index"],
+    ) -> None:
         pass
 
     def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
@@ -476,7 +567,9 @@ class DefaultImpl(metaclass=ImplMeta):
                 existing_transfer["expr"], new_type
             )
 
-    def render_ddl_sql_expr(self, expr, is_server_default=False, **kw):
+    def render_ddl_sql_expr(
+        self, expr: "ClauseElement", is_server_default: bool = False, **kw
+    ) -> str:
         """Render a SQL expression that is typically a server default,
         index expression, etc.
 
@@ -489,10 +582,16 @@ class DefaultImpl(metaclass=ImplMeta):
         )
         return text_type(expr.compile(dialect=self.dialect, **compile_kw))
 
-    def _compat_autogen_column_reflect(self, inspector):
+    def _compat_autogen_column_reflect(
+        self, inspector: "Inspector"
+    ) -> Callable:
         return self.autogen_column_reflect
 
-    def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
+    def correct_for_autogen_foreignkeys(
+        self,
+        conn_fks: Set["ForeignKeyConstraint"],
+        metadata_fks: Set["ForeignKeyConstraint"],
+    ) -> None:
         pass
 
     def autogen_column_reflect(self, inspector, table, column_info):
@@ -504,7 +603,7 @@ class DefaultImpl(metaclass=ImplMeta):
 
         """
 
-    def start_migrations(self):
+    def start_migrations(self) -> None:
         """A hook called when :meth:`.EnvironmentContext.run_migrations`
         is called.
 
@@ -512,7 +611,7 @@ class DefaultImpl(metaclass=ImplMeta):
 
         """
 
-    def emit_begin(self):
+    def emit_begin(self) -> None:
         """Emit the string ``BEGIN``, or the backend-specific
         equivalent, on the current connection context.
 
@@ -522,7 +621,7 @@ class DefaultImpl(metaclass=ImplMeta):
         """
         self.static_output("BEGIN" + self.command_terminator)
 
-    def emit_commit(self):
+    def emit_commit(self) -> None:
         """Emit the string ``COMMIT``, or the backend-specific
         equivalent, on the current connection context.
 
@@ -532,7 +631,9 @@ class DefaultImpl(metaclass=ImplMeta):
         """
         self.static_output("COMMIT" + self.command_terminator)
 
-    def render_type(self, type_obj, autogen_context):
+    def render_type(
+        self, type_obj: "TypeEngine", autogen_context: "AutogenContext"
+    ) -> Union[str, "Literal[False]"]:
         return False
 
     def _compare_identity_default(self, metadata_identity, inspector_identity):
index 8a99ee6e869b0270fdbb35132e238d49fc7d43fd..9e1ef76eae5cc65612f015b8a0d1914d6a545b6c 100644 (file)
@@ -1,9 +1,15 @@
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
 from sqlalchemy import types as sqltypes
 from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.schema import Column
 from sqlalchemy.schema import CreateIndex
-from sqlalchemy.sql.expression import ClauseElement
-from sqlalchemy.sql.expression import Executable
+from sqlalchemy.sql.base import Executable
+from sqlalchemy.sql.elements import ClauseElement
 
 from .base import AddColumn
 from .base import alter_column
@@ -21,6 +27,20 @@ from .impl import DefaultImpl
 from .. import util
 from ..util import sqla_compat
 
+if TYPE_CHECKING:
+    from typing import Literal
+
+    from sqlalchemy.dialects.mssql.base import MSDDLCompiler
+    from sqlalchemy.dialects.mssql.base import MSSQLCompiler
+    from sqlalchemy.engine.cursor import CursorResult
+    from sqlalchemy.engine.cursor import LegacyCursorResult
+    from sqlalchemy.sql.schema import Index
+    from sqlalchemy.sql.schema import Table
+    from sqlalchemy.sql.selectable import TableClause
+    from sqlalchemy.sql.type_api import TypeEngine
+
+    from .base import _ServerDefault
+
 
 class MSSQLImpl(DefaultImpl):
     __dialect__ = "mssql"
@@ -40,40 +60,44 @@ class MSSQLImpl(DefaultImpl):
         "order",
     )
 
-    def __init__(self, *arg, **kw):
+    def __init__(self, *arg, **kw) -> None:
         super(MSSQLImpl, self).__init__(*arg, **kw)
         self.batch_separator = self.context_opts.get(
             "mssql_batch_separator", self.batch_separator
         )
 
-    def _exec(self, construct, *args, **kw):
+    def _exec(
+        self, construct: Any, *args, **kw
+    ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
         result = super(MSSQLImpl, self)._exec(construct, *args, **kw)
         if self.as_sql and self.batch_separator:
             self.static_output(self.batch_separator)
         return result
 
-    def emit_begin(self):
+    def emit_begin(self) -> None:
         self.static_output("BEGIN TRANSACTION" + self.command_terminator)
 
-    def emit_commit(self):
+    def emit_commit(self) -> None:
         super(MSSQLImpl, self).emit_commit()
         if self.as_sql and self.batch_separator:
             self.static_output(self.batch_separator)
 
-    def alter_column(
+    def alter_column(  # type:ignore[override]
         self,
-        table_name,
-        column_name,
-        nullable=None,
-        server_default=False,
-        name=None,
-        type_=None,
-        schema=None,
-        existing_type=None,
-        existing_server_default=None,
-        existing_nullable=None,
-        **kw
-    ):
+        table_name: str,
+        column_name: str,
+        nullable: Optional[bool] = None,
+        server_default: Optional[
+            Union["_ServerDefault", "Literal[False]"]
+        ] = False,
+        name: Optional[str] = None,
+        type_: Optional["TypeEngine"] = None,
+        schema: Optional[str] = None,
+        existing_type: Optional["TypeEngine"] = None,
+        existing_server_default: Optional["_ServerDefault"] = None,
+        existing_nullable: Optional[bool] = None,
+        **kw: Any
+    ) -> None:
 
         if nullable is not None:
             if existing_type is None:
@@ -138,17 +162,20 @@ class MSSQLImpl(DefaultImpl):
                 table_name, column_name, schema=schema, name=name
             )
 
-    def create_index(self, index):
+    def create_index(self, index: "Index") -> None:
         # this likely defaults to None if not present, so get()
         # should normally not return the default value.  being
         # defensive in any case
         mssql_include = index.kwargs.get("mssql_include", None) or ()
+        assert index.table is not None
         for col in mssql_include:
             if col not in index.table.c:
                 index.table.append_column(Column(col, sqltypes.NullType))
         self._exec(CreateIndex(index))
 
-    def bulk_insert(self, table, rows, **kw):
+    def bulk_insert(  # type:ignore[override]
+        self, table: Union["TableClause", "Table"], rows: List[dict], **kw: Any
+    ) -> None:
         if self.as_sql:
             self._exec(
                 "SET IDENTITY_INSERT %s ON"
@@ -162,7 +189,13 @@ class MSSQLImpl(DefaultImpl):
         else:
             super(MSSQLImpl, self).bulk_insert(table, rows, **kw)
 
-    def drop_column(self, table_name, column, schema=None, **kw):
+    def drop_column(
+        self,
+        table_name: str,
+        column: "Column",
+        schema: Optional[str] = None,
+        **kw
+    ) -> None:
         drop_default = kw.pop("mssql_drop_default", False)
         if drop_default:
             self._exec(
@@ -222,7 +255,13 @@ class MSSQLImpl(DefaultImpl):
 
 
 class _ExecDropConstraint(Executable, ClauseElement):
-    def __init__(self, tname, colname, type_, schema):
+    def __init__(
+        self,
+        tname: str,
+        colname: Union["Column", str],
+        type_: str,
+        schema: Optional[str],
+    ) -> None:
         self.tname = tname
         self.colname = colname
         self.type_ = type_
@@ -230,14 +269,18 @@ class _ExecDropConstraint(Executable, ClauseElement):
 
 
 class _ExecDropFKConstraint(Executable, ClauseElement):
-    def __init__(self, tname, colname, schema):
+    def __init__(
+        self, tname: str, colname: "Column", schema: Optional[str]
+    ) -> None:
         self.tname = tname
         self.colname = colname
         self.schema = schema
 
 
 @compiles(_ExecDropConstraint, "mssql")
-def _exec_drop_col_constraint(element, compiler, **kw):
+def _exec_drop_col_constraint(
+    element: "_ExecDropConstraint", compiler: "MSSQLCompiler", **kw
+) -> str:
     schema, tname, colname, type_ = (
         element.schema,
         element.tname,
@@ -261,7 +304,9 @@ exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
 
 
 @compiles(_ExecDropFKConstraint, "mssql")
-def _exec_drop_col_fk_constraint(element, compiler, **kw):
+def _exec_drop_col_fk_constraint(
+    element: "_ExecDropFKConstraint", compiler: "MSSQLCompiler", **kw
+) -> str:
     schema, tname, colname = element.schema, element.tname, element.colname
 
     return """declare @const_name varchar(256)
@@ -279,19 +324,23 @@ exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
 
 
 @compiles(AddColumn, "mssql")
-def visit_add_column(element, compiler, **kw):
+def visit_add_column(
+    element: "AddColumn", compiler: "MSDDLCompiler", **kw
+) -> str:
     return "%s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         mssql_add_column(compiler, element.column, **kw),
     )
 
 
-def mssql_add_column(compiler, column, **kw):
+def mssql_add_column(compiler: "MSDDLCompiler", column: "Column", **kw) -> str:
     return "ADD %s" % compiler.get_column_specification(column, **kw)
 
 
 @compiles(ColumnNullable, "mssql")
-def visit_column_nullable(element, compiler, **kw):
+def visit_column_nullable(
+    element: "ColumnNullable", compiler: "MSDDLCompiler", **kw
+) -> str:
     return "%s %s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
@@ -301,7 +350,9 @@ def visit_column_nullable(element, compiler, **kw):
 
 
 @compiles(ColumnDefault, "mssql")
-def visit_column_default(element, compiler, **kw):
+def visit_column_default(
+    element: "ColumnDefault", compiler: "MSDDLCompiler", **kw
+) -> str:
     # TODO: there can also be a named constraint
     # with ADD CONSTRAINT here
     return "%s ADD DEFAULT %s FOR %s" % (
@@ -312,7 +363,9 @@ def visit_column_default(element, compiler, **kw):
 
 
 @compiles(ColumnName, "mssql")
-def visit_rename_column(element, compiler, **kw):
+def visit_rename_column(
+    element: "ColumnName", compiler: "MSDDLCompiler", **kw
+) -> str:
     return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % (
         format_table_name(compiler, element.table_name, element.schema),
         format_column_name(compiler, element.column_name),
@@ -321,7 +374,9 @@ def visit_rename_column(element, compiler, **kw):
 
 
 @compiles(ColumnType, "mssql")
-def visit_column_type(element, compiler, **kw):
+def visit_column_type(
+    element: "ColumnType", compiler: "MSDDLCompiler", **kw
+) -> str:
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
@@ -330,7 +385,9 @@ def visit_column_type(element, compiler, **kw):
 
 
 @compiles(RenameTable, "mssql")
-def visit_rename_table(element, compiler, **kw):
+def visit_rename_table(
+    element: "RenameTable", compiler: "MSDDLCompiler", **kw
+) -> str:
     return "EXEC sp_rename '%s', %s" % (
         format_table_name(compiler, element.table_name, element.schema),
         format_table_name(compiler, element.new_table_name, None),
index 4761f75edd9e4b5c592fdc0378765d311efb0e7b..94895605d006bd8a6c2c259d06243679b4282739 100644 (file)
@@ -1,4 +1,8 @@
 import re
+from typing import Any
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
 
 from sqlalchemy import schema
 from sqlalchemy import types as sqltypes
@@ -19,6 +23,16 @@ from ..util import sqla_compat
 from ..util.sqla_compat import _is_mariadb
 from ..util.sqla_compat import _is_type_bound
 
+if TYPE_CHECKING:
+    from typing import Literal
+
+    from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler
+    from sqlalchemy.sql.ddl import DropConstraint
+    from sqlalchemy.sql.schema import Constraint
+    from sqlalchemy.sql.type_api import TypeEngine
+
+    from .base import _ServerDefault
+
 
 class MySQLImpl(DefaultImpl):
     __dialect__ = "mysql"
@@ -27,24 +41,24 @@ class MySQLImpl(DefaultImpl):
     type_synonyms = DefaultImpl.type_synonyms + ({"BOOL", "TINYINT"},)
     type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"]
 
-    def alter_column(
+    def alter_column(  # type:ignore[override]
         self,
-        table_name,
-        column_name,
-        nullable=None,
-        server_default=False,
-        name=None,
-        type_=None,
-        schema=None,
-        existing_type=None,
-        existing_server_default=None,
-        existing_nullable=None,
-        autoincrement=None,
-        existing_autoincrement=None,
-        comment=False,
-        existing_comment=None,
-        **kw
-    ):
+        table_name: str,
+        column_name: str,
+        nullable: Optional[bool] = None,
+        server_default: Union["_ServerDefault", "Literal[False]"] = False,
+        name: Optional[str] = None,
+        type_: Optional["TypeEngine"] = None,
+        schema: Optional[str] = None,
+        existing_type: Optional["TypeEngine"] = None,
+        existing_server_default: Optional["_ServerDefault"] = None,
+        existing_nullable: Optional[bool] = None,
+        autoincrement: Optional[bool] = None,
+        existing_autoincrement: Optional[bool] = None,
+        comment: Optional[Union[str, "Literal[False]"]] = False,
+        existing_comment: Optional[str] = None,
+        **kw: Any
+    ) -> None:
         if sqla_compat._server_default_is_identity(
             server_default, existing_server_default
         ) or sqla_compat._server_default_is_computed(
@@ -126,16 +140,24 @@ class MySQLImpl(DefaultImpl):
                 )
             )
 
-    def drop_constraint(self, const):
+    def drop_constraint(
+        self,
+        const: "Constraint",
+    ) -> None:
         if isinstance(const, schema.CheckConstraint) and _is_type_bound(const):
             return
 
         super(MySQLImpl, self).drop_constraint(const)
 
-    def _is_mysql_allowed_functional_default(self, type_, server_default):
+    def _is_mysql_allowed_functional_default(
+        self,
+        type_: Optional["TypeEngine"],
+        server_default: Union["_ServerDefault", "Literal[False]"],
+    ) -> bool:
         return (
             type_ is not None
-            and type_._type_affinity is sqltypes.DateTime
+            and type_._type_affinity  # type:ignore[attr-defined]
+            is sqltypes.DateTime
             and server_default is not None
         )
 
@@ -268,7 +290,13 @@ class MariaDBImpl(MySQLImpl):
 
 
 class MySQLAlterDefault(AlterColumn):
-    def __init__(self, name, column_name, default, schema=None):
+    def __init__(
+        self,
+        name: str,
+        column_name: str,
+        default: "_ServerDefault",
+        schema: Optional[str] = None,
+    ) -> None:
         super(AlterColumn, self).__init__(name, schema=schema)
         self.column_name = column_name
         self.default = default
@@ -277,16 +305,16 @@ class MySQLAlterDefault(AlterColumn):
 class MySQLChangeColumn(AlterColumn):
     def __init__(
         self,
-        name,
-        column_name,
-        schema=None,
-        newname=None,
-        type_=None,
-        nullable=None,
-        default=False,
-        autoincrement=None,
-        comment=False,
-    ):
+        name: str,
+        column_name: str,
+        schema: Optional[str] = None,
+        newname: Optional[str] = None,
+        type_: Optional["TypeEngine"] = None,
+        nullable: Optional[bool] = None,
+        default: Optional[Union["_ServerDefault", "Literal[False]"]] = False,
+        autoincrement: Optional[bool] = None,
+        comment: Optional[Union[str, "Literal[False]"]] = False,
+    ) -> None:
         super(AlterColumn, self).__init__(name, schema=schema)
         self.column_name = column_name
         self.nullable = nullable
@@ -318,7 +346,9 @@ def _mysql_doesnt_support_individual(element, compiler, **kw):
 
 
 @compiles(MySQLAlterDefault, "mysql", "mariadb")
-def _mysql_alter_default(element, compiler, **kw):
+def _mysql_alter_default(
+    element: "MySQLAlterDefault", compiler: "MySQLDDLCompiler", **kw
+) -> str:
     return "%s ALTER COLUMN %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         format_column_name(compiler, element.column_name),
@@ -329,7 +359,9 @@ def _mysql_alter_default(element, compiler, **kw):
 
 
 @compiles(MySQLModifyColumn, "mysql", "mariadb")
-def _mysql_modify_column(element, compiler, **kw):
+def _mysql_modify_column(
+    element: "MySQLModifyColumn", compiler: "MySQLDDLCompiler", **kw
+) -> str:
     return "%s MODIFY %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         format_column_name(compiler, element.column_name),
@@ -345,7 +377,9 @@ def _mysql_modify_column(element, compiler, **kw):
 
 
 @compiles(MySQLChangeColumn, "mysql", "mariadb")
-def _mysql_change_column(element, compiler, **kw):
+def _mysql_change_column(
+    element: "MySQLChangeColumn", compiler: "MySQLDDLCompiler", **kw
+) -> str:
     return "%s CHANGE %s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         format_column_name(compiler, element.column_name),
@@ -362,8 +396,13 @@ def _mysql_change_column(element, compiler, **kw):
 
 
 def _mysql_colspec(
-    compiler, nullable, server_default, type_, autoincrement, comment
-):
+    compiler: "MySQLDDLCompiler",
+    nullable: Optional[bool],
+    server_default: Optional[Union["_ServerDefault", "Literal[False]"]],
+    type_: "TypeEngine",
+    autoincrement: Optional[bool],
+    comment: Optional[Union[str, "Literal[False]"]],
+) -> str:
     spec = "%s %s" % (
         compiler.dialect.type_compiler.process(type_),
         "NULL" if nullable else "NOT NULL",
@@ -381,7 +420,9 @@ def _mysql_colspec(
 
 
 @compiles(schema.DropConstraint, "mysql", "mariadb")
-def _mysql_drop_constraint(element, compiler, **kw):
+def _mysql_drop_constraint(
+    element: "DropConstraint", compiler: "MySQLDDLCompiler", **kw
+) -> str:
     """Redefine SQLAlchemy's drop constraint to
     raise errors for invalid constraint type."""
 
@@ -394,7 +435,8 @@ def _mysql_drop_constraint(element, compiler, **kw):
             schema.UniqueConstraint,
         ),
     ):
-        return compiler.visit_drop_constraint(element, **kw)
+        assert not kw
+        return compiler.visit_drop_constraint(element)
     elif isinstance(constraint, schema.CheckConstraint):
         # note that SQLAlchemy as of 1.2 does not yet support
         # DROP CONSTRAINT for MySQL/MariaDB, so we implement fully
index 90f93d27cfbbc4de5a1a530b974ed5b06397542e..915edb82a842aa46261718a029d3230751f61a67 100644 (file)
@@ -1,3 +1,8 @@
+from typing import Any
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
 from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.sql import sqltypes
 
@@ -16,6 +21,12 @@ from .base import IdentityColumnDefault
 from .base import RenameTable
 from .impl import DefaultImpl
 
+if TYPE_CHECKING:
+    from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
+    from sqlalchemy.engine.cursor import CursorResult
+    from sqlalchemy.engine.cursor import LegacyCursorResult
+    from sqlalchemy.sql.schema import Column
+
 
 class OracleImpl(DefaultImpl):
     __dialect__ = "oracle"
@@ -28,27 +39,31 @@ class OracleImpl(DefaultImpl):
     )
     identity_attrs_ignore = ()
 
-    def __init__(self, *arg, **kw):
+    def __init__(self, *arg, **kw) -> None:
         super(OracleImpl, self).__init__(*arg, **kw)
         self.batch_separator = self.context_opts.get(
             "oracle_batch_separator", self.batch_separator
         )
 
-    def _exec(self, construct, *args, **kw):
+    def _exec(
+        self, construct: Any, *args, **kw
+    ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
         result = super(OracleImpl, self)._exec(construct, *args, **kw)
         if self.as_sql and self.batch_separator:
             self.static_output(self.batch_separator)
         return result
 
-    def emit_begin(self):
+    def emit_begin(self) -> None:
         self._exec("SET TRANSACTION READ WRITE")
 
-    def emit_commit(self):
+    def emit_commit(self) -> None:
         self._exec("COMMIT")
 
 
 @compiles(AddColumn, "oracle")
-def visit_add_column(element, compiler, **kw):
+def visit_add_column(
+    element: "AddColumn", compiler: "OracleDDLCompiler", **kw
+) -> str:
     return "%s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         add_column(compiler, element.column, **kw),
@@ -56,7 +71,9 @@ def visit_add_column(element, compiler, **kw):
 
 
 @compiles(ColumnNullable, "oracle")
-def visit_column_nullable(element, compiler, **kw):
+def visit_column_nullable(
+    element: "ColumnNullable", compiler: "OracleDDLCompiler", **kw
+) -> str:
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
@@ -65,7 +82,9 @@ def visit_column_nullable(element, compiler, **kw):
 
 
 @compiles(ColumnType, "oracle")
-def visit_column_type(element, compiler, **kw):
+def visit_column_type(
+    element: "ColumnType", compiler: "OracleDDLCompiler", **kw
+) -> str:
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
@@ -74,7 +93,9 @@ def visit_column_type(element, compiler, **kw):
 
 
 @compiles(ColumnName, "oracle")
-def visit_column_name(element, compiler, **kw):
+def visit_column_name(
+    element: "ColumnName", compiler: "OracleDDLCompiler", **kw
+) -> str:
     return "%s RENAME COLUMN %s TO %s" % (
         alter_table(compiler, element.table_name, element.schema),
         format_column_name(compiler, element.column_name),
@@ -83,7 +104,9 @@ def visit_column_name(element, compiler, **kw):
 
 
 @compiles(ColumnDefault, "oracle")
-def visit_column_default(element, compiler, **kw):
+def visit_column_default(
+    element: "ColumnDefault", compiler: "OracleDDLCompiler", **kw
+) -> str:
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
@@ -94,7 +117,9 @@ def visit_column_default(element, compiler, **kw):
 
 
 @compiles(ColumnComment, "oracle")
-def visit_column_comment(element, compiler, **kw):
+def visit_column_comment(
+    element: "ColumnComment", compiler: "OracleDDLCompiler", **kw
+) -> str:
     ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
 
     comment = compiler.sql_compiler.render_literal_value(
@@ -110,23 +135,27 @@ def visit_column_comment(element, compiler, **kw):
 
 
 @compiles(RenameTable, "oracle")
-def visit_rename_table(element, compiler, **kw):
+def visit_rename_table(
+    element: "RenameTable", compiler: "OracleDDLCompiler", **kw
+) -> str:
     return "%s RENAME TO %s" % (
         alter_table(compiler, element.table_name, element.schema),
         format_table_name(compiler, element.new_table_name, None),
     )
 
 
-def alter_column(compiler, name):
+def alter_column(compiler: "OracleDDLCompiler", name: str) -> str:
     return "MODIFY %s" % format_column_name(compiler, name)
 
 
-def add_column(compiler, column, **kw):
+def add_column(compiler: "OracleDDLCompiler", column: "Column", **kw) -> str:
     return "ADD %s" % compiler.get_column_specification(column, **kw)
 
 
 @compiles(IdentityColumnDefault, "oracle")
-def visit_identity_column(element, compiler, **kw):
+def visit_identity_column(
+    element: "IdentityColumnDefault", compiler: "OracleDDLCompiler", **kw
+):
     text = "%s %s " % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
index 7468f082dade0bab780a757266b639f0d399a3ed..c894649a5917dde7877f5b9bb9bc26975ca9c2a1 100644 (file)
@@ -1,5 +1,13 @@
 import logging
 import re
+from typing import Any
+from typing import cast
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
 
 from sqlalchemy import Column
 from sqlalchemy import Numeric
@@ -8,8 +16,8 @@ from sqlalchemy import types as sqltypes
 from sqlalchemy.dialects.postgresql import BIGINT
 from sqlalchemy.dialects.postgresql import ExcludeConstraint
 from sqlalchemy.dialects.postgresql import INTEGER
-from sqlalchemy.sql.expression import ColumnClause
-from sqlalchemy.sql.expression import UnaryExpression
+from sqlalchemy.sql.elements import ColumnClause
+from sqlalchemy.sql.elements import UnaryExpression
 from sqlalchemy.types import NULLTYPE
 
 from .base import alter_column
@@ -32,6 +40,25 @@ from ..operations.base import Operations
 from ..util import compat
 from ..util import sqla_compat
 
+if TYPE_CHECKING:
+    from typing import Literal
+
+    from sqlalchemy.dialects.postgresql.array import ARRAY
+    from sqlalchemy.dialects.postgresql.base import PGDDLCompiler
+    from sqlalchemy.dialects.postgresql.hstore import HSTORE
+    from sqlalchemy.dialects.postgresql.json import JSON
+    from sqlalchemy.dialects.postgresql.json import JSONB
+    from sqlalchemy.sql.elements import BinaryExpression
+    from sqlalchemy.sql.elements import quoted_name
+    from sqlalchemy.sql.schema import MetaData
+    from sqlalchemy.sql.schema import Table
+    from sqlalchemy.sql.type_api import TypeEngine
+
+    from .base import _ServerDefault
+    from ..autogenerate.api import AutogenContext
+    from ..autogenerate.render import _f_name
+    from ..runtime.migration import MigrationContext
+
 
 log = logging.getLogger(__name__)
 
@@ -94,22 +121,22 @@ class PostgresqlImpl(DefaultImpl):
             )
         )
 
-    def alter_column(
+    def alter_column(  # type:ignore[override]
         self,
-        table_name,
-        column_name,
-        nullable=None,
-        server_default=False,
-        name=None,
-        type_=None,
-        schema=None,
-        autoincrement=None,
-        existing_type=None,
-        existing_server_default=None,
-        existing_nullable=None,
-        existing_autoincrement=None,
-        **kw
-    ):
+        table_name: str,
+        column_name: str,
+        nullable: Optional[bool] = None,
+        server_default: Union["_ServerDefault", "Literal[False]"] = False,
+        name: Optional[str] = None,
+        type_: Optional["TypeEngine"] = None,
+        schema: Optional[str] = None,
+        autoincrement: Optional[bool] = None,
+        existing_type: Optional["TypeEngine"] = None,
+        existing_server_default: Optional["_ServerDefault"] = None,
+        existing_nullable: Optional[bool] = None,
+        existing_autoincrement: Optional[bool] = None,
+        **kw: Any
+    ) -> None:
 
         using = kw.pop("postgresql_using", None)
 
@@ -218,7 +245,9 @@ class PostgresqlImpl(DefaultImpl):
                     )
                     metadata_indexes.discard(idx)
 
-    def render_type(self, type_, autogen_context):
+    def render_type(
+        self, type_: "TypeEngine", autogen_context: "AutogenContext"
+    ) -> Union[str, "Literal[False]"]:
         mod = type(type_).__module__
         if not mod.startswith("sqlalchemy.dialects.postgresql"):
             return False
@@ -229,29 +258,51 @@ class PostgresqlImpl(DefaultImpl):
 
         return False
 
-    def _render_HSTORE_type(self, type_, autogen_context):
-        return render._render_type_w_subtype(
-            type_, autogen_context, "text_type", r"(.+?\(.*text_type=)"
+    def _render_HSTORE_type(
+        self, type_: "HSTORE", autogen_context: "AutogenContext"
+    ) -> str:
+        return cast(
+            str,
+            render._render_type_w_subtype(
+                type_, autogen_context, "text_type", r"(.+?\(.*text_type=)"
+            ),
         )
 
-    def _render_ARRAY_type(self, type_, autogen_context):
-        return render._render_type_w_subtype(
-            type_, autogen_context, "item_type", r"(.+?\()"
+    def _render_ARRAY_type(
+        self, type_: "ARRAY", autogen_context: "AutogenContext"
+    ) -> str:
+        return cast(
+            str,
+            render._render_type_w_subtype(
+                type_, autogen_context, "item_type", r"(.+?\()"
+            ),
         )
 
-    def _render_JSON_type(self, type_, autogen_context):
-        return render._render_type_w_subtype(
-            type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
+    def _render_JSON_type(
+        self, type_: "JSON", autogen_context: "AutogenContext"
+    ) -> str:
+        return cast(
+            str,
+            render._render_type_w_subtype(
+                type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
+            ),
         )
 
-    def _render_JSONB_type(self, type_, autogen_context):
-        return render._render_type_w_subtype(
-            type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
+    def _render_JSONB_type(
+        self, type_: "JSONB", autogen_context: "AutogenContext"
+    ) -> str:
+        return cast(
+            str,
+            render._render_type_w_subtype(
+                type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
+            ),
         )
 
 
 class PostgresqlColumnType(AlterColumn):
-    def __init__(self, name, column_name, type_, **kw):
+    def __init__(
+        self, name: str, column_name: str, type_: "TypeEngine", **kw
+    ) -> None:
         using = kw.pop("using", None)
         super(PostgresqlColumnType, self).__init__(name, column_name, **kw)
         self.type_ = sqltypes.to_instance(type_)
@@ -259,7 +310,9 @@ class PostgresqlColumnType(AlterColumn):
 
 
 @compiles(RenameTable, "postgresql")
-def visit_rename_table(element, compiler, **kw):
+def visit_rename_table(
+    element: RenameTable, compiler: "PGDDLCompiler", **kw
+) -> str:
     return "%s RENAME TO %s" % (
         alter_table(compiler, element.table_name, element.schema),
         format_table_name(compiler, element.new_table_name, None),
@@ -267,7 +320,9 @@ def visit_rename_table(element, compiler, **kw):
 
 
 @compiles(PostgresqlColumnType, "postgresql")
-def visit_column_type(element, compiler, **kw):
+def visit_column_type(
+    element: PostgresqlColumnType, compiler: "PGDDLCompiler", **kw
+) -> str:
     return "%s %s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
@@ -277,7 +332,9 @@ def visit_column_type(element, compiler, **kw):
 
 
 @compiles(ColumnComment, "postgresql")
-def visit_column_comment(element, compiler, **kw):
+def visit_column_comment(
+    element: "ColumnComment", compiler: "PGDDLCompiler", **kw
+) -> str:
     ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
     comment = (
         compiler.sql_compiler.render_literal_value(
@@ -297,7 +354,9 @@ def visit_column_comment(element, compiler, **kw):
 
 
 @compiles(IdentityColumnDefault, "postgresql")
-def visit_identity_column(element, compiler, **kw):
+def visit_identity_column(
+    element: "IdentityColumnDefault", compiler: "PGDDLCompiler", **kw
+):
     text = "%s %s " % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
@@ -341,14 +400,17 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
 
     def __init__(
         self,
-        constraint_name,
-        table_name,
-        elements,
-        where=None,
-        schema=None,
-        _orig_constraint=None,
+        constraint_name: Optional[str],
+        table_name: Union[str, "quoted_name"],
+        elements: Union[
+            Sequence[Tuple[str, str]],
+            Sequence[Tuple["ColumnClause", str]],
+        ],
+        where: Optional[Union["BinaryExpression", str]] = None,
+        schema: Optional[str] = None,
+        _orig_constraint: Optional["ExcludeConstraint"] = None,
         **kw
-    ):
+    ) -> None:
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.elements = elements
@@ -358,13 +420,18 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
         self.kw = kw
 
     @classmethod
-    def from_constraint(cls, constraint):
+    def from_constraint(  # type:ignore[override]
+        cls, constraint: "ExcludeConstraint"
+    ) -> "CreateExcludeConstraintOp":
         constraint_table = sqla_compat._table_for_constraint(constraint)
 
         return cls(
             constraint.name,
             constraint_table.name,
-            [(expr, op) for expr, name, op in constraint._render_exprs],
+            [
+                (expr, op)
+                for expr, name, op in constraint._render_exprs  # type:ignore[attr-defined] # noqa
+            ],
             where=constraint.where,
             schema=constraint_table.schema,
             _orig_constraint=constraint,
@@ -373,7 +440,9 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
             using=constraint.using,
         )
 
-    def to_constraint(self, migration_context=None):
+    def to_constraint(
+        self, migration_context: Optional["MigrationContext"] = None
+    ) -> "ExcludeConstraint":
         if self._orig_constraint is not None:
             return self._orig_constraint
         schema_obj = schemaobj.SchemaObjects(migration_context)
@@ -384,15 +453,24 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
             where=self.where,
             **self.kw
         )
-        for expr, name, oper in excl._render_exprs:
+        for (
+            expr,
+            name,
+            oper,
+        ) in excl._render_exprs:  # type:ignore[attr-defined]
             t.append_column(Column(name, NULLTYPE))
         t.append_constraint(excl)
         return excl
 
     @classmethod
     def create_exclude_constraint(
-        cls, operations, constraint_name, table_name, *elements, **kw
-    ):
+        cls,
+        operations: "Operations",
+        constraint_name: str,
+        table_name: str,
+        *elements: Any,
+        **kw: Any
+    ) -> Optional["Table"]:
         """Issue an alter to create an EXCLUDE constraint using the
         current migration context.
 
@@ -453,14 +531,18 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
 
 
 @render.renderers.dispatch_for(CreateExcludeConstraintOp)
-def _add_exclude_constraint(autogen_context, op):
+def _add_exclude_constraint(
+    autogen_context: "AutogenContext", op: "CreateExcludeConstraintOp"
+) -> str:
     return _exclude_constraint(op.to_constraint(), autogen_context, alter=True)
 
 
 @render._constraint_renderers.dispatch_for(ExcludeConstraint)
 def _render_inline_exclude_constraint(
-    constraint, autogen_context, namespace_metadata
-):
+    constraint: "ExcludeConstraint",
+    autogen_context: "AutogenContext",
+    namespace_metadata: "MetaData",
+) -> str:
     rendered = render._user_defined_render(
         "exclude", constraint, autogen_context
     )
@@ -470,7 +552,7 @@ def _render_inline_exclude_constraint(
     return _exclude_constraint(constraint, autogen_context, False)
 
 
-def _postgresql_autogenerate_prefix(autogen_context):
+def _postgresql_autogenerate_prefix(autogen_context: "AutogenContext") -> str:
 
     imports = autogen_context.imports
     if imports is not None:
@@ -478,8 +560,12 @@ def _postgresql_autogenerate_prefix(autogen_context):
     return "postgresql."
 
 
-def _exclude_constraint(constraint, autogen_context, alter):
-    opts = []
+def _exclude_constraint(
+    constraint: "ExcludeConstraint",
+    autogen_context: "AutogenContext",
+    alter: bool,
+) -> str:
+    opts: List[Tuple[str, Union[quoted_name, str, _f_name, None]]] = []
 
     has_batch = autogen_context._has_batch
 
@@ -509,7 +595,7 @@ def _exclude_constraint(constraint, autogen_context, alter):
                     _render_potential_column(sqltext, autogen_context),
                     opstring,
                 )
-                for sqltext, name, opstring in constraint._render_exprs
+                for sqltext, name, opstring in constraint._render_exprs  # type:ignore[attr-defined] # noqa
             ]
         )
         if constraint.where is not None:
@@ -528,7 +614,7 @@ def _exclude_constraint(constraint, autogen_context, alter):
         args = [
             "(%s, %r)"
             % (_render_potential_column(sqltext, autogen_context), opstring)
-            for sqltext, name, opstring in constraint._render_exprs
+            for sqltext, name, opstring in constraint._render_exprs  # type:ignore[attr-defined] # noqa
         ]
         if constraint.where is not None:
             args.append(
@@ -544,7 +630,9 @@ def _exclude_constraint(constraint, autogen_context, alter):
         }
 
 
-def _render_potential_column(value, autogen_context):
+def _render_potential_column(
+    value: Union["ColumnClause", "Column"], autogen_context: "AutogenContext"
+) -> str:
     if isinstance(value, ColumnClause):
         template = "%(prefix)scolumn(%(name)r)"
 
index cb790ea7b524e6b13f9e83ac8c614f8716fc2d4c..2f4ed77362c3cb152b71b18b53de3a2731d4f677 100644 (file)
@@ -1,4 +1,9 @@
 import re
+from typing import Any
+from typing import Dict
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
 
 from sqlalchemy import cast
 from sqlalchemy import JSON
@@ -8,6 +13,17 @@ from sqlalchemy import sql
 from .impl import DefaultImpl
 from .. import util
 
+if TYPE_CHECKING:
+    from sqlalchemy.engine.reflection import Inspector
+    from sqlalchemy.sql.elements import Cast
+    from sqlalchemy.sql.elements import ClauseElement
+    from sqlalchemy.sql.schema import Column
+    from sqlalchemy.sql.schema import Constraint
+    from sqlalchemy.sql.schema import Table
+    from sqlalchemy.sql.type_api import TypeEngine
+
+    from ..operations.batch import BatchOperationsImpl
+
 
 class SQLiteImpl(DefaultImpl):
     __dialect__ = "sqlite"
@@ -17,7 +33,9 @@ class SQLiteImpl(DefaultImpl):
     see: http://bugs.python.org/issue10740
     """
 
-    def requires_recreate_in_batch(self, batch_op):
+    def requires_recreate_in_batch(
+        self, batch_op: "BatchOperationsImpl"
+    ) -> bool:
         """Return True if the given :class:`.BatchOperationsImpl`
         would need the table to be recreated and copied in order to
         proceed.
@@ -44,16 +62,16 @@ class SQLiteImpl(DefaultImpl):
         else:
             return False
 
-    def add_constraint(self, const):
+    def add_constraint(self, const: "Constraint"):
         # attempt to distinguish between an
         # auto-gen constraint and an explicit one
-        if const._create_rule is None:
+        if const._create_rule is None:  # type:ignore[attr-defined]
             raise NotImplementedError(
                 "No support for ALTER of constraints in SQLite dialect"
                 "Please refer to the batch mode feature which allows for "
                 "SQLite migrations using a copy-and-move strategy."
             )
-        elif const._create_rule(self):
+        elif const._create_rule(self):  # type:ignore[attr-defined]
             util.warn(
                 "Skipping unsupported ALTER for "
                 "creation of implicit constraint"
@@ -61,8 +79,8 @@ class SQLiteImpl(DefaultImpl):
                 "SQLite migrations using a copy-and-move strategy."
             )
 
-    def drop_constraint(self, const):
-        if const._create_rule is None:
+    def drop_constraint(self, const: "Constraint"):
+        if const._create_rule is None:  # type:ignore[attr-defined]
             raise NotImplementedError(
                 "No support for ALTER of constraints in SQLite dialect"
                 "Please refer to the batch mode feature which allows for "
@@ -71,11 +89,11 @@ class SQLiteImpl(DefaultImpl):
 
     def compare_server_default(
         self,
-        inspector_column,
-        metadata_column,
-        rendered_metadata_default,
-        rendered_inspector_default,
-    ):
+        inspector_column: "Column",
+        metadata_column: "Column",
+        rendered_metadata_default: Optional[str],
+        rendered_inspector_default: Optional[str],
+    ) -> bool:
 
         if rendered_metadata_default is not None:
             rendered_metadata_default = re.sub(
@@ -93,7 +111,9 @@ class SQLiteImpl(DefaultImpl):
 
         return rendered_inspector_default != rendered_metadata_default
 
-    def _guess_if_default_is_unparenthesized_sql_expr(self, expr):
+    def _guess_if_default_is_unparenthesized_sql_expr(
+        self, expr: Optional[str]
+    ) -> bool:
         """Determine if a server default is a SQL expression or a constant.
 
         There are too many assertions that expect server defaults to round-trip
@@ -112,7 +132,12 @@ class SQLiteImpl(DefaultImpl):
         else:
             return True
 
-    def autogen_column_reflect(self, inspector, table, column_info):
+    def autogen_column_reflect(
+        self,
+        inspector: "Inspector",
+        table: "Table",
+        column_info: Dict[str, Any],
+    ) -> None:
         # SQLite expression defaults require parenthesis when sent
         # as DDL
         if self._guess_if_default_is_unparenthesized_sql_expr(
@@ -120,7 +145,9 @@ class SQLiteImpl(DefaultImpl):
         ):
             column_info["default"] = "(%s)" % (column_info["default"],)
 
-    def render_ddl_sql_expr(self, expr, is_server_default=False, **kw):
+    def render_ddl_sql_expr(
+        self, expr: "ClauseElement", is_server_default: bool = False, **kw
+    ) -> str:
         # SQLite expression defaults require parenthesis when sent
         # as DDL
         str_expr = super(SQLiteImpl, self).render_ddl_sql_expr(
@@ -134,9 +161,15 @@ class SQLiteImpl(DefaultImpl):
             str_expr = "(%s)" % (str_expr,)
         return str_expr
 
-    def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
+    def cast_for_batch_migrate(
+        self,
+        existing: "Column",
+        existing_transfer: Dict[str, Union["TypeEngine", "Cast"]],
+        new_type: "TypeEngine",
+    ) -> None:
         if (
-            existing.type._type_affinity is not new_type._type_affinity
+            existing.type._type_affinity  # type:ignore[attr-defined]
+            is not new_type._type_affinity  # type:ignore[attr-defined]
             and not isinstance(new_type, JSON)
         ):
             existing_transfer["expr"] = cast(
diff --git a/alembic/environment.py b/alembic/environment.py
new file mode 100644 (file)
index 0000000..adfc93e
--- /dev/null
@@ -0,0 +1 @@
+from .runtime.environment import *  # noqa
diff --git a/alembic/migration.py b/alembic/migration.py
new file mode 100644 (file)
index 0000000..02626e2
--- /dev/null
@@ -0,0 +1 @@
+from .runtime.migration import *  # noqa
index cd1408046297cb0de75afb4beba5f8c49f33b228..d4ec7b17e7ad8d6faafe644ade9b8665595cb509 100644 (file)
@@ -1,5 +1,13 @@
 from contextlib import contextmanager
 import textwrap
+from typing import Any
+from typing import Callable
+from typing import Iterator
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
+from sqlalchemy.sql.elements import conv
 
 from . import batch
 from . import schemaobj
@@ -8,12 +16,15 @@ from ..util import sqla_compat
 from ..util.compat import inspect_formatargspec
 from ..util.compat import inspect_getargspec
 
-__all__ = ("Operations", "BatchOperations")
+if TYPE_CHECKING:
+    from sqlalchemy.engine import Connection
 
-try:
-    from sqlalchemy.sql.naming import conv
-except:
-    conv = None
+    from .batch import BatchOperationsImpl
+    from .ops import MigrateOperation
+    from ..runtime.migration import MigrationContext
+    from ..util.sqla_compat import _literal_bindparam
+
+__all__ = ("Operations", "BatchOperations")
 
 
 class Operations(util.ModuleClsProxy):
@@ -49,7 +60,11 @@ class Operations(util.ModuleClsProxy):
 
     _to_impl = util.Dispatcher()
 
-    def __init__(self, migration_context, impl=None):
+    def __init__(
+        self,
+        migration_context: "MigrationContext",
+        impl: Optional["BatchOperationsImpl"] = None,
+    ) -> None:
         """Construct a new :class:`.Operations`
 
         :param migration_context: a :class:`.MigrationContext`
@@ -65,7 +80,9 @@ class Operations(util.ModuleClsProxy):
         self.schema_obj = schemaobj.SchemaObjects(migration_context)
 
     @classmethod
-    def register_operation(cls, name, sourcename=None):
+    def register_operation(
+        cls, name: str, sourcename: Optional[str] = None
+    ) -> Callable:
         """Register a new operation for this class.
 
         This method is normally used to add new operations
@@ -142,7 +159,7 @@ class Operations(util.ModuleClsProxy):
         return register
 
     @classmethod
-    def implementation_for(cls, op_cls):
+    def implementation_for(cls, op_cls: Any) -> Callable:
         """Register an implementation for a given :class:`.MigrateOperation`.
 
         This is part of the operation extensibility API.
@@ -161,7 +178,9 @@ class Operations(util.ModuleClsProxy):
 
     @classmethod
     @contextmanager
-    def context(cls, migration_context):
+    def context(
+        cls, migration_context: "MigrationContext"
+    ) -> Iterator["Operations"]:
         op = Operations(migration_context)
         op._install_proxy()
         yield op
@@ -342,7 +361,7 @@ class Operations(util.ModuleClsProxy):
 
         return self.migration_context
 
-    def invoke(self, operation):
+    def invoke(self, operation: "MigrateOperation") -> Any:
         """Given a :class:`.MigrateOperation`, invoke it in terms of
         this :class:`.Operations` instance.
 
@@ -352,7 +371,7 @@ class Operations(util.ModuleClsProxy):
         )
         return fn(self, operation)
 
-    def f(self, name):
+    def f(self, name: str) -> "conv":
         """Indicate a string name that has already had a naming convention
         applied to it.
 
@@ -385,20 +404,14 @@ class Operations(util.ModuleClsProxy):
             CONSTRAINT ck_bool_t_x CHECK (x in (1, 0)))
 
         The function is rendered in the output of autogenerate when
-        a particular constraint name is already converted, for SQLAlchemy
-        version **0.9.4 and greater only**.   Even though ``naming_convention``
-        was introduced in 0.9.2, the string disambiguation service is new
-        as of 0.9.4.
+        a particular constraint name is already converted.
 
         """
-        if conv:
-            return conv(name)
-        else:
-            raise NotImplementedError(
-                "op.f() feature requires SQLAlchemy 0.9.4 or greater."
-            )
+        return conv(name)
 
-    def inline_literal(self, value, type_=None):
+    def inline_literal(
+        self, value: Union[str, int], type_: None = None
+    ) -> "_literal_bindparam":
         r"""Produce an 'inline literal' expression, suitable for
         using in an INSERT, UPDATE, or DELETE statement.
 
@@ -442,7 +455,7 @@ class Operations(util.ModuleClsProxy):
         """
         return sqla_compat._literal_bindparam(None, value, type_=type_)
 
-    def get_bind(self):
+    def get_bind(self) -> "Connection":
         """Return the current 'bind'.
 
         Under normal circumstances, this is the
index 656b8686bb9f78a887792c3af50139196c910804..ee1fe0578db2febc1cba2d7628e9fd1447c6daa6 100644 (file)
@@ -1,3 +1,12 @@
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
+
 from sqlalchemy import CheckConstraint
 from sqlalchemy import Column
 from sqlalchemy import ForeignKeyConstraint
@@ -21,6 +30,18 @@ from ..util.sqla_compat import _is_type_bound
 from ..util.sqla_compat import _remove_column_from_collection
 from ..util.sqla_compat import _select
 
+if TYPE_CHECKING:
+    from typing import Literal
+
+    from sqlalchemy.engine import Dialect
+    from sqlalchemy.sql.elements import ColumnClause
+    from sqlalchemy.sql.elements import quoted_name
+    from sqlalchemy.sql.functions import Function
+    from sqlalchemy.sql.schema import Constraint
+    from sqlalchemy.sql.type_api import TypeEngine
+
+    from ..ddl.impl import DefaultImpl
+
 
 class BatchOperationsImpl:
     def __init__(
@@ -61,14 +82,14 @@ class BatchOperationsImpl:
         self.batch = []
 
     @property
-    def dialect(self):
+    def dialect(self) -> "Dialect":
         return self.operations.impl.dialect
 
     @property
-    def impl(self):
+    def impl(self) -> "DefaultImpl":
         return self.operations.impl
 
-    def _should_recreate(self):
+    def _should_recreate(self) -> bool:
         if self.recreate == "auto":
             return self.operations.impl.requires_recreate_in_batch(self)
         elif self.recreate == "always":
@@ -76,7 +97,7 @@ class BatchOperationsImpl:
         else:
             return False
 
-    def flush(self):
+    def flush(self) -> None:
         should_recreate = self._should_recreate()
 
         with _ensure_scope_for_ddl(self.impl.connection):
@@ -118,10 +139,10 @@ class BatchOperationsImpl:
 
                 batch_impl._create(self.impl)
 
-    def alter_column(self, *arg, **kw):
+    def alter_column(self, *arg, **kw) -> None:
         self.batch.append(("alter_column", arg, kw))
 
-    def add_column(self, *arg, **kw):
+    def add_column(self, *arg, **kw) -> None:
         if (
             "insert_before" in kw or "insert_after" in kw
         ) and not self._should_recreate():
@@ -131,22 +152,22 @@ class BatchOperationsImpl:
             )
         self.batch.append(("add_column", arg, kw))
 
-    def drop_column(self, *arg, **kw):
+    def drop_column(self, *arg, **kw) -> None:
         self.batch.append(("drop_column", arg, kw))
 
-    def add_constraint(self, const):
+    def add_constraint(self, const: "Constraint") -> None:
         self.batch.append(("add_constraint", (const,), {}))
 
-    def drop_constraint(self, const):
+    def drop_constraint(self, const: "Constraint") -> None:
         self.batch.append(("drop_constraint", (const,), {}))
 
     def rename_table(self, *arg, **kw):
         self.batch.append(("rename_table", arg, kw))
 
-    def create_index(self, idx):
+    def create_index(self, idx: "Index") -> None:
         self.batch.append(("create_index", (idx,), {}))
 
-    def drop_index(self, idx):
+    def drop_index(self, idx: "Index") -> None:
         self.batch.append(("drop_index", (idx,), {}))
 
     def create_table_comment(self, table):
@@ -168,22 +189,24 @@ class BatchOperationsImpl:
 class ApplyBatchImpl:
     def __init__(
         self,
-        impl,
-        table,
-        table_args,
-        table_kwargs,
-        reflected,
-        partial_reordering=(),
-    ):
+        impl: "DefaultImpl",
+        table: "Table",
+        table_args: tuple,
+        table_kwargs: Dict[str, Any],
+        reflected: bool,
+        partial_reordering: tuple = (),
+    ) -> None:
         self.impl = impl
         self.table = table  # this is a Table object
         self.table_args = table_args
         self.table_kwargs = table_kwargs
         self.temp_table_name = self._calc_temp_name(table.name)
-        self.new_table = None
+        self.new_table: Optional[Table] = None
 
         self.partial_reordering = partial_reordering  # tuple of tuples
-        self.add_col_ordering = ()  # tuple of tuples
+        self.add_col_ordering: Tuple[
+            Tuple[str, str], ...
+        ] = ()  # tuple of tuples
 
         self.column_transfers = OrderedDict(
             (c.name, {"expr": c}) for c in self.table.c
@@ -194,12 +217,12 @@ class ApplyBatchImpl:
         self._grab_table_elements()
 
     @classmethod
-    def _calc_temp_name(cls, tablename):
+    def _calc_temp_name(cls, tablename: "quoted_name") -> str:
         return ("_alembic_tmp_%s" % tablename)[0:50]
 
-    def _grab_table_elements(self):
+    def _grab_table_elements(self) -> None:
         schema = self.table.schema
-        self.columns = OrderedDict()
+        self.columns: Dict[str, "Column"] = OrderedDict()
         for c in self.table.c:
             c_copy = _copy(c, schema=schema)
             c_copy.unique = c_copy.index = False
@@ -208,11 +231,11 @@ class ApplyBatchImpl:
             if isinstance(c.type, SchemaEventTarget):
                 assert c_copy.type is not c.type
             self.columns[c.name] = c_copy
-        self.named_constraints = {}
+        self.named_constraints: Dict[str, "Constraint"] = {}
         self.unnamed_constraints = []
         self.col_named_constraints = {}
-        self.indexes = {}
-        self.new_indexes = {}
+        self.indexes: Dict[str, "Index"] = {}
+        self.new_indexes: Dict[str, "Index"] = {}
 
         for const in self.table.constraints:
             if _is_type_bound(const):
@@ -238,7 +261,7 @@ class ApplyBatchImpl:
         for k in self.table.kwargs:
             self.table_kwargs.setdefault(k, self.table.kwargs[k])
 
-    def _adjust_self_columns_for_partial_reordering(self):
+    def _adjust_self_columns_for_partial_reordering(self) -> None:
         pairs = set()
 
         col_by_idx = list(self.columns)
@@ -258,17 +281,17 @@ class ApplyBatchImpl:
         # this can happen if some columns were dropped and not removed
         # from existing_ordering.  this should be prevented already, but
         # conservatively making sure this didn't happen
-        pairs = [p for p in pairs if p[0] != p[1]]
+        pairs_list = [p for p in pairs if p[0] != p[1]]
 
         sorted_ = list(
-            topological.sort(pairs, col_by_idx, deterministic_order=True)
+            topological.sort(pairs_list, col_by_idx, deterministic_order=True)
         )
         self.columns = OrderedDict((k, self.columns[k]) for k in sorted_)
         self.column_transfers = OrderedDict(
             (k, self.column_transfers[k]) for k in sorted_
         )
 
-    def _transfer_elements_to_new_table(self):
+    def _transfer_elements_to_new_table(self) -> None:
         assert self.new_table is None, "Can only create new table once"
 
         m = MetaData()
@@ -296,6 +319,7 @@ class ApplyBatchImpl:
             if not const_columns.issubset(self.column_transfers):
                 continue
 
+            const_copy: "Constraint"
             if isinstance(const, ForeignKeyConstraint):
                 if _fk_is_self_referential(const):
                     # for self-referential constraint, refer to the
@@ -320,8 +344,9 @@ class ApplyBatchImpl:
                 self._setup_referent(m, const)
             new_table.append_constraint(const_copy)
 
-    def _gather_indexes_from_both_tables(self):
-        idx = []
+    def _gather_indexes_from_both_tables(self) -> List["Index"]:
+        assert self.new_table is not None
+        idx: List[Index] = []
         idx.extend(self.indexes.values())
         for index in self.new_indexes.values():
             idx.append(
@@ -334,8 +359,12 @@ class ApplyBatchImpl:
             )
         return idx
 
-    def _setup_referent(self, metadata, constraint):
-        spec = constraint.elements[0]._get_colspec()
+    def _setup_referent(
+        self, metadata: "MetaData", constraint: "ForeignKeyConstraint"
+    ) -> None:
+        spec = constraint.elements[
+            0
+        ]._get_colspec()  # type:ignore[attr-defined]
         parts = spec.split(".")
         tname = parts[-2]
         if len(parts) == 3:
@@ -345,10 +374,14 @@ class ApplyBatchImpl:
 
         if tname != self.temp_table_name:
             key = sql_schema._get_table_key(tname, referent_schema)
+
+            def colspec(elem: Any):
+                return elem._get_colspec()
+
             if key in metadata.tables:
                 t = metadata.tables[key]
                 for elem in constraint.elements:
-                    colname = elem._get_colspec().split(".")[-1]
+                    colname = colspec(elem).split(".")[-1]
                     if colname not in t.c:
                         t.append_column(Column(colname, sqltypes.NULLTYPE))
             else:
@@ -358,17 +391,18 @@ class ApplyBatchImpl:
                     *[
                         Column(n, sqltypes.NULLTYPE)
                         for n in [
-                            elem._get_colspec().split(".")[-1]
+                            colspec(elem).split(".")[-1]
                             for elem in constraint.elements
                         ]
                     ],
                     schema=referent_schema
                 )
 
-    def _create(self, op_impl):
+    def _create(self, op_impl: "DefaultImpl") -> None:
         self._transfer_elements_to_new_table()
 
         op_impl.prep_table_for_batch(self, self.table)
+        assert self.new_table is not None
         op_impl.create_table(self.new_table)
 
         try:
@@ -405,18 +439,18 @@ class ApplyBatchImpl:
 
     def alter_column(
         self,
-        table_name,
-        column_name,
-        nullable=None,
-        server_default=False,
-        name=None,
-        type_=None,
-        autoincrement=None,
-        comment=False,
+        table_name: str,
+        column_name: str,
+        nullable: Optional[bool] = None,
+        server_default: Optional[Union["Function", str, bool]] = False,
+        name: Optional[str] = None,
+        type_: Optional["TypeEngine"] = None,
+        autoincrement: None = None,
+        comment: Union[str, "Literal[False]"] = False,
         **kw
-    ):
+    ) -> None:
         existing = self.columns[column_name]
-        existing_transfer = self.column_transfers[column_name]
+        existing_transfer: Dict[str, Any] = self.column_transfers[column_name]
         if name is not None and name != column_name:
             # note that we don't change '.key' - we keep referring
             # to the renamed column by its old key in _create().  neat!
@@ -431,8 +465,8 @@ class ApplyBatchImpl:
             # we also ignore the drop_constraint that will come here from
             # Operations.implementation_for(alter_column)
             if isinstance(existing.type, SchemaEventTarget):
-                existing.type._create_events = (
-                    existing.type.create_constraint
+                existing.type._create_events = (  # type:ignore[attr-defined]
+                    existing.type.create_constraint  # type:ignore[attr-defined] # noqa
                 ) = False
 
             self.impl.cast_for_batch_migrate(
@@ -452,7 +486,11 @@ class ApplyBatchImpl:
             if server_default is None:
                 existing.server_default = None
             else:
-                sql_schema.DefaultClause(server_default)._set_parent(existing)
+                sql_schema.DefaultClause(
+                    server_default
+                )._set_parent(  # type:ignore[attr-defined]
+                    existing
+                )
         if autoincrement is not None:
             existing.autoincrement = bool(autoincrement)
 
@@ -460,8 +498,11 @@ class ApplyBatchImpl:
             existing.comment = comment
 
     def _setup_dependencies_for_add_column(
-        self, colname, insert_before, insert_after
-    ):
+        self,
+        colname: str,
+        insert_before: Optional[str],
+        insert_after: Optional[str],
+    ) -> None:
         index_cols = self.existing_ordering
         col_indexes = {name: i for i, name in enumerate(index_cols)}
 
@@ -505,8 +546,13 @@ class ApplyBatchImpl:
             self.add_col_ordering += ((index_cols[-1], colname),)
 
     def add_column(
-        self, table_name, column, insert_before=None, insert_after=None, **kw
-    ):
+        self,
+        table_name: str,
+        column: "Column",
+        insert_before: Optional[str] = None,
+        insert_after: Optional[str] = None,
+        **kw
+    ) -> None:
         self._setup_dependencies_for_add_column(
             column.name, insert_before, insert_after
         )
@@ -515,7 +561,9 @@ class ApplyBatchImpl:
         self.columns[column.name] = _copy(column, schema=self.table.schema)
         self.column_transfers[column.name] = {}
 
-    def drop_column(self, table_name, column, **kw):
+    def drop_column(
+        self, table_name: str, column: Union["ColumnClause", "Column"], **kw
+    ) -> None:
         if column.name in self.table.primary_key.columns:
             _remove_column_from_collection(
                 self.table.primary_key.columns, column
@@ -546,7 +594,7 @@ class ApplyBatchImpl:
 
         """
 
-    def add_constraint(self, const):
+    def add_constraint(self, const: "Constraint") -> None:
         if not const.name:
             raise ValueError("Constraint must have a name")
         if isinstance(const, sql_schema.PrimaryKeyConstraint):
@@ -555,7 +603,7 @@ class ApplyBatchImpl:
 
         self.named_constraints[const.name] = const
 
-    def drop_constraint(self, const):
+    def drop_constraint(self, const: "Constraint") -> None:
         if not const.name:
             raise ValueError("Constraint must have a name")
         try:
@@ -566,7 +614,7 @@ class ApplyBatchImpl:
                     if col_const.name == const.name:
                         self.columns[col.name].constraints.remove(col_const)
             else:
-                const = self.named_constraints.pop(const.name)
+                const = self.named_constraints.pop(cast(str, const.name))
         except KeyError:
             if _is_type_bound(const):
                 # type-bound constraints are only included in the new
@@ -580,10 +628,10 @@ class ApplyBatchImpl:
                 for col in const.columns:
                     self.columns[col.name].primary_key = False
 
-    def create_index(self, idx):
+    def create_index(self, idx: "Index") -> None:
         self.new_indexes[idx.name] = idx
 
-    def drop_index(self, idx):
+    def drop_index(self, idx: "Index") -> None:
         try:
             del self.indexes[idx.name]
         except KeyError:
index 89793554f407119986ef9e3667b7652c9145716f..d5ddbc94df946649cdc7231c50cfefa6ed4d888a 100644 (file)
@@ -1,4 +1,19 @@
+from abc import abstractmethod
 import re
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import FrozenSet
+from typing import Iterator
+from typing import List
+from typing import MutableMapping
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
 
 from sqlalchemy.types import NULLTYPE
 
@@ -8,6 +23,33 @@ from .base import Operations
 from .. import util
 from ..util import sqla_compat
 
+if TYPE_CHECKING:
+    from sqlalchemy.sql.dml import Insert
+    from sqlalchemy.sql.dml import Update
+    from sqlalchemy.sql.elements import BinaryExpression
+    from sqlalchemy.sql.elements import ColumnElement
+    from sqlalchemy.sql.elements import conv
+    from sqlalchemy.sql.elements import quoted_name
+    from sqlalchemy.sql.elements import TextClause
+    from sqlalchemy.sql.functions import Function
+    from sqlalchemy.sql.schema import CheckConstraint
+    from sqlalchemy.sql.schema import Column
+    from sqlalchemy.sql.schema import Computed
+    from sqlalchemy.sql.schema import Constraint
+    from sqlalchemy.sql.schema import ForeignKeyConstraint
+    from sqlalchemy.sql.schema import Identity
+    from sqlalchemy.sql.schema import Index
+    from sqlalchemy.sql.schema import MetaData
+    from sqlalchemy.sql.schema import PrimaryKeyConstraint
+    from sqlalchemy.sql.schema import SchemaItem
+    from sqlalchemy.sql.schema import Table
+    from sqlalchemy.sql.schema import UniqueConstraint
+    from sqlalchemy.sql.selectable import TableClause
+    from sqlalchemy.sql.type_api import TypeEngine
+
+    from ..autogenerate.rewriter import Rewriter
+    from ..runtime.migration import MigrationContext
+
 
 class MigrateOperation:
     """base class for migration command and organization objects.
@@ -32,7 +74,13 @@ class MigrateOperation:
         """
         return {}
 
-    _mutations = frozenset()
+    _mutations: FrozenSet["Rewriter"] = frozenset()
+
+    def reverse(self) -> "MigrateOperation":
+        raise NotImplementedError
+
+    def to_diff_tuple(self) -> Tuple[Any, ...]:
+        raise NotImplementedError
 
 
 class AddConstraintOp(MigrateOperation):
@@ -45,7 +93,7 @@ class AddConstraintOp(MigrateOperation):
         raise NotImplementedError()
 
     @classmethod
-    def register_add_constraint(cls, type_):
+    def register_add_constraint(cls, type_: str) -> Callable:
         def go(klass):
             cls.add_constraint_ops.dispatch_for(type_)(klass.from_constraint)
             return klass
@@ -53,15 +101,21 @@ class AddConstraintOp(MigrateOperation):
         return go
 
     @classmethod
-    def from_constraint(cls, constraint):
+    def from_constraint(cls, constraint: "Constraint") -> "AddConstraintOp":
         return cls.add_constraint_ops.dispatch(constraint.__visit_name__)(
             constraint
         )
 
-    def reverse(self):
+    @abstractmethod
+    def to_constraint(
+        self, migration_context: Optional["MigrationContext"] = None
+    ) -> "Constraint":
+        pass
+
+    def reverse(self) -> "DropConstraintOp":
         return DropConstraintOp.from_constraint(self.to_constraint())
 
-    def to_diff_tuple(self):
+    def to_diff_tuple(self) -> Tuple[str, "Constraint"]:
         return ("add_constraint", self.to_constraint())
 
 
@@ -72,29 +126,34 @@ class DropConstraintOp(MigrateOperation):
 
     def __init__(
         self,
-        constraint_name,
-        table_name,
-        type_=None,
-        schema=None,
-        _reverse=None,
-    ):
+        constraint_name: Optional[str],
+        table_name: str,
+        type_: Optional[str] = None,
+        schema: Optional[str] = None,
+        _reverse: Optional["AddConstraintOp"] = None,
+    ) -> None:
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.constraint_type = type_
         self.schema = schema
         self._reverse = _reverse
 
-    def reverse(self):
+    def reverse(self) -> "AddConstraintOp":
         return AddConstraintOp.from_constraint(self.to_constraint())
 
-    def to_diff_tuple(self):
+    def to_diff_tuple(
+        self,
+    ) -> Tuple[str, "SchemaItem"]:
         if self.constraint_type == "foreignkey":
             return ("remove_fk", self.to_constraint())
         else:
             return ("remove_constraint", self.to_constraint())
 
     @classmethod
-    def from_constraint(cls, constraint):
+    def from_constraint(
+        cls,
+        constraint: "Constraint",
+    ) -> "DropConstraintOp":
         types = {
             "unique_constraint": "unique",
             "foreign_key_constraint": "foreignkey",
@@ -113,7 +172,9 @@ class DropConstraintOp(MigrateOperation):
             _reverse=AddConstraintOp.from_constraint(constraint),
         )
 
-    def to_constraint(self):
+    def to_constraint(
+        self,
+    ) -> "Constraint":
 
         if self._reverse is not None:
             constraint = self._reverse.to_constraint()
@@ -131,8 +192,13 @@ class DropConstraintOp(MigrateOperation):
 
     @classmethod
     def drop_constraint(
-        cls, operations, constraint_name, table_name, type_=None, schema=None
-    ):
+        cls,
+        operations: "Operations",
+        constraint_name: str,
+        table_name: str,
+        type_: Optional[str] = None,
+        schema: Optional[str] = None,
+    ) -> Optional["Table"]:
         r"""Drop a constraint of the given name, typically via DROP CONSTRAINT.
 
         :param constraint_name: name of the constraint.
@@ -150,7 +216,12 @@ class DropConstraintOp(MigrateOperation):
         return operations.invoke(op)
 
     @classmethod
-    def batch_drop_constraint(cls, operations, constraint_name, type_=None):
+    def batch_drop_constraint(
+        cls,
+        operations: "BatchOperations",
+        constraint_name: str,
+        type_: Optional[str] = None,
+    ) -> None:
         """Issue a "drop constraint" instruction using the
         current batch migration context.
 
@@ -182,8 +253,13 @@ class CreatePrimaryKeyOp(AddConstraintOp):
     constraint_type = "primarykey"
 
     def __init__(
-        self, constraint_name, table_name, columns, schema=None, **kw
-    ):
+        self,
+        constraint_name: Optional[str],
+        table_name: str,
+        columns: Sequence[str],
+        schema: Optional[str] = None,
+        **kw
+    ) -> None:
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.columns = columns
@@ -191,18 +267,23 @@ class CreatePrimaryKeyOp(AddConstraintOp):
         self.kw = kw
 
     @classmethod
-    def from_constraint(cls, constraint):
+    def from_constraint(cls, constraint: "Constraint") -> "CreatePrimaryKeyOp":
         constraint_table = sqla_compat._table_for_constraint(constraint)
+        pk_constraint = cast("PrimaryKeyConstraint", constraint)
+
         return cls(
-            constraint.name,
+            pk_constraint.name,
             constraint_table.name,
-            constraint.columns.keys(),
+            pk_constraint.columns.keys(),
             schema=constraint_table.schema,
-            **constraint.dialect_kwargs,
+            **pk_constraint.dialect_kwargs,
         )
 
-    def to_constraint(self, migration_context=None):
+    def to_constraint(
+        self, migration_context: Optional["MigrationContext"] = None
+    ) -> "PrimaryKeyConstraint":
         schema_obj = schemaobj.SchemaObjects(migration_context)
+
         return schema_obj.primary_key_constraint(
             self.constraint_name,
             self.table_name,
@@ -213,8 +294,13 @@ class CreatePrimaryKeyOp(AddConstraintOp):
 
     @classmethod
     def create_primary_key(
-        cls, operations, constraint_name, table_name, columns, schema=None
-    ):
+        cls,
+        operations: "Operations",
+        constraint_name: str,
+        table_name: str,
+        columns: List[str],
+        schema: Optional[str] = None,
+    ) -> Optional["Table"]:
         """Issue a "create primary key" instruction using the current
         migration context.
 
@@ -255,7 +341,12 @@ class CreatePrimaryKeyOp(AddConstraintOp):
         return operations.invoke(op)
 
     @classmethod
-    def batch_create_primary_key(cls, operations, constraint_name, columns):
+    def batch_create_primary_key(
+        cls,
+        operations: "BatchOperations",
+        constraint_name: str,
+        columns: List[str],
+    ) -> None:
         """Issue a "create primary key" instruction using the
         current batch migration context.
 
@@ -287,8 +378,13 @@ class CreateUniqueConstraintOp(AddConstraintOp):
     constraint_type = "unique"
 
     def __init__(
-        self, constraint_name, table_name, columns, schema=None, **kw
-    ):
+        self,
+        constraint_name: Optional[str],
+        table_name: str,
+        columns: Sequence[str],
+        schema: Optional[str] = None,
+        **kw
+    ) -> None:
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.columns = columns
@@ -296,24 +392,31 @@ class CreateUniqueConstraintOp(AddConstraintOp):
         self.kw = kw
 
     @classmethod
-    def from_constraint(cls, constraint):
+    def from_constraint(
+        cls, constraint: "Constraint"
+    ) -> "CreateUniqueConstraintOp":
+
         constraint_table = sqla_compat._table_for_constraint(constraint)
 
-        kw = {}
-        if constraint.deferrable:
-            kw["deferrable"] = constraint.deferrable
-        if constraint.initially:
-            kw["initially"] = constraint.initially
-        kw.update(constraint.dialect_kwargs)
+        uq_constraint = cast("UniqueConstraint", constraint)
+
+        kw: dict = {}
+        if uq_constraint.deferrable:
+            kw["deferrable"] = uq_constraint.deferrable
+        if uq_constraint.initially:
+            kw["initially"] = uq_constraint.initially
+        kw.update(uq_constraint.dialect_kwargs)
         return cls(
-            constraint.name,
+            uq_constraint.name,
             constraint_table.name,
-            [c.name for c in constraint.columns],
+            [c.name for c in uq_constraint.columns],
             schema=constraint_table.schema,
             **kw,
         )
 
-    def to_constraint(self, migration_context=None):
+    def to_constraint(
+        self, migration_context: Optional["MigrationContext"] = None
+    ) -> "UniqueConstraint":
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.unique_constraint(
             self.constraint_name,
@@ -326,13 +429,13 @@ class CreateUniqueConstraintOp(AddConstraintOp):
     @classmethod
     def create_unique_constraint(
         cls,
-        operations,
-        constraint_name,
-        table_name,
-        columns,
-        schema=None,
+        operations: "Operations",
+        constraint_name: Optional[str],
+        table_name: str,
+        columns: Sequence[str],
+        schema: Optional[str] = None,
         **kw
-    ):
+    ) -> Any:
         """Issue a "create unique constraint" instruction using the
         current migration context.
 
@@ -376,8 +479,12 @@ class CreateUniqueConstraintOp(AddConstraintOp):
 
     @classmethod
     def batch_create_unique_constraint(
-        cls, operations, constraint_name, columns, **kw
-    ):
+        cls,
+        operations: "BatchOperations",
+        constraint_name: str,
+        columns: Sequence[str],
+        **kw
+    ) -> Any:
         """Issue a "create unique constraint" instruction using the
         current batch migration context.
 
@@ -406,13 +513,13 @@ class CreateForeignKeyOp(AddConstraintOp):
 
     def __init__(
         self,
-        constraint_name,
-        source_table,
-        referent_table,
-        local_cols,
-        remote_cols,
+        constraint_name: Optional[str],
+        source_table: str,
+        referent_table: str,
+        local_cols: List[str],
+        remote_cols: List[str],
         **kw
-    ):
+    ) -> None:
         self.constraint_name = constraint_name
         self.source_table = source_table
         self.referent_table = referent_table
@@ -420,22 +527,24 @@ class CreateForeignKeyOp(AddConstraintOp):
         self.remote_cols = remote_cols
         self.kw = kw
 
-    def to_diff_tuple(self):
+    def to_diff_tuple(self) -> Tuple[str, "ForeignKeyConstraint"]:
         return ("add_fk", self.to_constraint())
 
     @classmethod
-    def from_constraint(cls, constraint):
-        kw = {}
-        if constraint.onupdate:
-            kw["onupdate"] = constraint.onupdate
-        if constraint.ondelete:
-            kw["ondelete"] = constraint.ondelete
-        if constraint.initially:
-            kw["initially"] = constraint.initially
-        if constraint.deferrable:
-            kw["deferrable"] = constraint.deferrable
-        if constraint.use_alter:
-            kw["use_alter"] = constraint.use_alter
+    def from_constraint(cls, constraint: "Constraint") -> "CreateForeignKeyOp":
+
+        fk_constraint = cast("ForeignKeyConstraint", constraint)
+        kw: dict = {}
+        if fk_constraint.onupdate:
+            kw["onupdate"] = fk_constraint.onupdate
+        if fk_constraint.ondelete:
+            kw["ondelete"] = fk_constraint.ondelete
+        if fk_constraint.initially:
+            kw["initially"] = fk_constraint.initially
+        if fk_constraint.deferrable:
+            kw["deferrable"] = fk_constraint.deferrable
+        if fk_constraint.use_alter:
+            kw["use_alter"] = fk_constraint.use_alter
 
         (
             source_schema,
@@ -448,13 +557,13 @@ class CreateForeignKeyOp(AddConstraintOp):
             ondelete,
             deferrable,
             initially,
-        ) = sqla_compat._fk_spec(constraint)
+        ) = sqla_compat._fk_spec(fk_constraint)
 
         kw["source_schema"] = source_schema
         kw["referent_schema"] = target_schema
-        kw.update(constraint.dialect_kwargs)
+        kw.update(fk_constraint.dialect_kwargs)
         return cls(
-            constraint.name,
+            fk_constraint.name,
             source_table,
             target_table,
             source_columns,
@@ -462,7 +571,9 @@ class CreateForeignKeyOp(AddConstraintOp):
             **kw,
         )
 
-    def to_constraint(self, migration_context=None):
+    def to_constraint(
+        self, migration_context: Optional["MigrationContext"] = None
+    ) -> "ForeignKeyConstraint":
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.foreign_key_constraint(
             self.constraint_name,
@@ -476,21 +587,21 @@ class CreateForeignKeyOp(AddConstraintOp):
     @classmethod
     def create_foreign_key(
         cls,
-        operations,
-        constraint_name,
-        source_table,
-        referent_table,
-        local_cols,
-        remote_cols,
-        onupdate=None,
-        ondelete=None,
-        deferrable=None,
-        initially=None,
-        match=None,
-        source_schema=None,
-        referent_schema=None,
+        operations: "Operations",
+        constraint_name: str,
+        source_table: str,
+        referent_table: str,
+        local_cols: List[str],
+        remote_cols: List[str],
+        onupdate: Optional[str] = None,
+        ondelete: Optional[str] = None,
+        deferrable: Optional[bool] = None,
+        initially: Optional[str] = None,
+        match: Optional[str] = None,
+        source_schema: Optional[str] = None,
+        referent_schema: Optional[str] = None,
         **dialect_kw
-    ):
+    ) -> Optional["Table"]:
         """Issue a "create foreign key" instruction using the
         current migration context.
 
@@ -556,19 +667,19 @@ class CreateForeignKeyOp(AddConstraintOp):
     @classmethod
     def batch_create_foreign_key(
         cls,
-        operations,
-        constraint_name,
-        referent_table,
-        local_cols,
-        remote_cols,
-        referent_schema=None,
-        onupdate=None,
-        ondelete=None,
-        deferrable=None,
-        initially=None,
-        match=None,
+        operations: "BatchOperations",
+        constraint_name: str,
+        referent_table: str,
+        local_cols: List[str],
+        remote_cols: List[str],
+        referent_schema: Optional[str] = None,
+        onupdate: None = None,
+        ondelete: None = None,
+        deferrable: None = None,
+        initially: None = None,
+        match: None = None,
         **dialect_kw
-    ):
+    ) -> None:
         """Issue a "create foreign key" instruction using the
         current batch migration context.
 
@@ -618,8 +729,13 @@ class CreateCheckConstraintOp(AddConstraintOp):
     constraint_type = "check"
 
     def __init__(
-        self, constraint_name, table_name, condition, schema=None, **kw
-    ):
+        self,
+        constraint_name: Optional[str],
+        table_name: str,
+        condition: Union["TextClause", "ColumnElement[Any]"],
+        schema: Optional[str] = None,
+        **kw
+    ) -> None:
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.condition = condition
@@ -627,18 +743,26 @@ class CreateCheckConstraintOp(AddConstraintOp):
         self.kw = kw
 
     @classmethod
-    def from_constraint(cls, constraint):
+    def from_constraint(
+        cls, constraint: "Constraint"
+    ) -> "CreateCheckConstraintOp":
         constraint_table = sqla_compat._table_for_constraint(constraint)
 
+        ck_constraint = cast("CheckConstraint", constraint)
+
         return cls(
-            constraint.name,
+            ck_constraint.name,
             constraint_table.name,
-            constraint.sqltext,
+            cast(
+                "Union[TextClause, ColumnElement[Any]]", ck_constraint.sqltext
+            ),
             schema=constraint_table.schema,
-            **constraint.dialect_kwargs,
+            **ck_constraint.dialect_kwargs,
         )
 
-    def to_constraint(self, migration_context=None):
+    def to_constraint(
+        self, migration_context: Optional["MigrationContext"] = None
+    ) -> "CheckConstraint":
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.check_constraint(
             self.constraint_name,
@@ -651,13 +775,13 @@ class CreateCheckConstraintOp(AddConstraintOp):
     @classmethod
     def create_check_constraint(
         cls,
-        operations,
-        constraint_name,
-        table_name,
-        condition,
-        schema=None,
+        operations: "Operations",
+        constraint_name: Optional[str],
+        table_name: str,
+        condition: "BinaryExpression",
+        schema: Optional[str] = None,
         **kw
-    ):
+    ) -> Optional["Table"]:
         """Issue a "create check constraint" instruction using the
         current migration context.
 
@@ -703,8 +827,12 @@ class CreateCheckConstraintOp(AddConstraintOp):
 
     @classmethod
     def batch_create_check_constraint(
-        cls, operations, constraint_name, condition, **kw
-    ):
+        cls,
+        operations: "BatchOperations",
+        constraint_name: str,
+        condition: "TextClause",
+        **kw
+    ) -> Optional["Table"]:
         """Issue a "create check constraint" instruction using the
         current batch migration context.
 
@@ -732,8 +860,14 @@ class CreateIndexOp(MigrateOperation):
     """Represent a create index operation."""
 
     def __init__(
-        self, index_name, table_name, columns, schema=None, unique=False, **kw
-    ):
+        self,
+        index_name: str,
+        table_name: str,
+        columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]],
+        schema: Optional[str] = None,
+        unique: bool = False,
+        **kw
+    ) -> None:
         self.index_name = index_name
         self.table_name = table_name
         self.columns = columns
@@ -741,14 +875,15 @@ class CreateIndexOp(MigrateOperation):
         self.unique = unique
         self.kw = kw
 
-    def reverse(self):
+    def reverse(self) -> "DropIndexOp":
         return DropIndexOp.from_index(self.to_index())
 
-    def to_diff_tuple(self):
+    def to_diff_tuple(self) -> Tuple[str, "Index"]:
         return ("add_index", self.to_index())
 
     @classmethod
-    def from_index(cls, index):
+    def from_index(cls, index: "Index") -> "CreateIndexOp":
+        assert index.table is not None
         return cls(
             index.name,
             index.table.name,
@@ -758,7 +893,9 @@ class CreateIndexOp(MigrateOperation):
             **index.kwargs,
         )
 
-    def to_index(self, migration_context=None):
+    def to_index(
+        self, migration_context: Optional["MigrationContext"] = None
+    ) -> "Index":
         schema_obj = schemaobj.SchemaObjects(migration_context)
 
         idx = schema_obj.index(
@@ -774,14 +911,14 @@ class CreateIndexOp(MigrateOperation):
     @classmethod
     def create_index(
         cls,
-        operations,
-        index_name,
-        table_name,
-        columns,
-        schema=None,
-        unique=False,
+        operations: Operations,
+        index_name: str,
+        table_name: str,
+        columns: Sequence[Union[str, "TextClause", "Function"]],
+        schema: Optional[str] = None,
+        unique: bool = False,
         **kw
-    ):
+    ) -> Optional["Table"]:
         r"""Issue a "create index" instruction using the current
         migration context.
 
@@ -829,7 +966,13 @@ class CreateIndexOp(MigrateOperation):
         return operations.invoke(op)
 
     @classmethod
-    def batch_create_index(cls, operations, index_name, columns, **kw):
+    def batch_create_index(
+        cls,
+        operations: "BatchOperations",
+        index_name: str,
+        columns: List[str],
+        **kw
+    ) -> Optional["Table"]:
         """Issue a "create index" instruction using the
         current batch migration context.
 
@@ -855,22 +998,28 @@ class DropIndexOp(MigrateOperation):
     """Represent a drop index operation."""
 
     def __init__(
-        self, index_name, table_name=None, schema=None, _reverse=None, **kw
-    ):
+        self,
+        index_name: Union["quoted_name", str, "conv"],
+        table_name: Optional[str] = None,
+        schema: Optional[str] = None,
+        _reverse: Optional["CreateIndexOp"] = None,
+        **kw
+    ) -> None:
         self.index_name = index_name
         self.table_name = table_name
         self.schema = schema
         self._reverse = _reverse
         self.kw = kw
 
-    def to_diff_tuple(self):
+    def to_diff_tuple(self) -> Tuple[str, "Index"]:
         return ("remove_index", self.to_index())
 
-    def reverse(self):
+    def reverse(self) -> "CreateIndexOp":
         return CreateIndexOp.from_index(self.to_index())
 
     @classmethod
-    def from_index(cls, index):
+    def from_index(cls, index: "Index") -> "DropIndexOp":
+        assert index.table is not None
         return cls(
             index.name,
             index.table.name,
@@ -879,7 +1028,9 @@ class DropIndexOp(MigrateOperation):
             **index.kwargs,
         )
 
-    def to_index(self, migration_context=None):
+    def to_index(
+        self, migration_context: Optional["MigrationContext"] = None
+    ) -> "Index":
         schema_obj = schemaobj.SchemaObjects(migration_context)
 
         # need a dummy column name here since SQLAlchemy
@@ -894,8 +1045,13 @@ class DropIndexOp(MigrateOperation):
 
     @classmethod
     def drop_index(
-        cls, operations, index_name, table_name=None, schema=None, **kw
-    ):
+        cls,
+        operations: "Operations",
+        index_name: str,
+        table_name: Optional[str] = None,
+        schema: Optional[str] = None,
+        **kw
+    ) -> Optional["Table"]:
         r"""Issue a "drop index" instruction using the current
         migration context.
 
@@ -921,7 +1077,9 @@ class DropIndexOp(MigrateOperation):
         return operations.invoke(op)
 
     @classmethod
-    def batch_drop_index(cls, operations, index_name, **kw):
+    def batch_drop_index(
+        cls, operations: BatchOperations, index_name: str, **kw
+    ) -> Optional["Table"]:
         """Issue a "drop index" instruction using the
         current batch migration context.
 
@@ -946,13 +1104,13 @@ class CreateTableOp(MigrateOperation):
 
     def __init__(
         self,
-        table_name,
-        columns,
-        schema=None,
-        _namespace_metadata=None,
-        _constraints_included=False,
+        table_name: str,
+        columns: Sequence[Union["Column", "Constraint"]],
+        schema: Optional[str] = None,
+        _namespace_metadata: Optional["MetaData"] = None,
+        _constraints_included: bool = False,
         **kw
-    ):
+    ) -> None:
         self.table_name = table_name
         self.columns = columns
         self.schema = schema
@@ -963,22 +1121,24 @@ class CreateTableOp(MigrateOperation):
         self._namespace_metadata = _namespace_metadata
         self._constraints_included = _constraints_included
 
-    def reverse(self):
+    def reverse(self) -> "DropTableOp":
         return DropTableOp.from_table(
             self.to_table(), _namespace_metadata=self._namespace_metadata
         )
 
-    def to_diff_tuple(self):
+    def to_diff_tuple(self) -> Tuple[str, "Table"]:
         return ("add_table", self.to_table())
 
     @classmethod
-    def from_table(cls, table, _namespace_metadata=None):
+    def from_table(
+        cls, table: "Table", _namespace_metadata: Optional["MetaData"] = None
+    ) -> "CreateTableOp":
         if _namespace_metadata is None:
             _namespace_metadata = table.metadata
 
         return cls(
             table.name,
-            list(table.c) + list(table.constraints),
+            list(table.c) + list(table.constraints),  # type:ignore[arg-type]
             schema=table.schema,
             _namespace_metadata=_namespace_metadata,
             # given a Table() object, this Table will contain full Index()
@@ -989,12 +1149,14 @@ class CreateTableOp(MigrateOperation):
             # not doubled up. see #844 #848
             _constraints_included=True,
             comment=table.comment,
-            info=table.info.copy(),
+            info=dict(table.info),
             prefixes=list(table._prefixes),
             **table.kwargs,
         )
 
-    def to_table(self, migration_context=None):
+    def to_table(
+        self, migration_context: Optional["MigrationContext"] = None
+    ) -> "Table":
         schema_obj = schemaobj.SchemaObjects(migration_context)
 
         return schema_obj.table(
@@ -1009,7 +1171,9 @@ class CreateTableOp(MigrateOperation):
         )
 
     @classmethod
-    def create_table(cls, operations, table_name, *columns, **kw):
+    def create_table(
+        cls, operations: "Operations", table_name: str, *columns, **kw
+    ) -> Optional["Table"]:
         r"""Issue a "create table" instruction using the current migration
         context.
 
@@ -1094,7 +1258,13 @@ class CreateTableOp(MigrateOperation):
 class DropTableOp(MigrateOperation):
     """Represent a drop table operation."""
 
-    def __init__(self, table_name, schema=None, table_kw=None, _reverse=None):
+    def __init__(
+        self,
+        table_name: str,
+        schema: Optional[str] = None,
+        table_kw: Optional[MutableMapping[Any, Any]] = None,
+        _reverse: Optional["CreateTableOp"] = None,
+    ) -> None:
         self.table_name = table_name
         self.schema = schema
         self.table_kw = table_kw or {}
@@ -1103,20 +1273,22 @@ class DropTableOp(MigrateOperation):
         self.prefixes = self.table_kw.pop("prefixes", None)
         self._reverse = _reverse
 
-    def to_diff_tuple(self):
+    def to_diff_tuple(self) -> Tuple[str, "Table"]:
         return ("remove_table", self.to_table())
 
-    def reverse(self):
+    def reverse(self) -> "CreateTableOp":
         return CreateTableOp.from_table(self.to_table())
 
     @classmethod
-    def from_table(cls, table, _namespace_metadata=None):
+    def from_table(
+        cls, table: "Table", _namespace_metadata: Optional["MetaData"] = None
+    ) -> "DropTableOp":
         return cls(
             table.name,
             schema=table.schema,
             table_kw={
                 "comment": table.comment,
-                "info": table.info.copy(),
+                "info": dict(table.info),
                 "prefixes": list(table._prefixes),
                 **table.kwargs,
             },
@@ -1125,7 +1297,9 @@ class DropTableOp(MigrateOperation):
             ),
         )
 
-    def to_table(self, migration_context=None):
+    def to_table(
+        self, migration_context: Optional["MigrationContext"] = None
+    ) -> "Table":
         if self._reverse:
             cols_and_constraints = self._reverse.columns
         else:
@@ -1139,14 +1313,21 @@ class DropTableOp(MigrateOperation):
             info=self.info.copy() if self.info else {},
             prefixes=list(self.prefixes) if self.prefixes else [],
             schema=self.schema,
-            _constraints_included=bool(self._reverse)
-            and self._reverse._constraints_included,
+            _constraints_included=self._reverse._constraints_included
+            if self._reverse
+            else False,
             **self.table_kw,
         )
         return t
 
     @classmethod
-    def drop_table(cls, operations, table_name, schema=None, **kw):
+    def drop_table(
+        cls,
+        operations: "Operations",
+        table_name: str,
+        schema: Optional[str] = None,
+        **kw: Any
+    ) -> None:
         r"""Issue a "drop table" instruction using the current
         migration context.
 
@@ -1171,7 +1352,11 @@ class DropTableOp(MigrateOperation):
 class AlterTableOp(MigrateOperation):
     """Represent an alter table operation."""
 
-    def __init__(self, table_name, schema=None):
+    def __init__(
+        self,
+        table_name: str,
+        schema: Optional[str] = None,
+    ) -> None:
         self.table_name = table_name
         self.schema = schema
 
@@ -1180,14 +1365,23 @@ class AlterTableOp(MigrateOperation):
 class RenameTableOp(AlterTableOp):
     """Represent a rename table operation."""
 
-    def __init__(self, old_table_name, new_table_name, schema=None):
+    def __init__(
+        self,
+        old_table_name: str,
+        new_table_name: str,
+        schema: Optional[str] = None,
+    ) -> None:
         super(RenameTableOp, self).__init__(old_table_name, schema=schema)
         self.new_table_name = new_table_name
 
     @classmethod
     def rename_table(
-        cls, operations, old_table_name, new_table_name, schema=None
-    ):
+        cls,
+        operations: "Operations",
+        old_table_name: str,
+        new_table_name: str,
+        schema: Optional[str] = None,
+    ) -> Optional["Table"]:
         """Emit an ALTER TABLE to rename a table.
 
         :param old_table_name: old name.
@@ -1210,8 +1404,12 @@ class CreateTableCommentOp(AlterTableOp):
     """Represent a COMMENT ON `table` operation."""
 
     def __init__(
-        self, table_name, comment, schema=None, existing_comment=None
-    ):
+        self,
+        table_name: str,
+        comment: Optional[str],
+        schema: Optional[str] = None,
+        existing_comment: Optional[str] = None,
+    ) -> None:
         self.table_name = table_name
         self.comment = comment
         self.existing_comment = existing_comment
@@ -1220,12 +1418,12 @@ class CreateTableCommentOp(AlterTableOp):
     @classmethod
     def create_table_comment(
         cls,
-        operations,
-        table_name,
-        comment,
-        existing_comment=None,
-        schema=None,
-    ):
+        operations: "Operations",
+        table_name: str,
+        comment: Optional[str],
+        existing_comment: None = None,
+        schema: Optional[str] = None,
+    ) -> Optional["Table"]:
         """Emit a COMMENT ON operation to set the comment for a table.
 
         .. versionadded:: 1.0.6
@@ -1317,15 +1515,24 @@ class CreateTableCommentOp(AlterTableOp):
 class DropTableCommentOp(AlterTableOp):
     """Represent an operation to remove the comment from a table."""
 
-    def __init__(self, table_name, schema=None, existing_comment=None):
+    def __init__(
+        self,
+        table_name: str,
+        schema: Optional[str] = None,
+        existing_comment: Optional[str] = None,
+    ) -> None:
         self.table_name = table_name
         self.existing_comment = existing_comment
         self.schema = schema
 
     @classmethod
     def drop_table_comment(
-        cls, operations, table_name, existing_comment=None, schema=None
-    ):
+        cls,
+        operations: "Operations",
+        table_name: str,
+        existing_comment: Optional[str] = None,
+        schema: Optional[str] = None,
+    ) -> Optional["Table"]:
         """Issue a "drop table comment" operation to
         remove an existing comment set on a table.
 
@@ -1388,20 +1595,20 @@ class AlterColumnOp(AlterTableOp):
 
     def __init__(
         self,
-        table_name,
-        column_name,
-        schema=None,
-        existing_type=None,
-        existing_server_default=False,
-        existing_nullable=None,
-        existing_comment=None,
-        modify_nullable=None,
-        modify_comment=False,
-        modify_server_default=False,
-        modify_name=None,
-        modify_type=None,
+        table_name: str,
+        column_name: str,
+        schema: Optional[str] = None,
+        existing_type: Optional[Any] = None,
+        existing_server_default: Any = False,
+        existing_nullable: Optional[bool] = None,
+        existing_comment: Optional[str] = None,
+        modify_nullable: Optional[bool] = None,
+        modify_comment: Optional[Union[str, bool]] = False,
+        modify_server_default: Any = False,
+        modify_name: Optional[str] = None,
+        modify_type: Optional[Any] = None,
         **kw
-    ):
+    ) -> None:
         super(AlterColumnOp, self).__init__(table_name, schema=schema)
         self.column_name = column_name
         self.existing_type = existing_type
@@ -1415,7 +1622,7 @@ class AlterColumnOp(AlterTableOp):
         self.modify_type = modify_type
         self.kw = kw
 
-    def to_diff_tuple(self):
+    def to_diff_tuple(self) -> Any:
         col_diff = []
         schema, tname, cname = self.schema, self.table_name, self.column_name
 
@@ -1495,7 +1702,7 @@ class AlterColumnOp(AlterTableOp):
 
         return col_diff
 
-    def has_changes(self):
+    def has_changes(self) -> bool:
         hc1 = (
             self.modify_nullable is not None
             or self.modify_server_default is not False
@@ -1510,7 +1717,7 @@ class AlterColumnOp(AlterTableOp):
         else:
             return False
 
-    def reverse(self):
+    def reverse(self) -> "AlterColumnOp":
 
         kw = self.kw.copy()
         kw["existing_type"] = self.existing_type
@@ -1546,21 +1753,25 @@ class AlterColumnOp(AlterTableOp):
     @classmethod
     def alter_column(
         cls,
-        operations,
-        table_name,
-        column_name,
-        nullable=None,
-        comment=False,
-        server_default=False,
-        new_column_name=None,
-        type_=None,
-        existing_type=None,
-        existing_server_default=False,
-        existing_nullable=None,
-        existing_comment=None,
-        schema=None,
+        operations: Operations,
+        table_name: str,
+        column_name: str,
+        nullable: Optional[bool] = None,
+        comment: Optional[Union[str, bool]] = False,
+        server_default: Any = False,
+        new_column_name: Optional[str] = None,
+        type_: Optional[Union["TypeEngine", Type["TypeEngine"]]] = None,
+        existing_type: Optional[
+            Union["TypeEngine", Type["TypeEngine"]]
+        ] = None,
+        existing_server_default: Optional[
+            Union[str, bool, "Identity", "Computed"]
+        ] = False,
+        existing_nullable: Optional[bool] = None,
+        existing_comment: Optional[str] = None,
+        schema: Optional[str] = None,
         **kw
-    ):
+    ) -> Optional["Table"]:
         r"""Issue an "alter column" instruction using the
         current migration context.
 
@@ -1671,21 +1882,23 @@ class AlterColumnOp(AlterTableOp):
     @classmethod
     def batch_alter_column(
         cls,
-        operations,
-        column_name,
-        nullable=None,
-        comment=False,
-        server_default=False,
-        new_column_name=None,
-        type_=None,
-        existing_type=None,
-        existing_server_default=False,
-        existing_nullable=None,
-        existing_comment=None,
-        insert_before=None,
-        insert_after=None,
+        operations: BatchOperations,
+        column_name: str,
+        nullable: Optional[bool] = None,
+        comment: bool = False,
+        server_default: Union["Function", bool] = False,
+        new_column_name: Optional[str] = None,
+        type_: Optional[Union["TypeEngine", Type["TypeEngine"]]] = None,
+        existing_type: Optional[
+            Union["TypeEngine", Type["TypeEngine"]]
+        ] = None,
+        existing_server_default: bool = False,
+        existing_nullable: None = None,
+        existing_comment: None = None,
+        insert_before: None = None,
+        insert_after: None = None,
         **kw
-    ):
+    ) -> Optional["Table"]:
         """Issue an "alter column" instruction using the current
         batch migration context.
 
@@ -1736,32 +1949,51 @@ class AlterColumnOp(AlterTableOp):
 class AddColumnOp(AlterTableOp):
     """Represent an add column operation."""
 
-    def __init__(self, table_name, column, schema=None, **kw):
+    def __init__(
+        self,
+        table_name: str,
+        column: "Column",
+        schema: Optional[str] = None,
+        **kw
+    ) -> None:
         super(AddColumnOp, self).__init__(table_name, schema=schema)
         self.column = column
         self.kw = kw
 
-    def reverse(self):
+    def reverse(self) -> "DropColumnOp":
         return DropColumnOp.from_column_and_tablename(
             self.schema, self.table_name, self.column
         )
 
-    def to_diff_tuple(self):
+    def to_diff_tuple(
+        self,
+    ) -> Tuple[str, Optional[str], str, "Column"]:
         return ("add_column", self.schema, self.table_name, self.column)
 
-    def to_column(self):
+    def to_column(self) -> "Column":
         return self.column
 
     @classmethod
-    def from_column(cls, col):
+    def from_column(cls, col: "Column") -> "AddColumnOp":
         return cls(col.table.name, col, schema=col.table.schema)
 
     @classmethod
-    def from_column_and_tablename(cls, schema, tname, col):
+    def from_column_and_tablename(
+        cls,
+        schema: Optional[str],
+        tname: str,
+        col: "Column",
+    ) -> "AddColumnOp":
         return cls(tname, col, schema=schema)
 
     @classmethod
-    def add_column(cls, operations, table_name, column, schema=None):
+    def add_column(
+        cls,
+        operations: "Operations",
+        table_name: str,
+        column: "Column",
+        schema: Optional[str] = None,
+    ) -> Optional["Table"]:
         """Issue an "add column" instruction using the current
         migration context.
 
@@ -1816,8 +2048,12 @@ class AddColumnOp(AlterTableOp):
 
     @classmethod
     def batch_add_column(
-        cls, operations, column, insert_before=None, insert_after=None
-    ):
+        cls,
+        operations: "BatchOperations",
+        column: "Column",
+        insert_before: Optional[str] = None,
+        insert_after: Optional[str] = None,
+    ) -> Optional["Table"]:
         """Issue an "add column" instruction using the current
         batch migration context.
 
@@ -1848,14 +2084,21 @@ class DropColumnOp(AlterTableOp):
     """Represent a drop column operation."""
 
     def __init__(
-        self, table_name, column_name, schema=None, _reverse=None, **kw
-    ):
+        self,
+        table_name: str,
+        column_name: str,
+        schema: Optional[str] = None,
+        _reverse: Optional["AddColumnOp"] = None,
+        **kw
+    ) -> None:
         super(DropColumnOp, self).__init__(table_name, schema=schema)
         self.column_name = column_name
         self.kw = kw
         self._reverse = _reverse
 
-    def to_diff_tuple(self):
+    def to_diff_tuple(
+        self,
+    ) -> Tuple[str, Optional[str], str, "Column"]:
         return (
             "remove_column",
             self.schema,
@@ -1863,7 +2106,7 @@ class DropColumnOp(AlterTableOp):
             self.to_column(),
         )
 
-    def reverse(self):
+    def reverse(self) -> "AddColumnOp":
         if self._reverse is None:
             raise ValueError(
                 "operation is not reversible; "
@@ -1875,7 +2118,12 @@ class DropColumnOp(AlterTableOp):
         )
 
     @classmethod
-    def from_column_and_tablename(cls, schema, tname, col):
+    def from_column_and_tablename(
+        cls,
+        schema: Optional[str],
+        tname: str,
+        col: "Column",
+    ) -> "DropColumnOp":
         return cls(
             tname,
             col.name,
@@ -1883,7 +2131,9 @@ class DropColumnOp(AlterTableOp):
             _reverse=AddColumnOp.from_column_and_tablename(schema, tname, col),
         )
 
-    def to_column(self, migration_context=None):
+    def to_column(
+        self, migration_context: Optional["MigrationContext"] = None
+    ) -> "Column":
         if self._reverse is not None:
             return self._reverse.column
         schema_obj = schemaobj.SchemaObjects(migration_context)
@@ -1891,8 +2141,13 @@ class DropColumnOp(AlterTableOp):
 
     @classmethod
     def drop_column(
-        cls, operations, table_name, column_name, schema=None, **kw
-    ):
+        cls,
+        operations: "Operations",
+        table_name: str,
+        column_name: str,
+        schema: Optional[str] = None,
+        **kw
+    ) -> Optional["Table"]:
         """Issue a "drop column" instruction using the current
         migration context.
 
@@ -1934,7 +2189,9 @@ class DropColumnOp(AlterTableOp):
         return operations.invoke(op)
 
     @classmethod
-    def batch_drop_column(cls, operations, column_name, **kw):
+    def batch_drop_column(
+        cls, operations: "BatchOperations", column_name: str, **kw
+    ) -> Optional["Table"]:
         """Issue a "drop column" instruction using the current
         batch migration context.
 
@@ -1956,13 +2213,24 @@ class DropColumnOp(AlterTableOp):
 class BulkInsertOp(MigrateOperation):
     """Represent a bulk insert operation."""
 
-    def __init__(self, table, rows, multiinsert=True):
+    def __init__(
+        self,
+        table: Union["Table", "TableClause"],
+        rows: List[dict],
+        multiinsert: bool = True,
+    ) -> None:
         self.table = table
         self.rows = rows
         self.multiinsert = multiinsert
 
     @classmethod
-    def bulk_insert(cls, operations, table, rows, multiinsert=True):
+    def bulk_insert(
+        cls,
+        operations: Operations,
+        table: Union["Table", "TableClause"],
+        rows: List[dict],
+        multiinsert: bool = True,
+    ) -> None:
         """Issue a "bulk insert" operation using the current
         migration context.
 
@@ -2046,12 +2314,21 @@ class BulkInsertOp(MigrateOperation):
 class ExecuteSQLOp(MigrateOperation):
     """Represent an execute SQL operation."""
 
-    def __init__(self, sqltext, execution_options=None):
+    def __init__(
+        self,
+        sqltext: Union["Update", str, "Insert", "TextClause"],
+        execution_options: None = None,
+    ) -> None:
         self.sqltext = sqltext
         self.execution_options = execution_options
 
     @classmethod
-    def execute(cls, operations, sqltext, execution_options=None):
+    def execute(
+        cls,
+        operations: Operations,
+        sqltext: Union[str, "TextClause", "Update"],
+        execution_options: None = None,
+    ) -> Optional["Table"]:
         r"""Execute the given SQL using the current migration context.
 
         The given SQL can be a plain string, e.g.::
@@ -2140,20 +2417,22 @@ class ExecuteSQLOp(MigrateOperation):
 class OpContainer(MigrateOperation):
     """Represent a sequence of operations operation."""
 
-    def __init__(self, ops=()):
-        self.ops = ops
+    def __init__(self, ops: Sequence[MigrateOperation] = ()) -> None:
+        self.ops = list(ops)
 
-    def is_empty(self):
+    def is_empty(self) -> bool:
         return not self.ops
 
-    def as_diffs(self):
+    def as_diffs(self) -> Any:
         return list(OpContainer._ops_as_diffs(self))
 
     @classmethod
-    def _ops_as_diffs(cls, migrations):
+    def _ops_as_diffs(
+        cls, migrations: "OpContainer"
+    ) -> Iterator[Tuple[Any, ...]]:
         for op in migrations.ops:
             if hasattr(op, "ops"):
-                for sub_op in cls._ops_as_diffs(op):
+                for sub_op in cls._ops_as_diffs(cast("OpContainer", op)):
                     yield sub_op
             else:
                 yield op.to_diff_tuple()
@@ -2162,12 +2441,17 @@ class OpContainer(MigrateOperation):
 class ModifyTableOps(OpContainer):
     """Contains a sequence of operations that all apply to a single Table."""
 
-    def __init__(self, table_name, ops, schema=None):
+    def __init__(
+        self,
+        table_name: str,
+        ops: Sequence[MigrateOperation],
+        schema: Optional[str] = None,
+    ) -> None:
         super(ModifyTableOps, self).__init__(ops)
         self.table_name = table_name
         self.schema = schema
 
-    def reverse(self):
+    def reverse(self) -> "ModifyTableOps":
         return ModifyTableOps(
             self.table_name,
             ops=list(reversed([op.reverse() for op in self.ops])),
@@ -2185,17 +2469,21 @@ class UpgradeOps(OpContainer):
 
     """
 
-    def __init__(self, ops=(), upgrade_token="upgrades"):
+    def __init__(
+        self,
+        ops: Sequence[MigrateOperation] = (),
+        upgrade_token: str = "upgrades",
+    ) -> None:
         super(UpgradeOps, self).__init__(ops=ops)
         self.upgrade_token = upgrade_token
 
-    def reverse_into(self, downgrade_ops):
-        downgrade_ops.ops[:] = list(
+    def reverse_into(self, downgrade_ops: "DowngradeOps") -> "DowngradeOps":
+        downgrade_ops.ops[:] = list(  # type:ignore[index]
             reversed([op.reverse() for op in self.ops])
         )
         return downgrade_ops
 
-    def reverse(self):
+    def reverse(self) -> "DowngradeOps":
         return self.reverse_into(DowngradeOps(ops=[]))
 
 
@@ -2209,7 +2497,11 @@ class DowngradeOps(OpContainer):
 
     """
 
-    def __init__(self, ops=(), downgrade_token="downgrades"):
+    def __init__(
+        self,
+        ops: Sequence[MigrateOperation] = (),
+        downgrade_token: str = "downgrades",
+    ) -> None:
         super(DowngradeOps, self).__init__(ops=ops)
         self.downgrade_token = downgrade_token
 
@@ -2243,19 +2535,21 @@ class MigrationScript(MigrateOperation):
 
     """
 
+    _needs_render: Optional[bool]
+
     def __init__(
         self,
-        rev_id,
-        upgrade_ops,
-        downgrade_ops,
-        message=None,
-        imports=set(),
-        head=None,
-        splice=None,
-        branch_label=None,
-        version_path=None,
-        depends_on=None,
-    ):
+        rev_id: Optional[str],
+        upgrade_ops: "UpgradeOps",
+        downgrade_ops: "DowngradeOps",
+        message: Optional[str] = None,
+        imports: Set[str] = set(),
+        head: Optional[str] = None,
+        splice: Optional[bool] = None,
+        branch_label: Optional[str] = None,
+        version_path: Optional[str] = None,
+        depends_on: Optional[Union[str, Sequence[str]]] = None,
+    ) -> None:
         self.rev_id = rev_id
         self.message = message
         self.imports = imports
@@ -2318,7 +2612,7 @@ class MigrationScript(MigrateOperation):
             assert isinstance(elem, DowngradeOps)
 
     @property
-    def upgrade_ops_list(self):
+    def upgrade_ops_list(self) -> List["UpgradeOps"]:
         """A list of :class:`.UpgradeOps` instances.
 
         This is used in place of the :attr:`.MigrationScript.upgrade_ops`
@@ -2329,7 +2623,7 @@ class MigrationScript(MigrateOperation):
         return self._upgrade_ops
 
     @property
-    def downgrade_ops_list(self):
+    def downgrade_ops_list(self) -> List["DowngradeOps"]:
         """A list of :class:`.DowngradeOps` instances.
 
         This is used in place of the :attr:`.MigrationScript.downgrade_ops`
index adbffdc013707b716a610549d028041dc4af8b35..0d40dc7853d539ad3c06888338a4fdccfc6ae2ae 100644 (file)
@@ -1,3 +1,12 @@
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
+
 from sqlalchemy import schema as sa_schema
 from sqlalchemy.sql.schema import Column
 from sqlalchemy.sql.schema import Constraint
@@ -9,34 +18,59 @@ from .. import util
 from ..util import sqla_compat
 from ..util.compat import string_types
 
+if TYPE_CHECKING:
+    from sqlalchemy.sql.elements import ColumnElement
+    from sqlalchemy.sql.elements import TextClause
+    from sqlalchemy.sql.schema import CheckConstraint
+    from sqlalchemy.sql.schema import ForeignKey
+    from sqlalchemy.sql.schema import ForeignKeyConstraint
+    from sqlalchemy.sql.schema import MetaData
+    from sqlalchemy.sql.schema import PrimaryKeyConstraint
+    from sqlalchemy.sql.schema import Table
+    from sqlalchemy.sql.schema import UniqueConstraint
+    from sqlalchemy.sql.type_api import TypeEngine
+
+    from ..runtime.migration import MigrationContext
+
 
 class SchemaObjects:
-    def __init__(self, migration_context=None):
+    def __init__(
+        self, migration_context: Optional["MigrationContext"] = None
+    ) -> None:
         self.migration_context = migration_context
 
-    def primary_key_constraint(self, name, table_name, cols, schema=None):
+    def primary_key_constraint(
+        self,
+        name: Optional[str],
+        table_name: str,
+        cols: Sequence[str],
+        schema: Optional[str] = None,
+        **dialect_kw
+    ) -> "PrimaryKeyConstraint":
         m = self.metadata()
         columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
         t = sa_schema.Table(table_name, m, *columns, schema=schema)
-        p = sa_schema.PrimaryKeyConstraint(*[t.c[n] for n in cols], name=name)
+        p = sa_schema.PrimaryKeyConstraint(
+            *[t.c[n] for n in cols], name=name, **dialect_kw
+        )
         return p
 
     def foreign_key_constraint(
         self,
-        name,
-        source,
-        referent,
-        local_cols,
-        remote_cols,
-        onupdate=None,
-        ondelete=None,
-        deferrable=None,
-        source_schema=None,
-        referent_schema=None,
-        initially=None,
-        match=None,
+        name: Optional[str],
+        source: str,
+        referent: str,
+        local_cols: List[str],
+        remote_cols: List[str],
+        onupdate: Optional[str] = None,
+        ondelete: Optional[str] = None,
+        deferrable: Optional[bool] = None,
+        source_schema: Optional[str] = None,
+        referent_schema: Optional[str] = None,
+        initially: Optional[str] = None,
+        match: Optional[str] = None,
         **dialect_kw
-    ):
+    ) -> "ForeignKeyConstraint":
         m = self.metadata()
         if source == referent and source_schema == referent_schema:
             t1_cols = local_cols + remote_cols
@@ -78,7 +112,14 @@ class SchemaObjects:
 
         return f
 
-    def unique_constraint(self, name, source, local_cols, schema=None, **kw):
+    def unique_constraint(
+        self,
+        name: Optional[str],
+        source: str,
+        local_cols: Sequence[str],
+        schema: Optional[str] = None,
+        **kw
+    ) -> "UniqueConstraint":
         t = sa_schema.Table(
             source,
             self.metadata(),
@@ -92,7 +133,14 @@ class SchemaObjects:
         t.append_constraint(uq)
         return uq
 
-    def check_constraint(self, name, source, condition, schema=None, **kw):
+    def check_constraint(
+        self,
+        name: Optional[str],
+        source: str,
+        condition: Union["TextClause", "ColumnElement[Any]"],
+        schema: Optional[str] = None,
+        **kw
+    ) -> Union["CheckConstraint"]:
         t = sa_schema.Table(
             source,
             self.metadata(),
@@ -103,9 +151,16 @@ class SchemaObjects:
         t.append_constraint(ck)
         return ck
 
-    def generic_constraint(self, name, table_name, type_, schema=None, **kw):
+    def generic_constraint(
+        self,
+        name: Optional[str],
+        table_name: str,
+        type_: Optional[str],
+        schema: Optional[str] = None,
+        **kw
+    ) -> Any:
         t = self.table(table_name, schema=schema)
-        types = {
+        types: Dict[Optional[str], Any] = {
             "foreignkey": lambda name: sa_schema.ForeignKeyConstraint(
                 [], [], name=name
             ),
@@ -126,7 +181,7 @@ class SchemaObjects:
             t.append_constraint(const)
             return const
 
-    def metadata(self):
+    def metadata(self) -> "MetaData":
         kw = {}
         if (
             self.migration_context is not None
@@ -137,7 +192,7 @@ class SchemaObjects:
                 kw["naming_convention"] = mt.naming_convention
         return sa_schema.MetaData(**kw)
 
-    def table(self, name, *columns, **kw):
+    def table(self, name: str, *columns, **kw) -> "Table":
         m = self.metadata()
 
         cols = [
@@ -173,10 +228,17 @@ class SchemaObjects:
             self._ensure_table_for_fk(m, f)
         return t
 
-    def column(self, name, type_, **kw):
+    def column(self, name: str, type_: "TypeEngine", **kw) -> "Column":
         return sa_schema.Column(name, type_, **kw)
 
-    def index(self, name, tablename, columns, schema=None, **kw):
+    def index(
+        self,
+        name: str,
+        tablename: Optional[str],
+        columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]],
+        schema: Optional[str] = None,
+        **kw
+    ) -> "Index":
         t = sa_schema.Table(
             tablename or "no_table",
             self.metadata(),
@@ -190,23 +252,27 @@ class SchemaObjects:
         )
         return idx
 
-    def _parse_table_key(self, table_key):
+    def _parse_table_key(self, table_key: str) -> Tuple[Optional[str], str]:
         if "." in table_key:
             tokens = table_key.split(".")
-            sname = ".".join(tokens[0:-1])
+            sname: Optional[str] = ".".join(tokens[0:-1])
             tname = tokens[-1]
         else:
             tname = table_key
             sname = None
         return (sname, tname)
 
-    def _ensure_table_for_fk(self, metadata, fk):
+    def _ensure_table_for_fk(
+        self, metadata: "MetaData", fk: "ForeignKey"
+    ) -> None:
         """create a placeholder Table object for the referent of a
         ForeignKey.
 
         """
-        if isinstance(fk._colspec, string_types):
-            table_key, cname = fk._colspec.rsplit(".", 1)
+        if isinstance(fk._colspec, string_types):  # type:ignore[attr-defined]
+            table_key, cname = fk._colspec.rsplit(  # type:ignore[attr-defined]
+                ".", 1
+            )
             sname, tname = self._parse_table_key(table_key)
             if table_key not in metadata.tables:
                 rel_t = sa_schema.Table(tname, metadata, schema=sname)
index 10a41e48c1016d3850997e25598d65ce376b3c71..f97983e66a523e862fb512253d6b490f1a2b85b6 100644 (file)
@@ -1,12 +1,19 @@
+from typing import TYPE_CHECKING
+
 from sqlalchemy import schema as sa_schema
 
 from . import ops
 from .base import Operations
 from ..util.sqla_compat import _copy
 
+if TYPE_CHECKING:
+    from sqlalchemy.sql.schema import Table
+
 
 @Operations.implementation_for(ops.AlterColumnOp)
-def alter_column(operations, operation):
+def alter_column(
+    operations: "Operations", operation: "ops.AlterColumnOp"
+) -> None:
 
     compiler = operations.impl.dialect.statement_compiler(
         operations.impl.dialect, None
@@ -68,14 +75,16 @@ def alter_column(operations, operation):
 
 
 @Operations.implementation_for(ops.DropTableOp)
-def drop_table(operations, operation):
+def drop_table(operations: "Operations", operation: "ops.DropTableOp") -> None:
     operations.impl.drop_table(
         operation.to_table(operations.migration_context)
     )
 
 
 @Operations.implementation_for(ops.DropColumnOp)
-def drop_column(operations, operation):
+def drop_column(
+    operations: "Operations", operation: "ops.DropColumnOp"
+) -> None:
     column = operation.to_column(operations.migration_context)
     operations.impl.drop_column(
         operation.table_name, column, schema=operation.schema, **operation.kw
@@ -83,46 +92,56 @@ def drop_column(operations, operation):
 
 
 @Operations.implementation_for(ops.CreateIndexOp)
-def create_index(operations, operation):
+def create_index(
+    operations: "Operations", operation: "ops.CreateIndexOp"
+) -> None:
     idx = operation.to_index(operations.migration_context)
     operations.impl.create_index(idx)
 
 
 @Operations.implementation_for(ops.DropIndexOp)
-def drop_index(operations, operation):
+def drop_index(operations: "Operations", operation: "ops.DropIndexOp") -> None:
     operations.impl.drop_index(
         operation.to_index(operations.migration_context)
     )
 
 
 @Operations.implementation_for(ops.CreateTableOp)
-def create_table(operations, operation):
+def create_table(
+    operations: "Operations", operation: "ops.CreateTableOp"
+) -> "Table":
     table = operation.to_table(operations.migration_context)
     operations.impl.create_table(table)
     return table
 
 
 @Operations.implementation_for(ops.RenameTableOp)
-def rename_table(operations, operation):
+def rename_table(
+    operations: "Operations", operation: "ops.RenameTableOp"
+) -> None:
     operations.impl.rename_table(
         operation.table_name, operation.new_table_name, schema=operation.schema
     )
 
 
 @Operations.implementation_for(ops.CreateTableCommentOp)
-def create_table_comment(operations, operation):
+def create_table_comment(
+    operations: "Operations", operation: "ops.CreateTableCommentOp"
+) -> None:
     table = operation.to_table(operations.migration_context)
     operations.impl.create_table_comment(table)
 
 
 @Operations.implementation_for(ops.DropTableCommentOp)
-def drop_table_comment(operations, operation):
+def drop_table_comment(
+    operations: "Operations", operation: "ops.DropTableCommentOp"
+) -> None:
     table = operation.to_table(operations.migration_context)
     operations.impl.drop_table_comment(table)
 
 
 @Operations.implementation_for(ops.AddColumnOp)
-def add_column(operations, operation):
+def add_column(operations: "Operations", operation: "ops.AddColumnOp") -> None:
     table_name = operation.table_name
     column = operation.column
     schema = operation.schema
@@ -150,14 +169,18 @@ def add_column(operations, operation):
 
 
 @Operations.implementation_for(ops.AddConstraintOp)
-def create_constraint(operations, operation):
+def create_constraint(
+    operations: "Operations", operation: "ops.AddConstraintOp"
+) -> None:
     operations.impl.add_constraint(
         operation.to_constraint(operations.migration_context)
     )
 
 
 @Operations.implementation_for(ops.DropConstraintOp)
-def drop_constraint(operations, operation):
+def drop_constraint(
+    operations: "Operations", operation: "ops.DropConstraintOp"
+) -> None:
     operations.impl.drop_constraint(
         operations.schema_obj.generic_constraint(
             operation.constraint_name,
@@ -169,14 +192,18 @@ def drop_constraint(operations, operation):
 
 
 @Operations.implementation_for(ops.BulkInsertOp)
-def bulk_insert(operations, operation):
+def bulk_insert(
+    operations: "Operations", operation: "ops.BulkInsertOp"
+) -> None:
     operations.impl.bulk_insert(
         operation.table, operation.rows, multiinsert=operation.multiinsert
     )
 
 
 @Operations.implementation_for(ops.ExecuteSQLOp)
-def execute_sql(operations, operation):
+def execute_sql(
+    operations: "Operations", operation: "ops.ExecuteSQLOp"
+) -> None:
     operations.migration_context.impl.execute(
         operation.sqltext, execution_options=operation.execution_options
     )
index e4f4d42bb8ce26b3ebf81b567126f2e8105e7dd4..f3473de40c273626745f250c067220f8e543ec51 100644 (file)
@@ -1,7 +1,30 @@
+from typing import Callable
+from typing import ContextManager
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import TextIO
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
+
 from .migration import MigrationContext
 from .. import util
 from ..operations import Operations
 
+if TYPE_CHECKING:
+    from typing import Literal
+
+    from sqlalchemy.engine.base import Connection
+    from sqlalchemy.sql.schema import MetaData
+
+    from .migration import _ProxyTransaction
+    from ..config import Config
+    from ..script.base import ScriptDirectory
+
+_RevNumber = Optional[Union[str, Tuple[str, ...]]]
+
 
 class EnvironmentContext(util.ModuleClsProxy):
 
@@ -66,21 +89,23 @@ class EnvironmentContext(util.ModuleClsProxy):
 
     """
 
-    _migration_context = None
+    _migration_context: Optional["MigrationContext"] = None
 
-    config = None
+    config: "Config" = None  # type:ignore[assignment]
     """An instance of :class:`.Config` representing the
     configuration file contents as well as other variables
     set programmatically within it."""
 
-    script = None
+    script: "ScriptDirectory" = None  # type:ignore[assignment]
     """An instance of :class:`.ScriptDirectory` which provides
     programmatic access to version files within the ``versions/``
     directory.
 
     """
 
-    def __init__(self, config, script, **kw):
+    def __init__(
+        self, config: "Config", script: "ScriptDirectory", **kw
+    ) -> None:
         r"""Construct a new :class:`.EnvironmentContext`.
 
         :param config: a :class:`.Config` instance.
@@ -94,7 +119,7 @@ class EnvironmentContext(util.ModuleClsProxy):
         self.script = script
         self.context_opts = kw
 
-    def __enter__(self):
+    def __enter__(self) -> "EnvironmentContext":
         """Establish a context which provides a
         :class:`.EnvironmentContext` object to
         env.py scripts.
@@ -106,10 +131,10 @@ class EnvironmentContext(util.ModuleClsProxy):
         self._install_proxy()
         return self
 
-    def __exit__(self, *arg, **kw):
+    def __exit__(self, *arg, **kw) -> None:
         self._remove_proxy()
 
-    def is_offline_mode(self):
+    def is_offline_mode(self) -> bool:
         """Return True if the current migrations environment
         is running in "offline mode".
 
@@ -136,10 +161,10 @@ class EnvironmentContext(util.ModuleClsProxy):
         """
         return self.get_context().impl.transactional_ddl
 
-    def requires_connection(self):
+    def requires_connection(self) -> bool:
         return not self.is_offline_mode()
 
-    def get_head_revision(self):
+    def get_head_revision(self) -> _RevNumber:
         """Return the hex identifier of the 'head' script revision.
 
         If the script directory has multiple heads, this
@@ -154,7 +179,7 @@ class EnvironmentContext(util.ModuleClsProxy):
         """
         return self.script.as_revision_number("head")
 
-    def get_head_revisions(self):
+    def get_head_revisions(self) -> _RevNumber:
         """Return the hex identifier of the 'heads' script revision(s).
 
         This returns a tuple containing the version number of all
@@ -166,7 +191,7 @@ class EnvironmentContext(util.ModuleClsProxy):
         """
         return self.script.as_revision_number("heads")
 
-    def get_starting_revision_argument(self):
+    def get_starting_revision_argument(self) -> _RevNumber:
         """Return the 'starting revision' argument,
         if the revision was passed using ``start:end``.
 
@@ -195,7 +220,7 @@ class EnvironmentContext(util.ModuleClsProxy):
                 "No starting revision argument is available."
             )
 
-    def get_revision_argument(self):
+    def get_revision_argument(self) -> _RevNumber:
         """Get the 'destination' revision argument.
 
         This is typically the argument passed to the
@@ -213,7 +238,7 @@ class EnvironmentContext(util.ModuleClsProxy):
             self.context_opts["destination_rev"]
         )
 
-    def get_tag_argument(self):
+    def get_tag_argument(self) -> Optional[str]:
         """Return the value passed for the ``--tag`` argument, if any.
 
         The ``--tag`` argument is not used directly by Alembic,
@@ -233,7 +258,19 @@ class EnvironmentContext(util.ModuleClsProxy):
         """
         return self.context_opts.get("tag", None)
 
-    def get_x_argument(self, as_dictionary=False):
+    @overload
+    def get_x_argument(  # type:ignore[misc]
+        self, as_dictionary: "Literal[False]" = ...
+    ) -> List[str]:
+        ...
+
+    @overload
+    def get_x_argument(  # type:ignore[misc]
+        self, as_dictionary: "Literal[True]" = ...
+    ) -> Dict[str, str]:
+        ...
+
+    def get_x_argument(self, as_dictionary: bool = False):
         """Return the value(s) passed for the ``-x`` argument, if any.
 
         The ``-x`` argument is an open ended flag that allows any user-defined
@@ -282,34 +319,34 @@ class EnvironmentContext(util.ModuleClsProxy):
 
     def configure(
         self,
-        connection=None,
-        url=None,
-        dialect_name=None,
-        dialect_opts=None,
-        transactional_ddl=None,
-        transaction_per_migration=False,
-        output_buffer=None,
-        starting_rev=None,
-        tag=None,
-        template_args=None,
-        render_as_batch=False,
-        target_metadata=None,
-        include_name=None,
-        include_object=None,
-        include_schemas=False,
-        process_revision_directives=None,
-        compare_type=False,
-        compare_server_default=False,
-        render_item=None,
-        literal_binds=False,
-        upgrade_token="upgrades",
-        downgrade_token="downgrades",
-        alembic_module_prefix="op.",
-        sqlalchemy_module_prefix="sa.",
-        user_module_prefix=None,
-        on_version_apply=None,
+        connection: Optional["Connection"] = None,
+        url: Optional[str] = None,
+        dialect_name: Optional[str] = None,
+        dialect_opts: Optional[dict] = None,
+        transactional_ddl: Optional[bool] = None,
+        transaction_per_migration: bool = False,
+        output_buffer: Optional[TextIO] = None,
+        starting_rev: Optional[str] = None,
+        tag: Optional[str] = None,
+        template_args: Optional[dict] = None,
+        render_as_batch: bool = False,
+        target_metadata: Optional["MetaData"] = None,
+        include_name: Optional[Callable] = None,
+        include_object: Optional[Callable] = None,
+        include_schemas: bool = False,
+        process_revision_directives: Optional[Callable] = None,
+        compare_type: bool = False,
+        compare_server_default: bool = False,
+        render_item: Optional[Callable] = None,
+        literal_binds: bool = False,
+        upgrade_token: str = "upgrades",
+        downgrade_token: str = "downgrades",
+        alembic_module_prefix: str = "op.",
+        sqlalchemy_module_prefix: str = "sa.",
+        user_module_prefix: Optional[str] = None,
+        on_version_apply: Optional[Callable] = None,
         **kw
-    ):
+    ) -> None:
         """Configure a :class:`.MigrationContext` within this
         :class:`.EnvironmentContext` which will provide database
         connectivity and other configuration to a series of
@@ -789,7 +826,7 @@ class EnvironmentContext(util.ModuleClsProxy):
             opts=opts,
         )
 
-    def run_migrations(self, **kw):
+    def run_migrations(self, **kw) -> None:
         """Run migrations as determined by the current command line
         configuration
         as well as versioning information present (or not) in the current
@@ -809,6 +846,7 @@ class EnvironmentContext(util.ModuleClsProxy):
         first been made available via :meth:`.configure`.
 
         """
+        assert self._migration_context is not None
         with Operations.context(self._migration_context):
             self.get_context().run_migrations(**kw)
 
@@ -837,7 +875,9 @@ class EnvironmentContext(util.ModuleClsProxy):
         """
         self.get_context().impl.static_output(text)
 
-    def begin_transaction(self):
+    def begin_transaction(
+        self,
+    ) -> Union["_ProxyTransaction", ContextManager]:
         """Return a context manager that will
         enclose an operation within a "transaction",
         as defined by the environment's offline
@@ -883,7 +923,7 @@ class EnvironmentContext(util.ModuleClsProxy):
 
         return self.get_context().begin_transaction()
 
-    def get_context(self):
+    def get_context(self) -> "MigrationContext":
         """Return the current :class:`.MigrationContext` object.
 
         If :meth:`.EnvironmentContext.configure` has not been
index 5ed2136da366e1fe26b29a97f74a0c3659649dba..c64e91f85eb1a38b2fb2ac0acde44032bbde7eda 100644 (file)
@@ -1,6 +1,18 @@
 from contextlib import contextmanager
 import logging
 import sys
+from typing import Any
+from typing import cast
+from typing import Collection
+from typing import ContextManager
+from typing import Dict
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
 
 from sqlalchemy import Column
 from sqlalchemy import literal_column
@@ -17,29 +29,46 @@ from .. import util
 from ..util import sqla_compat
 from ..util.compat import EncodedIO
 
+if TYPE_CHECKING:
+    from sqlalchemy.engine import Dialect
+    from sqlalchemy.engine.base import Connection
+    from sqlalchemy.engine.base import Transaction
+    from sqlalchemy.engine.mock import MockConnection
+
+    from .environment import EnvironmentContext
+    from ..config import Config
+    from ..script.base import Script
+    from ..script.base import ScriptDirectory
+    from ..script.revision import Revision
+    from ..script.revision import RevisionMap
+
 log = logging.getLogger(__name__)
 
 
 class _ProxyTransaction:
-    def __init__(self, migration_context):
+    def __init__(self, migration_context: "MigrationContext") -> None:
         self.migration_context = migration_context
 
     @property
-    def _proxied_transaction(self):
+    def _proxied_transaction(self) -> Optional["Transaction"]:
         return self.migration_context._transaction
 
-    def rollback(self):
-        self._proxied_transaction.rollback()
+    def rollback(self) -> None:
+        t = self._proxied_transaction
+        assert t is not None
+        t.rollback()
         self.migration_context._transaction = None
 
-    def commit(self):
-        self._proxied_transaction.commit()
+    def commit(self) -> None:
+        t = self._proxied_transaction
+        assert t is not None
+        t.commit()
         self.migration_context._transaction = None
 
-    def __enter__(self):
+    def __enter__(self) -> "_ProxyTransaction":
         return self
 
-    def __exit__(self, type_, value, traceback):
+    def __exit__(self, type_: None, value: None, traceback: None) -> None:
         if self._proxied_transaction is not None:
             self._proxied_transaction.__exit__(type_, value, traceback)
             self.migration_context._transaction = None
@@ -92,21 +121,29 @@ class MigrationContext:
 
     """
 
-    def __init__(self, dialect, connection, opts, environment_context=None):
+    def __init__(
+        self,
+        dialect: "Dialect",
+        connection: Optional["Connection"],
+        opts: Dict[str, Any],
+        environment_context: Optional["EnvironmentContext"] = None,
+    ) -> None:
         self.environment_context = environment_context
         self.opts = opts
         self.dialect = dialect
-        self.script = opts.get("script")
-        as_sql = opts.get("as_sql", False)
+        self.script: Optional["ScriptDirectory"] = opts.get("script")
+        as_sql: bool = opts.get("as_sql", False)
         transactional_ddl = opts.get("transactional_ddl")
         self._transaction_per_migration = opts.get(
             "transaction_per_migration", False
         )
         self.on_version_apply_callbacks = opts.get("on_version_apply", ())
-        self._transaction = None
+        self._transaction: Optional["Transaction"] = None
 
         if as_sql:
-            self.connection = self._stdout_connection(connection)
+            self.connection = cast(
+                Optional["Connection"], self._stdout_connection(connection)
+            )
             assert self.connection is not None
             self._in_external_transaction = False
         else:
@@ -122,7 +159,8 @@ class MigrationContext:
 
         if "output_encoding" in opts:
             self.output_buffer = EncodedIO(
-                opts.get("output_buffer") or sys.stdout,
+                opts.get("output_buffer")
+                or sys.stdout,  # type:ignore[arg-type]
                 opts["output_encoding"],
             )
         else:
@@ -151,7 +189,7 @@ class MigrationContext:
                 )
             )
 
-        self._start_from_rev = opts.get("starting_rev")
+        self._start_from_rev: Optional[str] = opts.get("starting_rev")
         self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
             dialect,
             self.connection,
@@ -173,14 +211,14 @@ class MigrationContext:
     @classmethod
     def configure(
         cls,
-        connection=None,
-        url=None,
-        dialect_name=None,
-        dialect=None,
-        environment_context=None,
-        dialect_opts=None,
-        opts=None,
-    ):
+        connection: Optional["Connection"] = None,
+        url: Optional[str] = None,
+        dialect_name: Optional[str] = None,
+        dialect: Optional["Dialect"] = None,
+        environment_context: Optional["EnvironmentContext"] = None,
+        dialect_opts: Optional[Dict[str, str]] = None,
+        opts: Optional[Any] = None,
+    ) -> "MigrationContext":
         """Create a new :class:`.MigrationContext`.
 
         This is a factory method usually called
@@ -216,18 +254,18 @@ class MigrationContext:
 
             dialect = connection.dialect
         elif url:
-            url = sqla_url.make_url(url)
-            dialect = url.get_dialect()(**dialect_opts)
+            url_obj = sqla_url.make_url(url)
+            dialect = url_obj.get_dialect()(**dialect_opts)
         elif dialect_name:
-            url = sqla_url.make_url("%s://" % dialect_name)
-            dialect = url.get_dialect()(**dialect_opts)
+            url_obj = sqla_url.make_url("%s://" % dialect_name)
+            dialect = url_obj.get_dialect()(**dialect_opts)
         elif not dialect:
             raise Exception("Connection, url, or dialect_name is required.")
-
+        assert dialect is not None
         return MigrationContext(dialect, connection, opts, environment_context)
 
     @contextmanager
-    def autocommit_block(self):
+    def autocommit_block(self) -> Iterator[None]:
         """Enter an "autocommit" block, for databases that support AUTOCOMMIT
         isolation levels.
 
@@ -285,6 +323,7 @@ class MigrationContext:
             self._transaction = None
 
         if not self.as_sql:
+            assert self.connection is not None
             current_level = self.connection.get_isolation_level()
             base_connection = self.connection
 
@@ -300,6 +339,7 @@ class MigrationContext:
             yield
         finally:
             if not self.as_sql:
+                assert self.connection is not None
                 self.connection.execution_options(
                     isolation_level=current_level
                 )
@@ -309,9 +349,12 @@ class MigrationContext:
                 self.impl.emit_begin()
 
             elif _in_connection_transaction:
+                assert self.connection is not None
                 self._transaction = self.connection.begin()
 
-    def begin_transaction(self, _per_migration=False):
+    def begin_transaction(
+        self, _per_migration: bool = False
+    ) -> Union["_ProxyTransaction", ContextManager]:
         """Begin a logical transaction for migration operations.
 
         This method is used within an ``env.py`` script to demarcate where
@@ -390,6 +433,7 @@ class MigrationContext:
                 if in_transaction:
                     return do_nothing()
                 else:
+                    assert self.connection is not None
                     self._transaction = (
                         sqla_compat._safe_begin_connection_transaction(
                             self.connection
@@ -406,12 +450,13 @@ class MigrationContext:
 
             return begin_commit()
         else:
+            assert self.connection is not None
             self._transaction = sqla_compat._safe_begin_connection_transaction(
                 self.connection
             )
             return _ProxyTransaction(self)
 
-    def get_current_revision(self):
+    def get_current_revision(self) -> Optional[str]:
         """Return the current revision, usually that which is present
         in the ``alembic_version`` table in the database.
 
@@ -438,7 +483,7 @@ class MigrationContext:
         else:
             return heads[0]
 
-    def get_current_heads(self):
+    def get_current_heads(self) -> Tuple[str, ...]:
         """Return a tuple of the current 'head versions' that are represented
         in the target database.
 
@@ -457,7 +502,7 @@ class MigrationContext:
 
         """
         if self.as_sql:
-            start_from_rev = self._start_from_rev
+            start_from_rev: Any = self._start_from_rev
             if start_from_rev == "base":
                 start_from_rev = None
             elif start_from_rev is not None and self.script:
@@ -476,22 +521,27 @@ class MigrationContext:
                 )
             if not self._has_version_table():
                 return ()
+        assert self.connection is not None
         return tuple(
             row[0] for row in self.connection.execute(self._version.select())
         )
 
-    def _ensure_version_table(self, purge=False):
+    def _ensure_version_table(self, purge: bool = False) -> None:
         with sqla_compat._ensure_scope_for_ddl(self.connection):
             self._version.create(self.connection, checkfirst=True)
             if purge:
+                assert self.connection is not None
                 self.connection.execute(self._version.delete())
 
-    def _has_version_table(self):
+    def _has_version_table(self) -> bool:
+        assert self.connection is not None
         return sqla_compat._connectable_has_table(
             self.connection, self.version_table, self.version_table_schema
         )
 
-    def stamp(self, script_directory, revision):
+    def stamp(
+        self, script_directory: "ScriptDirectory", revision: str
+    ) -> None:
         """Stamp the version table with a specific revision.
 
         This method calculates those branches to which the given revision
@@ -507,7 +557,7 @@ class MigrationContext:
         for step in script_directory._stamp_revs(revision, heads):
             head_maintainer.update_to_step(step)
 
-    def run_migrations(self, **kw):
+    def run_migrations(self, **kw) -> None:
         r"""Run the migration scripts established for this
         :class:`.MigrationContext`, if any.
 
@@ -530,6 +580,7 @@ class MigrationContext:
         """
         self.impl.start_migrations()
 
+        heads: Tuple[str, ...]
         if self.purge:
             if self.as_sql:
                 raise util.CommandError("Can't use --purge with --sql mode")
@@ -545,6 +596,7 @@ class MigrationContext:
 
         head_maintainer = HeadMaintainer(self, heads)
 
+        assert self._migrations_fn is not None
         for step in self._migrations_fn(heads, self):
             with self.begin_transaction(_per_migration=True):
 
@@ -576,15 +628,15 @@ class MigrationContext:
         if self.as_sql and not head_maintainer.heads:
             self._version.drop(self.connection)
 
-    def _in_connection_transaction(self):
+    def _in_connection_transaction(self) -> bool:
         try:
-            meth = self.connection.in_transaction
+            meth = self.connection.in_transaction  # type:ignore[union-attr]
         except AttributeError:
             return False
         else:
             return meth()
 
-    def execute(self, sql, execution_options=None):
+    def execute(self, sql: str, execution_options: None = None) -> None:
         """Execute a SQL construct or string statement.
 
         The underlying execution mechanics are used, that is
@@ -595,14 +647,16 @@ class MigrationContext:
         """
         self.impl._exec(sql, execution_options)
 
-    def _stdout_connection(self, connection):
+    def _stdout_connection(
+        self, connection: Optional["Connection"]
+    ) -> "MockConnection":
         def dump(construct, *multiparams, **params):
             self.impl._exec(construct)
 
         return MockEngineStrategy.MockConnection(self.dialect, dump)
 
     @property
-    def bind(self):
+    def bind(self) -> Optional["Connection"]:
         """Return the current "bind".
 
         In online mode, this is an instance of
@@ -623,7 +677,7 @@ class MigrationContext:
         return self.connection
 
     @property
-    def config(self):
+    def config(self) -> Optional["Config"]:
         """Return the :class:`.Config` used by the current environment,
         if any."""
 
@@ -632,7 +686,9 @@ class MigrationContext:
         else:
             return None
 
-    def _compare_type(self, inspector_column, metadata_column):
+    def _compare_type(
+        self, inspector_column: "Column", metadata_column: "Column"
+    ) -> bool:
         if self._user_compare_type is False:
             return False
 
@@ -651,11 +707,11 @@ class MigrationContext:
 
     def _compare_server_default(
         self,
-        inspector_column,
-        metadata_column,
-        rendered_metadata_default,
-        rendered_column_default,
-    ):
+        inspector_column: "Column",
+        metadata_column: "Column",
+        rendered_metadata_default: Optional[str],
+        rendered_column_default: Optional[str],
+    ) -> bool:
 
         if self._user_compare_server_default is False:
             return False
@@ -681,11 +737,11 @@ class MigrationContext:
 
 
 class HeadMaintainer:
-    def __init__(self, context, heads):
+    def __init__(self, context: "MigrationContext", heads: Any) -> None:
         self.context = context
         self.heads = set(heads)
 
-    def _insert_version(self, version):
+    def _insert_version(self, version: str) -> None:
         assert version not in self.heads
         self.heads.add(version)
 
@@ -695,7 +751,7 @@ class HeadMaintainer:
             )
         )
 
-    def _delete_version(self, version):
+    def _delete_version(self, version: str) -> None:
         self.heads.remove(version)
 
         ret = self.context.impl._exec(
@@ -716,7 +772,7 @@ class HeadMaintainer:
                 % (version, self.context.version_table, ret.rowcount)
             )
 
-    def _update_version(self, from_, to_):
+    def _update_version(self, from_: str, to_: str) -> None:
         assert to_ not in self.heads
         self.heads.remove(from_)
         self.heads.add(to_)
@@ -741,7 +797,7 @@ class HeadMaintainer:
                 % (from_, to_, self.context.version_table, ret.rowcount)
             )
 
-    def update_to_step(self, step):
+    def update_to_step(self, step: Union["RevisionStep", "StampStep"]) -> None:
         if step.should_delete_branch(self.heads):
             vers = step.delete_version_num
             log.debug("branch delete %s", vers)
@@ -796,15 +852,15 @@ class MigrationInfo:
 
     """
 
-    is_upgrade = None
+    is_upgrade: bool = None  # type:ignore[assignment]
     """True/False: indicates whether this operation ascends or descends the
     version tree."""
 
-    is_stamp = None
+    is_stamp: bool = None  # type:ignore[assignment]
     """True/False: indicates whether this operation is a stamp (i.e. whether
     it results in any actual database operations)."""
 
-    up_revision_id = None
+    up_revision_id: Optional[str] = None
     """Version string corresponding to :attr:`.Revision.revision`.
 
     In the case of a stamp operation, it is advised to use the
@@ -818,7 +874,7 @@ class MigrationInfo:
 
     """
 
-    up_revision_ids = None
+    up_revision_ids: Tuple[str, ...] = None  # type:ignore[assignment]
     """Tuple of version strings corresponding to :attr:`.Revision.revision`.
 
     In the majority of cases, this tuple will be a single value, synonomous
@@ -829,7 +885,7 @@ class MigrationInfo:
 
     """
 
-    down_revision_ids = None
+    down_revision_ids: Tuple[str, ...] = None  # type:ignore[assignment]
     """Tuple of strings representing the base revisions of this migration step.
 
     If empty, this represents a root revision; otherwise, the first item
@@ -837,12 +893,17 @@ class MigrationInfo:
     from dependencies.
     """
 
-    revision_map = None
+    revision_map: "RevisionMap" = None  # type:ignore[assignment]
     """The revision map inside of which this operation occurs."""
 
     def __init__(
-        self, revision_map, is_upgrade, is_stamp, up_revisions, down_revisions
-    ):
+        self,
+        revision_map: "RevisionMap",
+        is_upgrade: bool,
+        is_stamp: bool,
+        up_revisions: Union[str, Tuple[str, ...]],
+        down_revisions: Union[str, Tuple[str, ...]],
+    ) -> None:
         self.revision_map = revision_map
         self.is_upgrade = is_upgrade
         self.is_stamp = is_stamp
@@ -857,7 +918,7 @@ class MigrationInfo:
         self.down_revision_ids = util.to_tuple(down_revisions, default=())
 
     @property
-    def is_migration(self):
+    def is_migration(self) -> bool:
         """True/False: indicates whether this operation is a migration.
 
         At present this is true if and only the migration is not a stamp.
@@ -867,21 +928,21 @@ class MigrationInfo:
         return not self.is_stamp
 
     @property
-    def source_revision_ids(self):
+    def source_revision_ids(self) -> Tuple[str, ...]:
         """Active revisions before this migration step is applied."""
         return (
             self.down_revision_ids if self.is_upgrade else self.up_revision_ids
         )
 
     @property
-    def destination_revision_ids(self):
+    def destination_revision_ids(self) -> Tuple[str, ...]:
         """Active revisions after this migration step is applied."""
         return (
             self.up_revision_ids if self.is_upgrade else self.down_revision_ids
         )
 
     @property
-    def up_revision(self):
+    def up_revision(self) -> "Revision":
         """Get :attr:`~.MigrationInfo.up_revision_id` as
         a :class:`.Revision`.
 
@@ -889,49 +950,59 @@ class MigrationInfo:
         return self.revision_map.get_revision(self.up_revision_id)
 
     @property
-    def up_revisions(self):
+    def up_revisions(self) -> Tuple["Revision", ...]:
         """Get :attr:`~.MigrationInfo.up_revision_ids` as a
         :class:`.Revision`."""
         return self.revision_map.get_revisions(self.up_revision_ids)
 
     @property
-    def down_revisions(self):
+    def down_revisions(self) -> Tuple["Revision", ...]:
         """Get :attr:`~.MigrationInfo.down_revision_ids` as a tuple of
         :class:`Revisions <.Revision>`."""
         return self.revision_map.get_revisions(self.down_revision_ids)
 
     @property
-    def source_revisions(self):
+    def source_revisions(self) -> Tuple["Revision", ...]:
         """Get :attr:`~MigrationInfo.source_revision_ids` as a tuple of
         :class:`Revisions <.Revision>`."""
         return self.revision_map.get_revisions(self.source_revision_ids)
 
     @property
-    def destination_revisions(self):
+    def destination_revisions(self) -> Tuple["Revision", ...]:
         """Get :attr:`~MigrationInfo.destination_revision_ids` as a tuple of
         :class:`Revisions <.Revision>`."""
         return self.revision_map.get_revisions(self.destination_revision_ids)
 
 
 class MigrationStep:
+
+    from_revisions_no_deps: Tuple[str, ...]
+    to_revisions_no_deps: Tuple[str, ...]
+    is_upgrade: bool
+    migration_fn: Any
+
     @property
-    def name(self):
+    def name(self) -> str:
         return self.migration_fn.__name__
 
     @classmethod
-    def upgrade_from_script(cls, revision_map, script):
+    def upgrade_from_script(
+        cls, revision_map: "RevisionMap", script: "Script"
+    ) -> "RevisionStep":
         return RevisionStep(revision_map, script, True)
 
     @classmethod
-    def downgrade_from_script(cls, revision_map, script):
+    def downgrade_from_script(
+        cls, revision_map: "RevisionMap", script: "Script"
+    ) -> "RevisionStep":
         return RevisionStep(revision_map, script, False)
 
     @property
-    def is_downgrade(self):
+    def is_downgrade(self) -> bool:
         return not self.is_upgrade
 
     @property
-    def short_log(self):
+    def short_log(self) -> str:
         return "%s %s -> %s" % (
             self.name,
             util.format_as_comma(self.from_revisions_no_deps),
@@ -951,14 +1022,20 @@ class MigrationStep:
 
 
 class RevisionStep(MigrationStep):
-    def __init__(self, revision_map, revision, is_upgrade):
+    def __init__(
+        self, revision_map: "RevisionMap", revision: "Script", is_upgrade: bool
+    ) -> None:
         self.revision_map = revision_map
         self.revision = revision
         self.is_upgrade = is_upgrade
         if is_upgrade:
-            self.migration_fn = revision.module.upgrade
+            self.migration_fn = (
+                revision.module.upgrade  # type:ignore[attr-defined]
+            )
         else:
-            self.migration_fn = revision.module.downgrade
+            self.migration_fn = (
+                revision.module.downgrade  # type:ignore[attr-defined]
+            )
 
     def __repr__(self):
         return "RevisionStep(%r, is_upgrade=%r)" % (
@@ -966,7 +1043,7 @@ class RevisionStep(MigrationStep):
             self.is_upgrade,
         )
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         return (
             isinstance(other, RevisionStep)
             and other.revision == self.revision
@@ -978,38 +1055,42 @@ class RevisionStep(MigrationStep):
         return self.revision.doc
 
     @property
-    def from_revisions(self):
+    def from_revisions(self) -> Tuple[str, ...]:
         if self.is_upgrade:
             return self.revision._normalized_down_revisions
         else:
             return (self.revision.revision,)
 
     @property
-    def from_revisions_no_deps(self):
+    def from_revisions_no_deps(  # type:ignore[override]
+        self,
+    ) -> Tuple[str, ...]:
         if self.is_upgrade:
             return self.revision._versioned_down_revisions
         else:
             return (self.revision.revision,)
 
     @property
-    def to_revisions(self):
+    def to_revisions(self) -> Tuple[str, ...]:
         if self.is_upgrade:
             return (self.revision.revision,)
         else:
             return self.revision._normalized_down_revisions
 
     @property
-    def to_revisions_no_deps(self):
+    def to_revisions_no_deps(  # type:ignore[override]
+        self,
+    ) -> Tuple[str, ...]:
         if self.is_upgrade:
             return (self.revision.revision,)
         else:
             return self.revision._versioned_down_revisions
 
     @property
-    def _has_scalar_down_revision(self):
+    def _has_scalar_down_revision(self) -> bool:
         return len(self.revision._normalized_down_revisions) == 1
 
-    def should_delete_branch(self, heads):
+    def should_delete_branch(self, heads: Set[str]) -> bool:
         """A delete is when we are a. in a downgrade and b.
         we are going to the "base" or we are going to a version that
         is implied as a dependency on another version that is remaining.
@@ -1032,7 +1113,9 @@ class RevisionStep(MigrationStep):
             to_revisions = self._unmerge_to_revisions(heads)
             return not to_revisions
 
-    def merge_branch_idents(self, heads):
+    def merge_branch_idents(
+        self, heads: Set[str]
+    ) -> Tuple[List[str], str, str]:
         other_heads = set(heads).difference(self.from_revisions)
 
         if other_heads:
@@ -1055,7 +1138,7 @@ class RevisionStep(MigrationStep):
             self.to_revisions[0],
         )
 
-    def _unmerge_to_revisions(self, heads):
+    def _unmerge_to_revisions(self, heads: Collection[str]) -> Tuple[str, ...]:
         other_heads = set(heads).difference([self.revision.revision])
         if other_heads:
             ancestors = set(
@@ -1064,11 +1147,13 @@ class RevisionStep(MigrationStep):
                     self.revision_map.get_revisions(other_heads), check=False
                 )
             )
-            return list(set(self.to_revisions).difference(ancestors))
+            return tuple(set(self.to_revisions).difference(ancestors))
         else:
             return self.to_revisions
 
-    def unmerge_branch_idents(self, heads):
+    def unmerge_branch_idents(
+        self, heads: Collection[str]
+    ) -> Tuple[str, str, Tuple[str, ...]]:
         to_revisions = self._unmerge_to_revisions(heads)
 
         return (
@@ -1078,7 +1163,7 @@ class RevisionStep(MigrationStep):
             to_revisions[0:-1],
         )
 
-    def should_create_branch(self, heads):
+    def should_create_branch(self, heads: Set[str]) -> bool:
         if not self.is_upgrade:
             return False
 
@@ -1097,7 +1182,7 @@ class RevisionStep(MigrationStep):
             else:
                 return False
 
-    def should_merge_branches(self, heads):
+    def should_merge_branches(self, heads: Set[str]) -> bool:
         if not self.is_upgrade:
             return False
 
@@ -1108,7 +1193,7 @@ class RevisionStep(MigrationStep):
 
         return False
 
-    def should_unmerge_branches(self, heads):
+    def should_unmerge_branches(self, heads: Set[str]) -> bool:
         if not self.is_downgrade:
             return False
 
@@ -1119,7 +1204,7 @@ class RevisionStep(MigrationStep):
 
         return False
 
-    def update_version_num(self, heads):
+    def update_version_num(self, heads: Set[str]) -> Tuple[str, str]:
         if not self._has_scalar_down_revision:
             downrev = heads.intersection(
                 self.revision._normalized_down_revisions
@@ -1137,15 +1222,15 @@ class RevisionStep(MigrationStep):
             return self.revision.revision, down_revision
 
     @property
-    def delete_version_num(self):
+    def delete_version_num(self) -> str:
         return self.revision.revision
 
     @property
-    def insert_version_num(self):
+    def insert_version_num(self) -> str:
         return self.revision.revision
 
     @property
-    def info(self):
+    def info(self) -> "MigrationInfo":
         return MigrationInfo(
             revision_map=self.revision_map,
             up_revisions=self.revision.revision,
@@ -1156,9 +1241,16 @@ class RevisionStep(MigrationStep):
 
 
 class StampStep(MigrationStep):
-    def __init__(self, from_, to_, is_upgrade, branch_move, revision_map=None):
-        self.from_ = util.to_tuple(from_, default=())
-        self.to_ = util.to_tuple(to_, default=())
+    def __init__(
+        self,
+        from_: Optional[Union[str, Collection[str]]],
+        to_: Optional[Union[str, Collection[str]]],
+        is_upgrade: bool,
+        branch_move: bool,
+        revision_map: Optional["RevisionMap"] = None,
+    ) -> None:
+        self.from_: Tuple[str, ...] = util.to_tuple(from_, default=())
+        self.to_: Tuple[str, ...] = util.to_tuple(to_, default=())
         self.is_upgrade = is_upgrade
         self.branch_move = branch_move
         self.migration_fn = self.stamp_revision
@@ -1166,7 +1258,7 @@ class StampStep(MigrationStep):
 
     doc = None
 
-    def stamp_revision(self, **kw):
+    def stamp_revision(self, **kw) -> None:
         return None
 
     def __eq__(self, other):
@@ -1183,33 +1275,39 @@ class StampStep(MigrationStep):
         return self.from_
 
     @property
-    def to_revisions(self):
+    def to_revisions(self) -> Tuple[str, ...]:
         return self.to_
 
     @property
-    def from_revisions_no_deps(self):
+    def from_revisions_no_deps(  # type:ignore[override]
+        self,
+    ) -> Tuple[str, ...]:
         return self.from_
 
     @property
-    def to_revisions_no_deps(self):
+    def to_revisions_no_deps(  # type:ignore[override]
+        self,
+    ) -> Tuple[str, ...]:
         return self.to_
 
     @property
-    def delete_version_num(self):
+    def delete_version_num(self) -> str:
         assert len(self.from_) == 1
         return self.from_[0]
 
     @property
-    def insert_version_num(self):
+    def insert_version_num(self) -> str:
         assert len(self.to_) == 1
         return self.to_[0]
 
-    def update_version_num(self, heads):
+    def update_version_num(self, heads: Set[str]) -> Tuple[str, str]:
         assert len(self.from_) == 1
         assert len(self.to_) == 1
         return self.from_[0], self.to_[0]
 
-    def merge_branch_idents(self, heads):
+    def merge_branch_idents(
+        self, heads: Union[Set[str], List[str]]
+    ) -> Union[Tuple[List[Any], str, str], Tuple[List[str], str, str]]:
         return (
             # delete revs, update from rev, update to rev
             list(self.from_[0:-1]),
@@ -1217,7 +1315,9 @@ class StampStep(MigrationStep):
             self.to_[0],
         )
 
-    def unmerge_branch_idents(self, heads):
+    def unmerge_branch_idents(
+        self, heads: Set[str]
+    ) -> Tuple[str, str, List[str]]:
         return (
             # update from rev, update to rev, insert revs
             self.from_[0],
@@ -1225,32 +1325,33 @@ class StampStep(MigrationStep):
             list(self.to_[0:-1]),
         )
 
-    def should_delete_branch(self, heads):
+    def should_delete_branch(self, heads: Set[str]) -> bool:
         # TODO: we probably need to look for self.to_ inside of heads,
         # in a similar manner as should_create_branch, however we have
         # no tests for this yet (stamp downgrades w/ branches)
         return self.is_downgrade and self.branch_move
 
-    def should_create_branch(self, heads):
+    def should_create_branch(self, heads: Set[str]) -> Union[Set[str], bool]:
         return (
             self.is_upgrade
             and (self.branch_move or set(self.from_).difference(heads))
             and set(self.to_).difference(heads)
         )
 
-    def should_merge_branches(self, heads):
+    def should_merge_branches(self, heads: Set[str]) -> bool:
         return len(self.from_) > 1
 
-    def should_unmerge_branches(self, heads):
+    def should_unmerge_branches(self, heads: Set[str]) -> bool:
         return len(self.to_) > 1
 
     @property
-    def info(self):
+    def info(self) -> "MigrationInfo":
         up, down = (
             (self.to_, self.from_)
             if self.is_upgrade
             else (self.from_, self.to_)
         )
+        assert self.revision_map is not None
         return MigrationInfo(
             revision_map=self.revision_map,
             up_revisions=up,
index d0500c4e5e8676ce0a6b251d3aaf8747d3c506ee..ef0fd52a1f87f9f0ca718d6e9a20cc080e2ad78f 100644 (file)
@@ -4,16 +4,35 @@ import os
 import re
 import shutil
 import sys
+from types import ModuleType
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
 
 from . import revision
 from . import write_hooks
 from .. import util
 from ..runtime import migration
 
+if TYPE_CHECKING:
+    from ..config import Config
+    from ..runtime.migration import RevisionStep
+    from ..runtime.migration import StampStep
+
 try:
     from dateutil import tz
 except ImportError:
-    tz = None  # noqa
+    tz = None  # type: ignore[assignment]
+
+_RevIdType = Union[str, Sequence[str]]
 
 _sourceless_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)(c|o)?$")
 _only_source_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)$")
@@ -49,15 +68,15 @@ class ScriptDirectory:
 
     def __init__(
         self,
-        dir,  # noqa
-        file_template=_default_file_template,
-        truncate_slug_length=40,
-        version_locations=None,
-        sourceless=False,
-        output_encoding="utf-8",
-        timezone=None,
-        hook_config=None,
-    ):
+        dir: str,  # noqa
+        file_template: str = _default_file_template,
+        truncate_slug_length: Optional[int] = 40,
+        version_locations: Optional[List[str]] = None,
+        sourceless: bool = False,
+        output_encoding: str = "utf-8",
+        timezone: Optional[str] = None,
+        hook_config: Optional[Dict[str, str]] = None,
+    ) -> None:
         self.dir = dir
         self.file_template = file_template
         self.version_locations = version_locations
@@ -76,7 +95,7 @@ class ScriptDirectory:
             )
 
     @property
-    def versions(self):
+    def versions(self) -> str:
         loc = self._version_locations
         if len(loc) > 1:
             raise util.CommandError("Multiple version_locations present")
@@ -93,7 +112,7 @@ class ScriptDirectory:
         else:
             return (os.path.abspath(os.path.join(self.dir, "versions")),)
 
-    def _load_revisions(self):
+    def _load_revisions(self) -> Iterator["Script"]:
         if self.version_locations:
             paths = [
                 vers
@@ -120,7 +139,7 @@ class ScriptDirectory:
                 yield script
 
     @classmethod
-    def from_config(cls, config):
+    def from_config(cls, config: "Config") -> "ScriptDirectory":
         """Produce a new :class:`.ScriptDirectory` given a :class:`.Config`
         instance.
 
@@ -133,7 +152,9 @@ class ScriptDirectory:
             raise util.CommandError(
                 "No 'script_location' key " "found in configuration."
             )
-        truncate_slug_length = config.get_main_option("truncate_slug_length")
+        truncate_slug_length = cast(
+            Optional[int], config.get_main_option("truncate_slug_length")
+        )
         if truncate_slug_length is not None:
             truncate_slug_length = int(truncate_slug_length)
 
@@ -162,13 +183,17 @@ class ScriptDirectory:
             else:
                 if split_char is None:
                     # legacy behaviour for backwards compatibility
-                    version_locations = _split_on_space_comma.split(
-                        version_locations
+                    vl = _split_on_space_comma.split(
+                        cast(str, version_locations)
                     )
+                    version_locations: List[str] = vl  # type: ignore[no-redef]
                 else:
-                    version_locations = [
-                        x for x in version_locations.split(split_char) if x
+                    vl = [
+                        x
+                        for x in cast(str, version_locations).split(split_char)
+                        if x
                     ]
+                    version_locations: List[str] = vl  # type: ignore[no-redef]
 
         prepend_sys_path = config.get_main_option("prepend_sys_path")
         if prepend_sys_path:
@@ -184,7 +209,7 @@ class ScriptDirectory:
             truncate_slug_length=truncate_slug_length,
             sourceless=config.get_main_option("sourceless") == "true",
             output_encoding=config.get_main_option("output_encoding", "utf-8"),
-            version_locations=version_locations,
+            version_locations=cast("Optional[List[str]]", version_locations),
             timezone=config.get_main_option("timezone"),
             hook_config=config.get_section("post_write_hooks", {}),
         )
@@ -192,19 +217,19 @@ class ScriptDirectory:
     @contextmanager
     def _catch_revision_errors(
         self,
-        ancestor=None,
-        multiple_heads=None,
-        start=None,
-        end=None,
-        resolution=None,
-    ):
+        ancestor: Optional[str] = None,
+        multiple_heads: Optional[str] = None,
+        start: Optional[str] = None,
+        end: Optional[str] = None,
+        resolution: Optional[str] = None,
+    ) -> Iterator[None]:
         try:
             yield
         except revision.RangeNotAncestorError as rna:
             if start is None:
-                start = rna.lower
+                start = cast(Any, rna.lower)
             if end is None:
-                end = rna.upper
+                end = cast(Any, rna.upper)
             if not ancestor:
                 ancestor = (
                     "Requested range %(start)s:%(end)s does not refer to "
@@ -235,7 +260,9 @@ class ScriptDirectory:
         except revision.RevisionError as err:
             raise util.CommandError(err.args[0]) from err
 
-    def walk_revisions(self, base="base", head="heads"):
+    def walk_revisions(
+        self, base: str = "base", head: str = "heads"
+    ) -> Iterator["Script"]:
         """Iterate through all revisions.
 
         :param base: the base revision, or "base" to start from the
@@ -250,28 +277,36 @@ class ScriptDirectory:
             for rev in self.revision_map.iterate_revisions(
                 head, base, inclusive=True, assert_relative_length=False
             ):
-                yield rev
+                yield cast(Script, rev)
 
-    def get_revisions(self, id_):
+    def get_revisions(self, id_: _RevIdType) -> Tuple["Script", ...]:
         """Return the :class:`.Script` instance with the given rev identifier,
         symbolic name, or sequence of identifiers.
 
         """
         with self._catch_revision_errors():
-            return self.revision_map.get_revisions(id_)
+            return cast(
+                "Tuple[Script, ...]", self.revision_map.get_revisions(id_)
+            )
 
-    def get_all_current(self, id_):
+    def get_all_current(self, id_: Tuple[str, ...]) -> Set["Script"]:
         with self._catch_revision_errors():
-            top_revs = set(self.revision_map.get_revisions(id_))
+            top_revs = cast(
+                "Set[Script]",
+                set(self.revision_map.get_revisions(id_)),
+            )
             top_revs.update(
-                self.revision_map._get_ancestor_nodes(
-                    list(top_revs), include_dependencies=True
+                cast(
+                    "Iterator[Script]",
+                    self.revision_map._get_ancestor_nodes(
+                        list(top_revs), include_dependencies=True
+                    ),
                 )
             )
             top_revs = self.revision_map._filter_into_branch_heads(top_revs)
             return top_revs
 
-    def get_revision(self, id_):
+    def get_revision(self, id_: str) -> "Script":
         """Return the :class:`.Script` instance with the given rev id.
 
         .. seealso::
@@ -281,9 +316,11 @@ class ScriptDirectory:
         """
 
         with self._catch_revision_errors():
-            return self.revision_map.get_revision(id_)
+            return cast(Script, self.revision_map.get_revision(id_))
 
-    def as_revision_number(self, id_):
+    def as_revision_number(
+        self, id_: Optional[str]
+    ) -> Optional[Union[str, Tuple[str, ...]]]:
         """Convert a symbolic revision, i.e. 'head' or 'base', into
         an actual revision number."""
 
@@ -340,7 +377,7 @@ class ScriptDirectory:
         ):
             return self.revision_map.get_current_head()
 
-    def get_heads(self):
+    def get_heads(self) -> List[str]:
         """Return all "versioned head" revisions as strings.
 
         This is normally a list of length one,
@@ -353,7 +390,7 @@ class ScriptDirectory:
         """
         return list(self.revision_map.heads)
 
-    def get_base(self):
+    def get_base(self) -> Optional[str]:
         """Return the "base" revision as a string.
 
         This is the revision number of the script that
@@ -375,7 +412,7 @@ class ScriptDirectory:
         else:
             return None
 
-    def get_bases(self):
+    def get_bases(self) -> List[str]:
         """return all "base" revisions as strings.
 
         This is the revision number of all scripts that
@@ -384,7 +421,9 @@ class ScriptDirectory:
         """
         return list(self.revision_map.bases)
 
-    def _upgrade_revs(self, destination, current_rev):
+    def _upgrade_revs(
+        self, destination: str, current_rev: str
+    ) -> List["RevisionStep"]:
         with self._catch_revision_errors(
             ancestor="Destination %(end)s is not a valid upgrade "
             "target from current head(s)",
@@ -393,15 +432,16 @@ class ScriptDirectory:
             revs = self.revision_map.iterate_revisions(
                 destination, current_rev, implicit_base=True
             )
-            revs = list(revs)
             return [
                 migration.MigrationStep.upgrade_from_script(
-                    self.revision_map, script
+                    self.revision_map, cast(Script, script)
                 )
                 for script in reversed(list(revs))
             ]
 
-    def _downgrade_revs(self, destination, current_rev):
+    def _downgrade_revs(
+        self, destination: str, current_rev: Optional[str]
+    ) -> List["RevisionStep"]:
         with self._catch_revision_errors(
             ancestor="Destination %(end)s is not a valid downgrade "
             "target from current head(s)",
@@ -412,30 +452,32 @@ class ScriptDirectory:
             )
             return [
                 migration.MigrationStep.downgrade_from_script(
-                    self.revision_map, script
+                    self.revision_map, cast(Script, script)
                 )
                 for script in revs
             ]
 
-    def _stamp_revs(self, revision, heads):
+    def _stamp_revs(
+        self, revision: _RevIdType, heads: _RevIdType
+    ) -> List["StampStep"]:
         with self._catch_revision_errors(
             multiple_heads="Multiple heads are present; please specify a "
             "single target revision"
         ):
 
-            heads = self.get_revisions(heads)
+            heads_revs = self.get_revisions(heads)
 
             steps = []
 
             if not revision:
                 revision = "base"
 
-            filtered_heads = []
+            filtered_heads: List["Script"] = []
             for rev in util.to_tuple(revision):
                 if rev:
                     filtered_heads.extend(
                         self.revision_map.filter_for_lineage(
-                            heads, rev, include_dependencies=True
+                            heads_revs, rev, include_dependencies=True
                         )
                     )
             filtered_heads = util.unique_list(filtered_heads)
@@ -509,7 +551,7 @@ class ScriptDirectory:
 
             return steps
 
-    def run_env(self):
+    def run_env(self) -> None:
         """Run the script environment.
 
         This basically runs the ``env.py`` script present
@@ -524,7 +566,7 @@ class ScriptDirectory:
     def env_py_location(self):
         return os.path.abspath(os.path.join(self.dir, "env.py"))
 
-    def _generate_template(self, src, dest, **kw):
+    def _generate_template(self, src: str, dest: str, **kw: Any) -> None:
         util.status(
             "Generating %s" % os.path.abspath(dest),
             util.template_to_file,
@@ -534,17 +576,17 @@ class ScriptDirectory:
             **kw
         )
 
-    def _copy_file(self, src, dest):
+    def _copy_file(self, src: str, dest: str) -> None:
         util.status(
             "Generating %s" % os.path.abspath(dest), shutil.copy, src, dest
         )
 
-    def _ensure_directory(self, path):
+    def _ensure_directory(self, path: str) -> None:
         path = os.path.abspath(path)
         if not os.path.exists(path):
             util.status("Creating directory %s" % path, os.makedirs, path)
 
-    def _generate_create_date(self):
+    def _generate_create_date(self) -> "datetime.datetime":
         if self.timezone is not None:
             if tz is None:
                 raise util.CommandError(
@@ -571,16 +613,16 @@ class ScriptDirectory:
 
     def generate_revision(
         self,
-        revid,
-        message,
-        head=None,
-        refresh=False,
-        splice=False,
-        branch_labels=None,
-        version_path=None,
-        depends_on=None,
-        **kw
-    ):
+        revid: str,
+        message: Optional[str],
+        head: Optional[str] = None,
+        refresh: bool = False,
+        splice: Optional[bool] = False,
+        branch_labels: Optional[str] = None,
+        version_path: Optional[str] = None,
+        depends_on: Optional[_RevIdType] = None,
+        **kw: Any
+    ) -> Optional["Script"]:
         """Generate a new revision file.
 
         This runs the ``script.py.mako`` template, given
@@ -623,9 +665,10 @@ class ScriptDirectory:
 
         if version_path is None:
             if len(self._version_locations) > 1:
-                for head in heads:
-                    if head is not None:
-                        version_path = os.path.dirname(head.path)
+                for head_ in heads:
+                    if head_ is not None:
+                        assert isinstance(head_, Script)
+                        version_path = os.path.dirname(head_.path)
                         break
                 else:
                     raise util.CommandError(
@@ -651,12 +694,12 @@ class ScriptDirectory:
         path = self._rev_path(version_path, revid, message, create_date)
 
         if not splice:
-            for head in heads:
-                if head is not None and not head.is_head:
+            for head_ in heads:
+                if head_ is not None and not head_.is_head:
                     raise util.CommandError(
                         "Revision %s is not a head revision; please specify "
                         "--splice to create a new branch from this revision"
-                        % head.revision
+                        % head_.revision
                     )
 
         if depends_on:
@@ -679,7 +722,9 @@ class ScriptDirectory:
                 tuple(h.revision if h is not None else None for h in heads)
             ),
             branch_labels=util.to_tuple(branch_labels),
-            depends_on=revision.tuple_rev_as_scalar(depends_on),
+            depends_on=revision.tuple_rev_as_scalar(
+                cast("Optional[List[str]]", depends_on)
+            ),
             create_date=create_date,
             comma=util.format_as_comma,
             message=message if message is not None else ("empty message"),
@@ -694,6 +739,8 @@ class ScriptDirectory:
             script = Script._from_path(self, path)
         except revision.RevisionError as err:
             raise util.CommandError(err.args[0]) from err
+        if script is None:
+            return None
         if branch_labels and not script.branch_labels:
             raise util.CommandError(
                 "Version %s specified branch_labels %s, however the "
@@ -702,11 +749,16 @@ class ScriptDirectory:
                 "'branch_labels' section?"
                 % (script.revision, branch_labels, script.path)
             )
-
         self.revision_map.add_revision(script)
         return script
 
-    def _rev_path(self, path, rev_id, message, create_date):
+    def _rev_path(
+        self,
+        path: str,
+        rev_id: str,
+        message: Optional[str],
+        create_date: "datetime.datetime",
+    ) -> str:
         slug = "_".join(_slug_re.findall(message or "")).lower()
         if len(slug) > self.truncate_slug_length:
             slug = slug[: self.truncate_slug_length].rsplit("_", 1)[0] + "_"
@@ -735,12 +787,12 @@ class Script(revision.Revision):
 
     """
 
-    def __init__(self, module, rev_id, path):
+    def __init__(self, module: ModuleType, rev_id: str, path: str):
         self.module = module
         self.path = path
         super(Script, self).__init__(
             rev_id,
-            module.down_revision,
+            module.down_revision,  # type: ignore[attr-defined]
             branch_labels=util.to_tuple(
                 getattr(module, "branch_labels", None), default=()
             ),
@@ -749,10 +801,10 @@ class Script(revision.Revision):
             ),
         )
 
-    module = None
+    module: ModuleType = None  # type: ignore[assignment]
     """The Python module representing the actual script itself."""
 
-    path = None
+    path: str = None  # type: ignore[assignment]
     """Filesystem path of the script."""
 
     _db_current_indicator = None
@@ -760,25 +812,27 @@ class Script(revision.Revision):
     this is a "current" version in some database"""
 
     @property
-    def doc(self):
+    def doc(self) -> str:
         """Return the docstring given in the script."""
 
         return re.split("\n\n", self.longdoc)[0]
 
     @property
-    def longdoc(self):
+    def longdoc(self) -> str:
         """Return the docstring given in the script."""
 
         doc = self.module.__doc__
         if doc:
             if hasattr(self.module, "_alembic_source_encoding"):
-                doc = doc.decode(self.module._alembic_source_encoding)
-            return doc.strip()
+                doc = doc.decode(  # type: ignore[attr-defined]
+                    self.module._alembic_source_encoding  # type: ignore[attr-defined] # noqa
+                )
+            return doc.strip()  # type: ignore[union-attr]
         else:
             return ""
 
     @property
-    def log_entry(self):
+    def log_entry(self) -> str:
         entry = "Rev: %s%s%s%s%s\n" % (
             self.revision,
             " (head)" if self.is_head else "",
@@ -825,12 +879,12 @@ class Script(revision.Revision):
 
     def _head_only(
         self,
-        include_branches=False,
-        include_doc=False,
-        include_parents=False,
-        tree_indicators=True,
-        head_indicators=True,
-    ):
+        include_branches: bool = False,
+        include_doc: bool = False,
+        include_parents: bool = False,
+        tree_indicators: bool = True,
+        head_indicators: bool = True,
+    ) -> str:
         text = self.revision
         if include_parents:
             if self.dependencies:
@@ -841,6 +895,7 @@ class Script(revision.Revision):
                 )
             else:
                 text = "%s -> %s" % (self._format_down_revision(), text)
+        assert text is not None
         if include_branches and self.branch_labels:
             text += " (%s)" % util.format_as_comma(self.branch_labels)
         if head_indicators or tree_indicators:
@@ -862,12 +917,12 @@ class Script(revision.Revision):
 
     def cmd_format(
         self,
-        verbose,
-        include_branches=False,
-        include_doc=False,
-        include_parents=False,
-        tree_indicators=True,
-    ):
+        verbose: bool,
+        include_branches: bool = False,
+        include_doc: bool = False,
+        include_parents: bool = False,
+        tree_indicators: bool = True,
+    ) -> str:
         if verbose:
             return self.log_entry
         else:
@@ -875,19 +930,21 @@ class Script(revision.Revision):
                 include_branches, include_doc, include_parents, tree_indicators
             )
 
-    def _format_down_revision(self):
+    def _format_down_revision(self) -> str:
         if not self.down_revision:
             return "<base>"
         else:
             return util.format_as_comma(self._versioned_down_revisions)
 
     @classmethod
-    def _from_path(cls, scriptdir, path):
+    def _from_path(
+        cls, scriptdir: ScriptDirectory, path: str
+    ) -> Optional["Script"]:
         dir_, filename = os.path.split(path)
         return cls._from_filename(scriptdir, dir_, filename)
 
     @classmethod
-    def _list_py_dir(cls, scriptdir, path):
+    def _list_py_dir(cls, scriptdir: ScriptDirectory, path: str) -> List[str]:
         if scriptdir.sourceless:
             # read files in version path, e.g. pyc or pyo files
             # in the immediate path
@@ -910,7 +967,9 @@ class Script(revision.Revision):
             return os.listdir(path)
 
     @classmethod
-    def _from_filename(cls, scriptdir, dir_, filename):
+    def _from_filename(
+        cls, scriptdir: ScriptDirectory, dir_: str, filename: str
+    ) -> Optional["Script"]:
         if scriptdir.sourceless:
             py_match = _sourceless_rev_file.match(filename)
         else:
index bdae805db03e72d03e0864d6ba509ddba67d24b8..eccb98ec8a8b1cce8c36922b6e7d48e67f0622e9 100644 (file)
@@ -1,11 +1,40 @@
 import collections
 import re
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Collection
+from typing import Deque
+from typing import Dict
+from typing import FrozenSet
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
 
 from sqlalchemy import util as sqlautil
 
 from .. import util
 from ..util import compat
 
+if TYPE_CHECKING:
+    from typing import Literal
+
+    from .base import Script
+
+_RevIdType = Union[str, Sequence[str]]
+_RevisionIdentifierType = Union[str, Tuple[str, ...], None]
+_RevisionOrStr = Union["Revision", str]
+_RevisionOrBase = Union["Revision", "Literal['base']"]
+_InterimRevisionMapType = Dict[str, "Revision"]
+_RevisionMapType = Dict[Union[None, str, Tuple[()]], Optional["Revision"]]
+_T = TypeVar("_T", bound=Union[str, "Revision"])
+
 _relative_destination = re.compile(r"(?:(.+?)@)?(\w+)?((?:\+|-)\d+)")
 _revision_illegal_chars = ["@", "-", "+"]
 
@@ -15,7 +44,9 @@ class RevisionError(Exception):
 
 
 class RangeNotAncestorError(RevisionError):
-    def __init__(self, lower, upper):
+    def __init__(
+        self, lower: _RevisionIdentifierType, upper: _RevisionIdentifierType
+    ) -> None:
         self.lower = lower
         self.upper = upper
         super(RangeNotAncestorError, self).__init__(
@@ -25,7 +56,7 @@ class RangeNotAncestorError(RevisionError):
 
 
 class MultipleHeads(RevisionError):
-    def __init__(self, heads, argument):
+    def __init__(self, heads: Sequence[str], argument: Optional[str]) -> None:
         self.heads = heads
         self.argument = argument
         super(MultipleHeads, self).__init__(
@@ -35,7 +66,7 @@ class MultipleHeads(RevisionError):
 
 
 class ResolutionError(RevisionError):
-    def __init__(self, message, argument):
+    def __init__(self, message: str, argument: str) -> None:
         super(ResolutionError, self).__init__(message)
         self.argument = argument
 
@@ -43,7 +74,7 @@ class ResolutionError(RevisionError):
 class CycleDetected(RevisionError):
     kind = "Cycle"
 
-    def __init__(self, revisions):
+    def __init__(self, revisions: Sequence[str]) -> None:
         self.revisions = revisions
         super(CycleDetected, self).__init__(
             "%s is detected in revisions (%s)"
@@ -54,21 +85,21 @@ class CycleDetected(RevisionError):
 class DependencyCycleDetected(CycleDetected):
     kind = "Dependency cycle"
 
-    def __init__(self, revisions):
+    def __init__(self, revisions: Sequence[str]) -> None:
         super(DependencyCycleDetected, self).__init__(revisions)
 
 
 class LoopDetected(CycleDetected):
     kind = "Self-loop"
 
-    def __init__(self, revision):
+    def __init__(self, revision: str) -> None:
         super(LoopDetected, self).__init__([revision])
 
 
 class DependencyLoopDetected(DependencyCycleDetected, LoopDetected):
     kind = "Dependency self-loop"
 
-    def __init__(self, revision):
+    def __init__(self, revision: Sequence[str]) -> None:
         super(DependencyLoopDetected, self).__init__(revision)
 
 
@@ -81,7 +112,7 @@ class RevisionMap:
 
     """
 
-    def __init__(self, generator):
+    def __init__(self, generator: Callable[[], Iterator["Revision"]]) -> None:
         """Construct a new :class:`.RevisionMap`.
 
         :param generator: a zero-arg callable that will generate an iterable
@@ -92,7 +123,7 @@ class RevisionMap:
         self._generator = generator
 
     @util.memoized_property
-    def heads(self):
+    def heads(self) -> Tuple[str, ...]:
         """All "head" revisions as strings.
 
         This is normally a tuple of length one,
@@ -105,7 +136,7 @@ class RevisionMap:
         return self.heads
 
     @util.memoized_property
-    def bases(self):
+    def bases(self) -> Tuple[str, ...]:
         """All "base" revisions as strings.
 
         These are revisions that have a ``down_revision`` of None,
@@ -118,7 +149,7 @@ class RevisionMap:
         return self.bases
 
     @util.memoized_property
-    def _real_heads(self):
+    def _real_heads(self) -> Tuple[str, ...]:
         """All "real" head revisions as strings.
 
         :return: a tuple of string revision numbers.
@@ -128,7 +159,7 @@ class RevisionMap:
         return self._real_heads
 
     @util.memoized_property
-    def _real_bases(self):
+    def _real_bases(self) -> Tuple[str, ...]:
         """All "real" base revisions as strings.
 
         :return: a tuple of string revision numbers.
@@ -138,19 +169,19 @@ class RevisionMap:
         return self._real_bases
 
     @util.memoized_property
-    def _revision_map(self):
+    def _revision_map(self) -> _RevisionMapType:
         """memoized attribute, initializes the revision map from the
         initial collection.
 
         """
         # Ordering required for some tests to pass (but not required in
         # general)
-        map_ = sqlautil.OrderedDict()
+        map_: _InterimRevisionMapType = sqlautil.OrderedDict()
 
-        heads = sqlautil.OrderedSet()
-        _real_heads = sqlautil.OrderedSet()
-        bases = ()
-        _real_bases = ()
+        heads: Set["Revision"] = sqlautil.OrderedSet()
+        _real_heads: Set["Revision"] = sqlautil.OrderedSet()
+        bases: Tuple["Revision", ...] = ()
+        _real_bases: Tuple["Revision", ...] = ()
 
         has_branch_labels = set()
         all_revisions = set()
@@ -176,11 +207,13 @@ class RevisionMap:
         # add the branch_labels to the map_.  We'll need these
         # to resolve the dependencies.
         rev_map = map_.copy()
-        self._map_branch_labels(has_branch_labels, map_)
+        self._map_branch_labels(
+            has_branch_labels, cast(_RevisionMapType, map_)
+        )
 
         # resolve dependency names from branch labels and symbolic
         # names
-        self._add_depends_on(all_revisions, map_)
+        self._add_depends_on(all_revisions, cast(_RevisionMapType, map_))
 
         for rev in map_.values():
             for downrev in rev._all_down_revisions:
@@ -198,32 +231,44 @@ class RevisionMap:
         # once the map has downrevisions populated, the dependencies
         # can be further refined to include only those which are not
         # already ancestors
-        self._normalize_depends_on(all_revisions, map_)
+        self._normalize_depends_on(all_revisions, cast(_RevisionMapType, map_))
         self._detect_cycles(rev_map, heads, bases, _real_heads, _real_bases)
 
-        map_[None] = map_[()] = None
+        revision_map: _RevisionMapType = dict(map_.items())
+        revision_map[None] = revision_map[()] = None
         self.heads = tuple(rev.revision for rev in heads)
         self._real_heads = tuple(rev.revision for rev in _real_heads)
         self.bases = tuple(rev.revision for rev in bases)
         self._real_bases = tuple(rev.revision for rev in _real_bases)
 
-        self._add_branches(has_branch_labels, map_)
-        return map_
+        self._add_branches(has_branch_labels, revision_map)
+        return revision_map
 
-    def _detect_cycles(self, rev_map, heads, bases, _real_heads, _real_bases):
+    def _detect_cycles(
+        self,
+        rev_map: _InterimRevisionMapType,
+        heads: Set["Revision"],
+        bases: Tuple["Revision", ...],
+        _real_heads: Set["Revision"],
+        _real_bases: Tuple["Revision", ...],
+    ) -> None:
         if not rev_map:
             return
         if not heads or not bases:
-            raise CycleDetected(rev_map.keys())
+            raise CycleDetected(list(rev_map))
         total_space = {
             rev.revision
             for rev in self._iterate_related_revisions(
-                lambda r: r._versioned_down_revisions, heads, map_=rev_map
+                lambda r: r._versioned_down_revisions,
+                heads,
+                map_=cast(_RevisionMapType, rev_map),
             )
         }.intersection(
             rev.revision
             for rev in self._iterate_related_revisions(
-                lambda r: r.nextrev, bases, map_=rev_map
+                lambda r: r.nextrev,
+                bases,
+                map_=cast(_RevisionMapType, rev_map),
             )
         )
         deleted_revs = set(rev_map.keys()) - total_space
@@ -231,39 +276,50 @@ class RevisionMap:
             raise CycleDetected(sorted(deleted_revs))
 
         if not _real_heads or not _real_bases:
-            raise DependencyCycleDetected(rev_map.keys())
+            raise DependencyCycleDetected(list(rev_map))
         total_space = {
             rev.revision
             for rev in self._iterate_related_revisions(
-                lambda r: r._all_down_revisions, _real_heads, map_=rev_map
+                lambda r: r._all_down_revisions,
+                _real_heads,
+                map_=cast(_RevisionMapType, rev_map),
             )
         }.intersection(
             rev.revision
             for rev in self._iterate_related_revisions(
-                lambda r: r._all_nextrev, _real_bases, map_=rev_map
+                lambda r: r._all_nextrev,
+                _real_bases,
+                map_=cast(_RevisionMapType, rev_map),
             )
         )
         deleted_revs = set(rev_map.keys()) - total_space
         if deleted_revs:
             raise DependencyCycleDetected(sorted(deleted_revs))
 
-    def _map_branch_labels(self, revisions, map_):
+    def _map_branch_labels(
+        self, revisions: Collection["Revision"], map_: _RevisionMapType
+    ) -> None:
         for revision in revisions:
             if revision.branch_labels:
+                assert revision._orig_branch_labels is not None
                 for branch_label in revision._orig_branch_labels:
                     if branch_label in map_:
+                        map_rev = map_[branch_label]
+                        assert map_rev is not None
                         raise RevisionError(
                             "Branch name '%s' in revision %s already "
                             "used by revision %s"
                             % (
                                 branch_label,
                                 revision.revision,
-                                map_[branch_label].revision,
+                                map_rev.revision,
                             )
                         )
                     map_[branch_label] = revision
 
-    def _add_branches(self, revisions, map_):
+    def _add_branches(
+        self, revisions: Collection["Revision"], map_: _RevisionMapType
+    ) -> None:
         for revision in revisions:
             if revision.branch_labels:
                 revision.branch_labels.update(revision.branch_labels)
@@ -285,7 +341,9 @@ class RevisionMap:
                     else:
                         break
 
-    def _add_depends_on(self, revisions, map_):
+    def _add_depends_on(
+        self, revisions: Collection["Revision"], map_: _RevisionMapType
+    ) -> None:
         """Resolve the 'dependencies' for each revision in a collection
         in terms of actual revision ids, as opposed to branch labels or other
         symbolic names.
@@ -301,12 +359,14 @@ class RevisionMap:
                     map_[dep] for dep in util.to_tuple(revision.dependencies)
                 ]
                 revision._resolved_dependencies = tuple(
-                    [d.revision for d in deps]
+                    [d.revision for d in deps if d is not None]
                 )
             else:
                 revision._resolved_dependencies = ()
 
-    def _normalize_depends_on(self, revisions, map_):
+    def _normalize_depends_on(
+        self, revisions: Collection["Revision"], map_: _RevisionMapType
+    ) -> None:
         """Create a collection of "dependencies" that omits dependencies
         that are already ancestor nodes for each revision in a given
         collection.
@@ -327,7 +387,9 @@ class RevisionMap:
             if revision._resolved_dependencies:
                 normalized_resolved = set(revision._resolved_dependencies)
                 for rev in self._get_ancestor_nodes(
-                    [revision], include_dependencies=False, map_=map_
+                    [revision],
+                    include_dependencies=False,
+                    map_=cast(_RevisionMapType, map_),
                 ):
                     if rev is revision:
                         continue
@@ -342,7 +404,9 @@ class RevisionMap:
             else:
                 revision._normalized_resolved_dependencies = ()
 
-    def add_revision(self, revision, _replace=False):
+    def add_revision(
+        self, revision: "Revision", _replace: bool = False
+    ) -> None:
         """add a single revision to an existing map.
 
         This method is for single-revision use cases, it's not
@@ -375,7 +439,7 @@ class RevisionMap:
                     "Revision %s referenced from %s is not present"
                     % (downrev, revision)
                 )
-            map_[downrev].add_nextrev(revision)
+            cast("Revision", map_[downrev]).add_nextrev(revision)
 
         self._normalize_depends_on(revisions, map_)
 
@@ -398,7 +462,9 @@ class RevisionMap:
                 )
             ) + (revision.revision,)
 
-    def get_current_head(self, branch_label=None):
+    def get_current_head(
+        self, branch_label: Optional[str] = None
+    ) -> Optional[str]:
         """Return the current head revision.
 
         If the script directory has multiple heads
@@ -416,7 +482,7 @@ class RevisionMap:
             :meth:`.ScriptDirectory.get_heads`
 
         """
-        current_heads = self.heads
+        current_heads: Sequence[str] = self.heads
         if branch_label:
             current_heads = self.filter_for_lineage(
                 current_heads, branch_label
@@ -432,10 +498,12 @@ class RevisionMap:
         else:
             return None
 
-    def _get_base_revisions(self, identifier):
+    def _get_base_revisions(self, identifier: str) -> Tuple[str, ...]:
         return self.filter_for_lineage(self.bases, identifier)
 
-    def get_revisions(self, id_):
+    def get_revisions(
+        self, id_: Union[str, Collection[str], None]
+    ) -> Tuple["Revision", ...]:
         """Return the :class:`.Revision` instances with the given rev id
         or identifiers.
 
@@ -456,7 +524,9 @@ class RevisionMap:
         if isinstance(id_, (list, tuple, set, frozenset)):
             return sum([self.get_revisions(id_elem) for id_elem in id_], ())
         else:
-            resolved_id, branch_label = self._resolve_revision_number(id_)
+            resolved_id, branch_label = self._resolve_revision_number(
+                id_  # type:ignore [arg-type]
+            )
             if len(resolved_id) == 1:
                 try:
                     rint = int(resolved_id[0])
@@ -464,11 +534,11 @@ class RevisionMap:
                         # branch@-n -> walk down from heads
                         select_heads = self.get_revisions("heads")
                         if branch_label is not None:
-                            select_heads = [
+                            select_heads = tuple(
                                 head
                                 for head in select_heads
                                 if branch_label in head.branch_labels
-                            ]
+                            )
                         return tuple(
                             self._walk(head, steps=rint)
                             for head in select_heads
@@ -481,7 +551,7 @@ class RevisionMap:
                 for rev_id in resolved_id
             )
 
-    def get_revision(self, id_):
+    def get_revision(self, id_: Optional[str]) -> "Revision":
         """Return the :class:`.Revision` instance with the given rev id.
 
         If a symbolic name such as "head" or "base" is given, resolves
@@ -499,11 +569,11 @@ class RevisionMap:
         if len(resolved_id) > 1:
             raise MultipleHeads(resolved_id, id_)
         elif resolved_id:
-            resolved_id = resolved_id[0]
+            resolved_id = resolved_id[0]  # type:ignore[assignment]
 
-        return self._revision_for_ident(resolved_id, branch_label)
+        return self._revision_for_ident(cast(str, resolved_id), branch_label)
 
-    def _resolve_branch(self, branch_label):
+    def _resolve_branch(self, branch_label: str) -> "Revision":
         try:
             branch_rev = self._revision_map[branch_label]
         except KeyError:
@@ -517,19 +587,24 @@ class RevisionMap:
             else:
                 return nonbranch_rev
         else:
-            return branch_rev
+            return cast("Revision", branch_rev)
 
-    def _revision_for_ident(self, resolved_id, check_branch=None):
+    def _revision_for_ident(
+        self, resolved_id: str, check_branch: Optional[str] = None
+    ) -> "Revision":
+        branch_rev: Optional["Revision"]
         if check_branch:
             branch_rev = self._resolve_branch(check_branch)
         else:
             branch_rev = None
 
+        revision: Union["Revision", "Literal[False]"]
         try:
-            revision = self._revision_map[resolved_id]
+            revision = cast("Revision", self._revision_map[resolved_id])
         except KeyError:
             # break out to avoid misleading py3k stack traces
             revision = False
+        revs: Sequence[str]
         if revision is False:
             # do a partial lookup
             revs = [
@@ -562,9 +637,11 @@ class RevisionMap:
                     resolved_id,
                 )
             else:
-                revision = self._revision_map[revs[0]]
+                revision = cast("Revision", self._revision_map[revs[0]])
 
+        revision = cast("Revision", revision)
         if check_branch and revision is not None:
+            assert branch_rev is not None
             if not self._shares_lineage(
                 revision.revision, branch_rev.revision
             ):
@@ -575,7 +652,9 @@ class RevisionMap:
                 )
         return revision
 
-    def _filter_into_branch_heads(self, targets):
+    def _filter_into_branch_heads(
+        self, targets: Set["Script"]
+    ) -> Set["Script"]:
         targets = set(targets)
 
         for rev in list(targets):
@@ -586,8 +665,11 @@ class RevisionMap:
         return targets
 
     def filter_for_lineage(
-        self, targets, check_against, include_dependencies=False
-    ):
+        self,
+        targets: Sequence[_T],
+        check_against: Optional[str],
+        include_dependencies: bool = False,
+    ) -> Tuple[_T, ...]:
         id_, branch_label = self._resolve_revision_number(check_against)
 
         shares = []
@@ -596,17 +678,20 @@ class RevisionMap:
         if id_:
             shares.extend(id_)
 
-        return [
+        return tuple(
             tg
             for tg in targets
             if self._shares_lineage(
                 tg, shares, include_dependencies=include_dependencies
             )
-        ]
+        )
 
     def _shares_lineage(
-        self, target, test_against_revs, include_dependencies=False
-    ):
+        self,
+        target: _RevisionOrStr,
+        test_against_revs: Sequence[_RevisionOrStr],
+        include_dependencies: bool = False,
+    ) -> bool:
         if not test_against_revs:
             return True
         if not isinstance(target, Revision):
@@ -635,7 +720,10 @@ class RevisionMap:
             .intersection(test_against_revs)
         )
 
-    def _resolve_revision_number(self, id_):
+    def _resolve_revision_number(
+        self, id_: Optional[str]
+    ) -> Tuple[Tuple[str, ...], Optional[str]]:
+        branch_label: Optional[str]
         if isinstance(id_, compat.string_types) and "@" in id_:
             branch_label, id_ = id_.split("@", 1)
 
@@ -678,13 +766,13 @@ class RevisionMap:
 
     def iterate_revisions(
         self,
-        upper,
-        lower,
-        implicit_base=False,
-        inclusive=False,
-        assert_relative_length=True,
-        select_for_downgrade=False,
-    ):
+        upper: _RevisionIdentifierType,
+        lower: _RevisionIdentifierType,
+        implicit_base: bool = False,
+        inclusive: bool = False,
+        assert_relative_length: bool = True,
+        select_for_downgrade: bool = False,
+    ) -> Iterator["Revision"]:
         """Iterate through script revisions, starting at the given
         upper revision identifier and ending at the lower.
 
@@ -696,6 +784,7 @@ class RevisionMap:
         The iterator yields :class:`.Revision` objects.
 
         """
+        fn: Callable
         if select_for_downgrade:
             fn = self._collect_downgrade_revisions
         else:
@@ -714,12 +803,12 @@ class RevisionMap:
 
     def _get_descendant_nodes(
         self,
-        targets,
-        map_=None,
-        check=False,
-        omit_immediate_dependencies=False,
-        include_dependencies=True,
-    ):
+        targets: Collection["Revision"],
+        map_: Optional[_RevisionMapType] = None,
+        check: bool = False,
+        omit_immediate_dependencies: bool = False,
+        include_dependencies: bool = True,
+    ) -> Iterator[Any]:
 
         if omit_immediate_dependencies:
 
@@ -744,8 +833,12 @@ class RevisionMap:
         )
 
     def _get_ancestor_nodes(
-        self, targets, map_=None, check=False, include_dependencies=True
-    ):
+        self,
+        targets: Collection["Revision"],
+        map_: Optional[_RevisionMapType] = None,
+        check: bool = False,
+        include_dependencies: bool = True,
+    ) -> Iterator["Revision"]:
 
         if include_dependencies:
 
@@ -761,12 +854,18 @@ class RevisionMap:
             fn, targets, map_=map_, check=check
         )
 
-    def _iterate_related_revisions(self, fn, targets, map_, check=False):
+    def _iterate_related_revisions(
+        self,
+        fn: Callable,
+        targets: Collection["Revision"],
+        map_: Optional[_RevisionMapType],
+        check: bool = False,
+    ) -> Iterator["Revision"]:
         if map_ is None:
             map_ = self._revision_map
 
         seen = set()
-        todo = collections.deque()
+        todo: Deque["Revision"] = collections.deque()
         for target in targets:
 
             todo.append(target)
@@ -784,6 +883,7 @@ class RevisionMap:
                 # Check for map errors before collecting.
                 for rev_id in fn(rev):
                     next_rev = map_[rev_id]
+                    assert next_rev is not None
                     if next_rev.revision != rev_id:
                         raise RevisionError(
                             "Dependency resolution failed; broken map"
@@ -804,7 +904,11 @@ class RevisionMap:
                         )
                     )
 
-    def _topological_sort(self, revisions, heads):
+    def _topological_sort(
+        self,
+        revisions: Collection["Revision"],
+        heads: Any,
+    ) -> List[str]:
         """Yield revision ids of a collection of Revision objects in
         topological sorted order (i.e. revisions always come after their
         down_revisions and dependencies). Uses the order of keys in
@@ -860,6 +964,7 @@ class RevisionMap:
                 # now update the heads with our ancestors.
 
                 candidate_rev = id_to_rev[candidate]
+                assert candidate_rev is not None
 
                 heads_to_add = [
                     r
@@ -873,7 +978,6 @@ class RevisionMap:
                     del ancestors_by_idx[current_candidate_idx]
                     current_candidate_idx = max(current_candidate_idx - 1, 0)
                 else:
-
                     if (
                         not candidate_rev._normalized_resolved_dependencies
                         and len(candidate_rev._versioned_down_revisions) == 1
@@ -905,7 +1009,13 @@ class RevisionMap:
         assert not todo
         return output
 
-    def _walk(self, start, steps, branch_label=None, no_overwalk=True):
+    def _walk(
+        self,
+        start: Optional[Union[str, "Revision"]],
+        steps: int,
+        branch_label: Optional[str] = None,
+        no_overwalk: bool = True,
+    ) -> "Revision":
         """
         Walk the requested number of :steps up (steps > 0) or down (steps < 0)
         the revision tree.
@@ -918,44 +1028,55 @@ class RevisionMap:
         A RevisionError is raised if there is no unambiguous revision to
         walk to.
         """
-
+        initial: Optional[_RevisionOrBase]
         if isinstance(start, compat.string_types):
-            start = self.get_revision(start)
+            initial = self.get_revision(start)
+        else:
+            initial = start
 
+        children: Sequence[_RevisionOrBase]
         for _ in range(abs(steps)):
             if steps > 0:
                 # Walk up
                 children = [
                     rev
                     for rev in self.get_revisions(
-                        self.bases if start is None else start.nextrev
+                        self.bases
+                        if initial is None
+                        else cast("Revision", initial).nextrev
                     )
                 ]
                 if branch_label:
                     children = self.filter_for_lineage(children, branch_label)
             else:
                 # Walk down
-                if start == "base":
-                    children = tuple()
+                if initial == "base":
+                    children = ()
                 else:
                     children = self.get_revisions(
-                        self.heads if start is None else start.down_revision
+                        self.heads
+                        if initial is None
+                        else initial.down_revision
                     )
                     if not children:
-                        children = ("base",)
+                        children = cast("Tuple[Literal['base']]", ("base",))
             if not children:
                 # This will return an invalid result if no_overwalk, otherwise
                 # further steps will stay where we are.
-                return None if no_overwalk else start
+                ret = None if no_overwalk else initial
+                return ret  # type:ignore[return-value]
             elif len(children) > 1:
                 raise RevisionError("Ambiguous walk")
-            start = children[0]
+            initial = children[0]
 
-        return start
+        return cast("Revision", initial)
 
     def _parse_downgrade_target(
-        self, current_revisions, target, assert_relative_length
-    ):
+        self,
+        current_revisions: _RevisionIdentifierType,
+        target: _RevisionIdentifierType,
+        assert_relative_length: bool,
+    ) -> Tuple[Optional[str], Optional[_RevisionOrBase]]:
         """
         Parse downgrade command syntax :target to retrieve the target revision
         and branch label (if any) given the :current_revisons stamp of the
@@ -999,11 +1120,11 @@ class RevisionMap:
                 if relative_revision:
                     # Find target revision relative to current state.
                     if branch_label:
-                        symbol = self.filter_for_lineage(
+                        symbol_list = self.filter_for_lineage(
                             util.to_tuple(current_revisions), branch_label
                         )
-                        assert len(symbol) == 1
-                        symbol = symbol[0]
+                        assert len(symbol_list) == 1
+                        symbol = symbol_list[0]
                     else:
                         current_revisions = util.to_tuple(current_revisions)
                         if not current_revisions:
@@ -1045,12 +1166,15 @@ class RevisionMap:
         # No relative destination given, revision specified is absolute.
         branch_label, _, symbol = target.rpartition("@")
         if not branch_label:
-            branch_label = None
+            branch_label = None  # type:ignore[assignment]
         return branch_label, self.get_revision(symbol)
 
     def _parse_upgrade_target(
-        self, current_revisions, target, assert_relative_length
-    ):
+        self,
+        current_revisions: _RevisionIdentifierType,
+        target: _RevisionIdentifierType,
+        assert_relative_length: bool,
+    ) -> Tuple["Revision", ...]:
         """
         Parse upgrade command syntax :target to retrieve the target revision
         and given the :current_revisons stamp of the database.
@@ -1070,9 +1194,8 @@ class RevisionMap:
 
         current_revisions = util.to_tuple(current_revisions)
 
-        branch_label, symbol, relative = match.groups()
-        relative_str = relative
-        relative = int(relative)
+        branch_label, symbol, relative_str = match.groups()
+        relative = int(relative_str)
         if relative > 0:
             if symbol is None:
                 if not current_revisions:
@@ -1151,8 +1274,13 @@ class RevisionMap:
             )
 
     def _collect_downgrade_revisions(
-        self, upper, target, inclusive, implicit_base, assert_relative_length
-    ):
+        self,
+        upper: _RevisionIdentifierType,
+        target: _RevisionIdentifierType,
+        inclusive: bool,
+        implicit_base: bool,
+        assert_relative_length: bool,
+    ) -> Any:
         """
         Compute the set of current revisions specified by :upper, and the
         downgrade target specified by :target. Return all dependents of target
@@ -1244,8 +1372,13 @@ class RevisionMap:
         return downgrade_revisions, heads
 
     def _collect_upgrade_revisions(
-        self, upper, lower, inclusive, implicit_base, assert_relative_length
-    ):
+        self,
+        upper: _RevisionIdentifierType,
+        lower: _RevisionIdentifierType,
+        inclusive: bool,
+        implicit_base: bool,
+        assert_relative_length: bool,
+    ) -> Tuple[Set["Revision"], Tuple[Optional[_RevisionOrBase]]]:
         """
         Compute the set of required revisions specified by :upper, and the
         current set of active revisions specified by :lower. Find the
@@ -1257,14 +1390,13 @@ class RevisionMap:
         of the current/lower revisions. Dependencies from branches with
         different bases will not be included.
         """
-        targets = self._parse_upgrade_target(
+        targets: Collection["Revision"] = self._parse_upgrade_target(
             current_revisions=lower,
             target=upper,
             assert_relative_length=assert_relative_length,
         )
 
-        assert targets is not None
-        assert type(targets) is tuple, "targets should be a tuple"
+        # assert type(targets) is tuple, "targets should be a tuple"
 
         # Handled named bases (e.g. branch@... -> heads should only produce
         # targets on the given branch)
@@ -1332,7 +1464,7 @@ class RevisionMap:
             )
             needs.intersection_update(lower_descendents)
 
-        return needs, targets
+        return needs, tuple(targets)  # type:ignore[return-value]
 
 
 class Revision:
@@ -1346,15 +1478,15 @@ class Revision:
 
     """
 
-    nextrev = frozenset()
+    nextrev: FrozenSet[str] = frozenset()
     """following revisions, based on down_revision only."""
 
-    _all_nextrev = frozenset()
+    _all_nextrev: FrozenSet[str] = frozenset()
 
-    revision = None
+    revision: str = None  # type: ignore[assignment]
     """The string revision number."""
 
-    down_revision = None
+    down_revision: Optional[_RevIdType] = None
     """The ``down_revision`` identifier(s) within the migration script.
 
     Note that the total set of "down" revisions is
@@ -1362,7 +1494,7 @@ class Revision:
 
     """
 
-    dependencies = None
+    dependencies: Optional[_RevIdType] = None
     """Additional revisions which this revision is dependent on.
 
     From a migration standpoint, these dependencies are added to the
@@ -1372,12 +1504,15 @@ class Revision:
 
     """
 
-    branch_labels = None
+    branch_labels: Set[str] = None  # type: ignore[assignment]
     """Optional string/tuple of symbolic names to apply to this
     revision's branch"""
 
+    _resolved_dependencies: Tuple[str, ...]
+    _normalized_resolved_dependencies: Tuple[str, ...]
+
     @classmethod
-    def verify_rev_id(cls, revision):
+    def verify_rev_id(cls, revision: str) -> None:
         illegal_chars = set(revision).intersection(_revision_illegal_chars)
         if illegal_chars:
             raise RevisionError(
@@ -1386,8 +1521,12 @@ class Revision:
             )
 
     def __init__(
-        self, revision, down_revision, dependencies=None, branch_labels=None
-    ):
+        self,
+        revision: str,
+        down_revision: Optional[Union[str, Tuple[str, ...]]],
+        dependencies: Optional[Tuple[str, ...]] = None,
+        branch_labels: Optional[Tuple[str, ...]] = None,
+    ) -> None:
         if down_revision and revision in util.to_tuple(down_revision):
             raise LoopDetected(revision)
         elif dependencies is not None and revision in util.to_tuple(
@@ -1402,7 +1541,7 @@ class Revision:
         self._orig_branch_labels = util.to_tuple(branch_labels, default=())
         self.branch_labels = set(self._orig_branch_labels)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         args = [repr(self.revision), repr(self.down_revision)]
         if self.dependencies:
             args.append("dependencies=%r" % (self.dependencies,))
@@ -1410,20 +1549,20 @@ class Revision:
             args.append("branch_labels=%r" % (self.branch_labels,))
         return "%s(%s)" % (self.__class__.__name__, ", ".join(args))
 
-    def add_nextrev(self, revision):
+    def add_nextrev(self, revision: "Revision") -> None:
         self._all_nextrev = self._all_nextrev.union([revision.revision])
         if self.revision in revision._versioned_down_revisions:
             self.nextrev = self.nextrev.union([revision.revision])
 
     @property
-    def _all_down_revisions(self):
+    def _all_down_revisions(self) -> Tuple[str, ...]:
         return util.dedupe_tuple(
             util.to_tuple(self.down_revision, default=())
             + self._resolved_dependencies
         )
 
     @property
-    def _normalized_down_revisions(self):
+    def _normalized_down_revisions(self) -> Tuple[str, ...]:
         """return immediate down revisions for a rev, omitting dependencies
         that are still dependencies of ancestors.
 
@@ -1434,11 +1573,11 @@ class Revision:
         )
 
     @property
-    def _versioned_down_revisions(self):
+    def _versioned_down_revisions(self) -> Tuple[str, ...]:
         return util.to_tuple(self.down_revision, default=())
 
     @property
-    def is_head(self):
+    def is_head(self) -> bool:
         """Return True if this :class:`.Revision` is a 'head' revision.
 
         This is determined based on whether any other :class:`.Script`
@@ -1449,17 +1588,17 @@ class Revision:
         return not bool(self.nextrev)
 
     @property
-    def _is_real_head(self):
+    def _is_real_head(self) -> bool:
         return not bool(self._all_nextrev)
 
     @property
-    def is_base(self):
+    def is_base(self) -> bool:
         """Return True if this :class:`.Revision` is a 'base' revision."""
 
         return self.down_revision is None
 
     @property
-    def _is_real_base(self):
+    def _is_real_base(self) -> bool:
         """Return True if this :class:`.Revision` is a "real" base revision,
         e.g. that it has no dependencies either."""
 
@@ -1469,7 +1608,7 @@ class Revision:
         return self.down_revision is None and self.dependencies is None
 
     @property
-    def is_branch_point(self):
+    def is_branch_point(self) -> bool:
         """Return True if this :class:`.Script` is a branch point.
 
         A branchpoint is defined as a :class:`.Script` which is referred
@@ -1481,7 +1620,7 @@ class Revision:
         return len(self.nextrev) > 1
 
     @property
-    def _is_real_branch_point(self):
+    def _is_real_branch_point(self) -> bool:
         """Return True if this :class:`.Script` is a 'real' branch point,
         taking into account dependencies as well.
 
@@ -1489,13 +1628,15 @@ class Revision:
         return len(self._all_nextrev) > 1
 
     @property
-    def is_merge_point(self):
+    def is_merge_point(self) -> bool:
         """Return True if this :class:`.Script` is a merge point."""
 
         return len(self._versioned_down_revisions) > 1
 
 
-def tuple_rev_as_scalar(rev):
+def tuple_rev_as_scalar(
+    rev: Optional[Sequence[str]],
+) -> Optional[Union[str, Sequence[str]]]:
     if not rev:
         return None
     elif len(rev) == 1:
index 8cd3dccc20f749e51c56a7a945cdde3aec00fb90..8f9e35efc8406ceb9f16ef3ec67427e55aa818a7 100644 (file)
@@ -1,6 +1,11 @@
 import shlex
 import subprocess
 import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Union
 
 from .. import util
 from ..util import compat
@@ -11,7 +16,7 @@ REVISION_SCRIPT_TOKEN = "REVISION_SCRIPT_FILENAME"
 _registry = {}
 
 
-def register(name):
+def register(name: str) -> Callable:
     """A function decorator that will register that function as a write hook.
 
     See the documentation linked below for an example.
@@ -31,7 +36,9 @@ def register(name):
     return decorate
 
 
-def _invoke(name, revision, options):
+def _invoke(
+    name: str, revision: str, options: Dict[str, Union[str, int]]
+) -> Any:
     """Invokes the formatter registered for the given name.
 
     :param name: The name of a formatter in the registry
@@ -50,7 +57,7 @@ def _invoke(name, revision, options):
         return hook(revision, options)
 
 
-def _run_hooks(path, hook_config):
+def _run_hooks(path: str, hook_config: Dict[str, str]) -> None:
     """Invoke hooks for a generated revision."""
 
     from .base import _split_on_space_comma
@@ -83,7 +90,7 @@ def _run_hooks(path, hook_config):
             )
 
 
-def _parse_cmdline_options(cmdline_options_str, path):
+def _parse_cmdline_options(cmdline_options_str: str, path: str) -> List[str]:
     """Parse options from a string into a list.
 
     Also substitutes the revision script token with the actual filename of
index e22ac6b6727401738fc47bd4c70f15ba37d80dcf..ed532062d1b81c54669f6023b384375c9f090ba9 100644 (file)
@@ -1,8 +1,8 @@
-from __future__ import absolute_import
-
 import contextlib
 import re
 import sys
+from typing import Any
+from typing import Dict
 
 from sqlalchemy import exc as sa_exc
 from sqlalchemy import util
@@ -114,7 +114,7 @@ def eq_ignore_whitespace(a, b, msg=None):
     assert a == b, msg or "%r != %r" % (a, b)
 
 
-_dialect_mods = {}
+_dialect_mods: Dict[Any, Any] = {}
 
 
 def _get_dialect(name):
index cccc3822623209fbbf91bd84c09b15336bb4e433..c273665f0ae8ac09cd8683261c5f6498441598c1 100644 (file)
@@ -3,6 +3,8 @@ import configparser
 from contextlib import contextmanager
 import io
 import re
+from typing import Any
+from typing import Dict
 
 from sqlalchemy import Column
 from sqlalchemy import inspect
@@ -61,7 +63,7 @@ if sqla_14:
     from sqlalchemy.testing.fixtures import FutureEngineMixin
 else:
 
-    class FutureEngineMixin:
+    class FutureEngineMixin:  # type:ignore[no-redef]
         __requires__ = ("sqlalchemy_14",)
 
 
@@ -78,7 +80,7 @@ def capture_db(dialect="postgresql://"):
     return engine, buf
 
 
-_engs = {}
+_engs: Dict[Any, Any] = {}
 
 
 @contextmanager
index f1792f8973b85a8ea48c167585727175b0d8091b..37780ab0213c70cfa7e082686de041141d0e5874 100644 (file)
@@ -1,5 +1,3 @@
-import sys
-
 from sqlalchemy.testing.requirements import Requirements
 
 from alembic import util
@@ -85,12 +83,6 @@ class SuiteRequirements(Requirements):
             "SQLAlchemy 1.4 or greater required",
         )
 
-    @property
-    def python3(self):
-        return exclusions.skip_if(
-            lambda: sys.version_info < (3,), "Python version 3.xx is required."
-        )
-
     @property
     def comments(self):
         return exclusions.only_if(
index 44fc24f334f71de0f120e15b24e3832dce214511..ea1957aed3eb82fda3aae255022cb573fb95dd3a 100644 (file)
@@ -1,3 +1,6 @@
+from typing import Any
+from typing import Dict
+
 from sqlalchemy import CHAR
 from sqlalchemy import CheckConstraint
 from sqlalchemy import Column
@@ -211,7 +214,7 @@ class AutogenTest(_ComparesFKs):
     def _get_bind(cls):
         return config.db
 
-    configure_opts = {}
+    configure_opts: Dict[Any, Any] = {}
 
     @classmethod
     def setup_class(cls):
index 0fdd86dcd02021a4de1debbbf6388c23bb2ffc4b..a07813c12aa71b8407113fb8bc58bb28e65d23cf 100644 (file)
@@ -2,6 +2,12 @@ import collections
 import inspect
 import io
 import os
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Type
 
 is_posix = os.name == "posix"
 
@@ -10,11 +16,11 @@ ArgSpec = collections.namedtuple(
 )
 
 
-def inspect_getargspec(func):
+def inspect_getargspec(func: Callable) -> ArgSpec:
     """getargspec based on fully vendored getfullargspec from Python 3.3."""
 
     if inspect.ismethod(func):
-        func = func.__func__
+        func = func.__func__  # type: ignore
     if not inspect.isfunction(func):
         raise TypeError("{!r} is not a Python function".format(func))
 
@@ -36,7 +42,7 @@ def inspect_getargspec(func):
     if co.co_flags & inspect.CO_VARKEYWORDS:
         varkw = co.co_varnames[nargs]
 
-    return ArgSpec(args, varargs, varkw, func.__defaults__)
+    return ArgSpec(args, varargs, varkw, func.__defaults__)  # type: ignore
 
 
 string_types = (str,)
@@ -57,20 +63,20 @@ def _formatannotation(annotation, base_module=None):
 
 
 def inspect_formatargspec(
-    args,
-    varargs=None,
-    varkw=None,
-    defaults=None,
-    kwonlyargs=(),
-    kwonlydefaults={},
-    annotations={},
-    formatarg=str,
-    formatvarargs=lambda name: "*" + name,
-    formatvarkw=lambda name: "**" + name,
-    formatvalue=lambda value: "=" + repr(value),
-    formatreturns=lambda text: " -> " + text,
-    formatannotation=_formatannotation,
-):
+    args: List[str],
+    varargs: Optional[str] = None,
+    varkw: Optional[str] = None,
+    defaults: Optional[Any] = None,
+    kwonlyargs: tuple = (),
+    kwonlydefaults: Dict[Any, Any] = {},
+    annotations: Dict[Any, Any] = {},
+    formatarg: Type[str] = str,
+    formatvarargs: Callable = lambda name: "*" + name,
+    formatvarkw: Callable = lambda name: "**" + name,
+    formatvalue: Callable = lambda value: "=" + repr(value),
+    formatreturns: Callable = lambda text: " -> " + text,
+    formatannotation: Callable = _formatannotation,
+) -> str:
     """Copy formatargspec from python 3.7 standard library.
 
     Python 3 has deprecated formatargspec and requested that Signature
@@ -118,5 +124,5 @@ def inspect_formatargspec(
 # into a given buffer, but doesn't close it.
 # not sure of a more idiomatic approach to this.
 class EncodedIO(io.TextIOWrapper):
-    def close(self):
+    def close(self) -> None:
         pass
index c27f0f36b2eaaa83352fc824050b6e1ccf06f607..ba376c0793068302e8ff1572b7637ddbcde5372f 100644 (file)
@@ -3,12 +3,18 @@ from os.path import exists
 from os.path import join
 from os.path import splitext
 from subprocess import check_call
+from typing import Dict
+from typing import List
+from typing import Mapping
+from typing import Optional
 
 from .compat import is_posix
 from .exc import CommandError
 
 
-def open_in_editor(filename, environ=None):
+def open_in_editor(
+    filename: str, environ: Optional[Dict[str, str]] = None
+) -> None:
     """
     Opens the given file in a text editor. If the environment variable
     ``EDITOR`` is set, this is taken as preference.
@@ -22,15 +28,15 @@ def open_in_editor(filename, environ=None):
     :param environ: An optional drop-in replacement for ``os.environ``. Used
         mainly for testing.
     """
-
+    env = os.environ if environ is None else environ
     try:
-        editor = _find_editor(environ)
+        editor = _find_editor(env)
         check_call([editor, filename])
     except Exception as exc:
         raise CommandError("Error executing editor (%s)" % (exc,)) from exc
 
 
-def _find_editor(environ=None):
+def _find_editor(environ: Mapping[str, str]) -> str:
     candidates = _default_editors()
     for i, var in enumerate(("EDITOR", "VISUAL")):
         if var in environ:
@@ -50,7 +56,9 @@ def _find_editor(environ=None):
     )
 
 
-def _find_executable(candidate, environ):
+def _find_executable(
+    candidate: str, environ: Mapping[str, str]
+) -> Optional[str]:
     # Assuming this is on the PATH, we need to determine it's absolute
     # location. Otherwise, ``check_call`` will fail
     if not is_posix and splitext(candidate)[1] != ".exe":
@@ -62,7 +70,7 @@ def _find_executable(candidate, environ):
     return None
 
 
-def _default_editors():
+def _default_editors() -> List[str]:
     # Look for an editor. Prefer the user's choice by env-var, fall back to
     # most commonly installed editor (nano/vim)
     if is_posix:
index dbd1f217a0dee55ebe4c98126eb37a02aada2402..87a9aca35fcedd45bb504f459ba7d49744f505c1 100644 (file)
@@ -1,6 +1,16 @@
 import collections
 from collections.abc import Iterable
 import textwrap
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Tuple
+from typing import TypeVar
+from typing import Union
 import uuid
 import warnings
 
@@ -14,10 +24,13 @@ from .compat import inspect_getargspec
 from .compat import string_types
 
 
+_T = TypeVar("_T")
+
+
 class _ModuleClsMeta(type):
-    def __setattr__(cls, key, value):
+    def __setattr__(cls, key: str, value: Callable) -> None:
         super(_ModuleClsMeta, cls).__setattr__(key, value)
-        cls._update_module_proxies(key)
+        cls._update_module_proxies(key)  # type: ignore
 
 
 class ModuleClsProxy(metaclass=_ModuleClsMeta):
@@ -29,22 +42,24 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta):
 
     """
 
-    _setups = collections.defaultdict(lambda: (set(), []))
+    _setups: Dict[type, Tuple[set, list]] = collections.defaultdict(
+        lambda: (set(), [])
+    )
 
     @classmethod
-    def _update_module_proxies(cls, name):
+    def _update_module_proxies(cls, name: str) -> None:
         attr_names, modules = cls._setups[cls]
         for globals_, locals_ in modules:
             cls._add_proxied_attribute(name, globals_, locals_, attr_names)
 
-    def _install_proxy(self):
+    def _install_proxy(self) -> None:
         attr_names, modules = self._setups[self.__class__]
         for globals_, locals_ in modules:
             globals_["_proxy"] = self
             for attr_name in attr_names:
                 globals_[attr_name] = getattr(self, attr_name)
 
-    def _remove_proxy(self):
+    def _remove_proxy(self) -> None:
         attr_names, modules = self._setups[self.__class__]
         for globals_, locals_ in modules:
             globals_["_proxy"] = None
@@ -171,10 +186,25 @@ def _with_legacy_names(translations):
     return decorate
 
 
-def rev_id():
+def rev_id() -> str:
     return uuid.uuid4().hex[-12:]
 
 
+@overload
+def to_tuple(x: Any, default: tuple) -> tuple:
+    ...
+
+
+@overload
+def to_tuple(x: None, default: _T = None) -> _T:
+    ...
+
+
+@overload
+def to_tuple(x: Any, default: Optional[tuple] = None) -> tuple:
+    ...
+
+
 def to_tuple(x, default=None):
     if x is None:
         return default
@@ -186,16 +216,18 @@ def to_tuple(x, default=None):
         return (x,)
 
 
-def dedupe_tuple(tup):
+def dedupe_tuple(tup: Tuple[str, ...]) -> Tuple[str, ...]:
     return tuple(unique_list(tup))
 
 
 class Dispatcher:
-    def __init__(self, uselist=False):
-        self._registry = {}
+    def __init__(self, uselist: bool = False) -> None:
+        self._registry: Dict[tuple, Any] = {}
         self.uselist = uselist
 
-    def dispatch_for(self, target, qualifier="default"):
+    def dispatch_for(
+        self, target: Any, qualifier: str = "default"
+    ) -> Callable:
         def decorate(fn):
             if self.uselist:
                 self._registry.setdefault((target, qualifier), []).append(fn)
@@ -206,10 +238,10 @@ class Dispatcher:
 
         return decorate
 
-    def dispatch(self, obj, qualifier="default"):
+    def dispatch(self, obj: Any, qualifier: str = "default") -> Any:
 
         if isinstance(obj, string_types):
-            targets = [obj]
+            targets: Sequence = [obj]
         elif isinstance(obj, type):
             targets = obj.__mro__
         else:
@@ -223,7 +255,9 @@ class Dispatcher:
         else:
             raise ValueError("no dispatch function for object: %s" % obj)
 
-    def _fn_or_list(self, fn_or_list):
+    def _fn_or_list(
+        self, fn_or_list: Union[List[Callable], Callable]
+    ) -> Callable:
         if self.uselist:
 
             def go(*arg, **kw):
@@ -232,9 +266,9 @@ class Dispatcher:
 
             return go
         else:
-            return fn_or_list
+            return fn_or_list  # type: ignore
 
-    def branch(self):
+    def branch(self) -> "Dispatcher":
         """Return a copy of this dispatcher that is independently
         writable."""
 
index 70c9128288399323534b21954ba4837d1965278d..062890a32ee92c9464870b5d40c57eddbf278f36 100644 (file)
@@ -2,6 +2,11 @@ from collections.abc import Iterable
 import logging
 import sys
 import textwrap
+from typing import Any
+from typing import Callable
+from typing import Optional
+from typing import TextIO
+from typing import Union
 import warnings
 
 from sqlalchemy.engine import url
@@ -29,7 +34,7 @@ except (ImportError, IOError):
     TERMWIDTH = None
 
 
-def write_outstream(stream, *text):
+def write_outstream(stream: TextIO, *text) -> None:
     encoding = getattr(stream, "encoding", "ascii") or "ascii"
     for t in text:
         if not isinstance(t, binary_type):
@@ -44,7 +49,7 @@ def write_outstream(stream, *text):
             break
 
 
-def status(_statmsg, fn, *arg, **kw):
+def status(_statmsg: str, fn: Callable, *arg, **kw) -> Any:
     newline = kw.pop("newline", False)
     msg(_statmsg + " ...", newline, True)
     try:
@@ -56,27 +61,27 @@ def status(_statmsg, fn, *arg, **kw):
         raise
 
 
-def err(message):
+def err(message: str):
     log.error(message)
     msg("FAILED: %s" % message)
     sys.exit(-1)
 
 
-def obfuscate_url_pw(u):
-    u = url.make_url(u)
+def obfuscate_url_pw(input_url: str) -> str:
+    u = url.make_url(input_url)
     if u.password:
         if sqla_compat.sqla_14:
             u = u.set(password="XXXXX")
         else:
-            u.password = "XXXXX"
+            u.password = "XXXXX"  # type: ignore[misc]
     return str(u)
 
 
-def warn(msg, stacklevel=2):
+def warn(msg: str, stacklevel: int = 2) -> None:
     warnings.warn(msg, UserWarning, stacklevel=stacklevel)
 
 
-def msg(msg, newline=True, flush=False):
+def msg(msg: str, newline: bool = True, flush: bool = False) -> None:
     if TERMWIDTH is None:
         write_outstream(sys.stdout, msg)
         if newline:
@@ -92,7 +97,7 @@ def msg(msg, newline=True, flush=False):
         sys.stdout.flush()
 
 
-def format_as_comma(value):
+def format_as_comma(value: Optional[Union[str, "Iterable[str]"]]) -> str:
     if value is None:
         return ""
     elif isinstance(value, string_types):
index 53cc3cce44c21b7bb2d9aa719c330d05d7e65339..7eb582eff8533a4b61d35b40e387ba3934cbede2 100644 (file)
@@ -4,6 +4,7 @@ import importlib.util
 import os
 import re
 import tempfile
+from typing import Optional
 
 from mako import exceptions
 from mako.template import Template
@@ -11,7 +12,9 @@ from mako.template import Template
 from .exc import CommandError
 
 
-def template_to_file(template_file, dest, output_encoding, **kw):
+def template_to_file(
+    template_file: str, dest: str, output_encoding: str, **kw
+) -> None:
     template = Template(filename=template_file)
     try:
         output = template.render_unicode(**kw).encode(output_encoding)
@@ -32,7 +35,7 @@ def template_to_file(template_file, dest, output_encoding, **kw):
             f.write(output)
 
 
-def coerce_resource_to_filename(fname):
+def coerce_resource_to_filename(fname: str) -> str:
     """Interpret a filename as either a filesystem location or as a package
     resource.
 
@@ -47,7 +50,7 @@ def coerce_resource_to_filename(fname):
     return fname
 
 
-def pyc_file_from_path(path):
+def pyc_file_from_path(path: str) -> Optional[str]:
     """Given a python source path, locate the .pyc."""
 
     candidate = importlib.util.cache_from_source(path)
@@ -64,7 +67,7 @@ def pyc_file_from_path(path):
         return None
 
 
-def load_python_file(dir_, filename):
+def load_python_file(dir_: str, filename: str):
     """Load a file from the given path as a Python module."""
 
     module_id = re.sub(r"\W", "_", filename)
@@ -78,21 +81,15 @@ def load_python_file(dir_, filename):
             if pyc_path is None:
                 raise ImportError("Can't find Python file %s" % path)
             else:
-                module = load_module_pyc(module_id, pyc_path)
+                module = load_module_py(module_id, pyc_path)
     elif ext in (".pyc", ".pyo"):
-        module = load_module_pyc(module_id, path)
+        module = load_module_py(module_id, path)
     return module
 
 
-def load_module_py(module_id, path):
+def load_module_py(module_id: str, path: str):
     spec = importlib.util.spec_from_file_location(module_id, path)
+    assert spec
     module = importlib.util.module_from_spec(spec)
-    spec.loader.exec_module(module)
-    return module
-
-
-def load_module_pyc(module_id, path):
-    spec = importlib.util.spec_from_file_location(module_id, path)
-    module = importlib.util.module_from_spec(spec)
-    spec.loader.exec_module(module)
+    spec.loader.exec_module(module)  # type: ignore
     return module
index a04ab2e9ce0625c27adc49e15194f7925152912c..e1ccd415e2d0922bdbe1e9f5126e389f0f51d6a9 100644 (file)
@@ -1,5 +1,11 @@
 import contextlib
 import re
+from typing import Iterator
+from typing import Mapping
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
 
 from sqlalchemy import __version__
 from sqlalchemy import inspect
@@ -12,15 +18,34 @@ from sqlalchemy.schema import CheckConstraint
 from sqlalchemy.schema import Column
 from sqlalchemy.schema import ForeignKeyConstraint
 from sqlalchemy.sql import visitors
+from sqlalchemy.sql.elements import BindParameter
 from sqlalchemy.sql.elements import quoted_name
-from sqlalchemy.sql.expression import _BindParamClause
-from sqlalchemy.sql.expression import _TextClause as TextClause
+from sqlalchemy.sql.elements import TextClause
 from sqlalchemy.sql.visitors import traverse
 
 from . import compat
 
-
-def _safe_int(value):
+if TYPE_CHECKING:
+    from sqlalchemy import Index
+    from sqlalchemy import Table
+    from sqlalchemy.engine import Connection
+    from sqlalchemy.engine import Dialect
+    from sqlalchemy.engine import Transaction
+    from sqlalchemy.engine.reflection import Inspector
+    from sqlalchemy.sql.base import ColumnCollection
+    from sqlalchemy.sql.compiler import SQLCompiler
+    from sqlalchemy.sql.dml import Insert
+    from sqlalchemy.sql.elements import ColumnClause
+    from sqlalchemy.sql.elements import ColumnElement
+    from sqlalchemy.sql.schema import Constraint
+    from sqlalchemy.sql.schema import SchemaItem
+    from sqlalchemy.sql.selectable import Select
+    from sqlalchemy.sql.selectable import TableClause
+
+_CE = TypeVar("_CE", bound=Union["ColumnElement", "SchemaItem"])
+
+
+def _safe_int(value: str) -> Union[int, str]:
     try:
         return int(value)
     except:
@@ -36,6 +61,7 @@ sqla_14 = _vers >= (1, 4)
 try:
     from sqlalchemy import Computed  # noqa
 except ImportError:
+    Computed = None  # type: ignore
     has_computed = False
     has_computed_reflection = False
 else:
@@ -45,6 +71,7 @@ else:
 try:
     from sqlalchemy import Identity  # noqa
 except ImportError:
+    Identity = None  # type: ignore
     has_identity = False
 else:
     # attributes common to Indentity and Sequence
@@ -67,21 +94,26 @@ AUTOINCREMENT_DEFAULT = "auto"
 
 
 @contextlib.contextmanager
-def _ensure_scope_for_ddl(connection):
+def _ensure_scope_for_ddl(
+    connection: Optional["Connection"],
+) -> Iterator[None]:
     try:
-        in_transaction = connection.in_transaction
+        in_transaction = connection.in_transaction  # type: ignore[union-attr]
     except AttributeError:
-        # catch for MockConnection
+        # catch for MockConnection, None
         yield
     else:
         if not in_transaction():
+            assert connection is not None
             with connection.begin():
                 yield
         else:
             yield
 
 
-def _safe_begin_connection_transaction(connection):
+def _safe_begin_connection_transaction(
+    connection: "Connection",
+) -> "Transaction":
     transaction = _get_connection_transaction(connection)
     if transaction:
         return transaction
@@ -89,9 +121,9 @@ def _safe_begin_connection_transaction(connection):
         return connection.begin()
 
 
-def _get_connection_in_transaction(connection):
+def _get_connection_in_transaction(connection: Optional["Connection"]) -> bool:
     try:
-        in_transaction = connection.in_transaction
+        in_transaction = connection.in_transaction  # type: ignore
     except AttributeError:
         # catch for MockConnection
         return False
@@ -99,28 +131,33 @@ def _get_connection_in_transaction(connection):
         return in_transaction()
 
 
-def _copy(schema_item, **kw):
+def _copy(schema_item: _CE, **kw) -> _CE:
     if hasattr(schema_item, "_copy"):
         return schema_item._copy(**kw)
     else:
         return schema_item.copy(**kw)
 
 
-def _get_connection_transaction(connection):
+def _get_connection_transaction(
+    connection: "Connection",
+) -> Optional["Transaction"]:
     if sqla_14:
         return connection.get_transaction()
     else:
-        return connection._root._Connection__transaction
+        r = connection._root  # type: ignore[attr-defined]
+        return r._Connection__transaction
 
 
-def _create_url(*arg, **kw):
+def _create_url(*arg, **kw) -> url.URL:
     if hasattr(url.URL, "create"):
         return url.URL.create(*arg, **kw)
     else:
         return url.URL(*arg, **kw)
 
 
-def _connectable_has_table(connectable, tablename, schemaname):
+def _connectable_has_table(
+    connectable: "Connection", tablename: str, schemaname: Union[str, None]
+) -> bool:
     if sqla_14:
         return inspect(connectable).has_table(tablename, schemaname)
     else:
@@ -148,23 +185,25 @@ def _nullability_might_be_unset(metadata_column):
         )
 
 
-def _server_default_is_computed(*server_default):
+def _server_default_is_computed(*server_default) -> bool:
     if not has_computed:
         return False
     else:
         return any(isinstance(sd, Computed) for sd in server_default)
 
 
-def _server_default_is_identity(*server_default):
+def _server_default_is_identity(*server_default) -> bool:
     if not sqla_14:
         return False
     else:
         return any(isinstance(sd, Identity) for sd in server_default)
 
 
-def _table_for_constraint(constraint):
+def _table_for_constraint(constraint: "Constraint") -> "Table":
     if isinstance(constraint, ForeignKeyConstraint):
-        return constraint.parent
+        table = constraint.parent
+        assert table is not None
+        return table
     else:
         return constraint.table
 
@@ -178,7 +217,9 @@ def _columns_for_constraint(constraint):
         return list(constraint.columns)
 
 
-def _reflect_table(inspector, table, include_cols):
+def _reflect_table(
+    inspector: "Inspector", table: "Table", include_cols: None
+) -> None:
     if sqla_14:
         return inspector.reflect_table(table, None)
     else:
@@ -213,19 +254,20 @@ def _fk_spec(constraint):
     )
 
 
-def _fk_is_self_referential(constraint):
-    spec = constraint.elements[0]._get_colspec()
+def _fk_is_self_referential(constraint: "ForeignKeyConstraint") -> bool:
+    spec = constraint.elements[0]._get_colspec()  # type: ignore[attr-defined]
     tokens = spec.split(".")
     tokens.pop(-1)  # colname
     tablekey = ".".join(tokens)
+    assert constraint.parent is not None
     return tablekey == constraint.parent.key
 
 
-def _is_type_bound(constraint):
+def _is_type_bound(constraint: "Constraint") -> bool:
     # this deals with SQLAlchemy #3260, don't copy CHECK constraints
     # that will be generated by the type.
     # new feature added for #3260
-    return constraint._type_bound
+    return constraint._type_bound  # type: ignore[attr-defined]
 
 
 def _find_columns(clause):
@@ -236,16 +278,21 @@ def _find_columns(clause):
     return cols
 
 
-def _remove_column_from_collection(collection, column):
+def _remove_column_from_collection(
+    collection: "ColumnCollection", column: Union["Column", "ColumnClause"]
+) -> None:
     """remove a column from a ColumnCollection."""
 
     # workaround for older SQLAlchemy, remove the
     # same object that's present
+    assert column.key is not None
     to_remove = collection[column.key]
     collection.remove(to_remove)
 
 
-def _textual_index_column(table, text_):
+def _textual_index_column(
+    table: "Table", text_: Union[str, "TextClause", "ColumnElement"]
+) -> Union["ColumnElement", "Column"]:
     """a workaround for the Index construct's severe lack of flexibility"""
     if isinstance(text_, compat.string_types):
         c = Column(text_, sqltypes.NULLTYPE)
@@ -259,7 +306,7 @@ def _textual_index_column(table, text_):
         raise ValueError("String or text() construct expected")
 
 
-def _copy_expression(expression, target_table):
+def _copy_expression(expression: _CE, target_table: "Table") -> _CE:
     def replace(col):
         if (
             isinstance(col, Column)
@@ -296,7 +343,7 @@ class _textual_index_element(sql.ColumnElement):
 
     __visit_name__ = "_textual_idx_element"
 
-    def __init__(self, table, text):
+    def __init__(self, table: "Table", text: "TextClause") -> None:
         self.table = table
         self.text = text
         self.key = text.text
@@ -308,16 +355,20 @@ class _textual_index_element(sql.ColumnElement):
 
 
 @compiles(_textual_index_element)
-def _render_textual_index_column(element, compiler, **kw):
+def _render_textual_index_column(
+    element: _textual_index_element, compiler: "SQLCompiler", **kw
+) -> str:
     return compiler.process(element.text, **kw)
 
 
-class _literal_bindparam(_BindParamClause):
+class _literal_bindparam(BindParameter):
     pass
 
 
 @compiles(_literal_bindparam)
-def _render_literal_bindparam(element, compiler, **kw):
+def _render_literal_bindparam(
+    element: _literal_bindparam, compiler: "SQLCompiler", **kw
+) -> str:
     return compiler.render_literal_bindparam(element, **kw)
 
 
@@ -329,17 +380,20 @@ def _get_index_column_names(idx):
     return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)]
 
 
-def _column_kwargs(col):
+def _column_kwargs(col: "Column") -> Mapping:
     if sqla_13:
         return col.kwargs
     else:
         return {}
 
 
-def _get_constraint_final_name(constraint, dialect):
+def _get_constraint_final_name(
+    constraint: Union["Index", "Constraint"], dialect: Optional["Dialect"]
+) -> Optional[str]:
     if constraint.name is None:
         return None
-    elif sqla_14:
+    assert dialect is not None
+    if sqla_14:
         # for SQLAlchemy 1.4 we would like to have the option to expand
         # the use of "deferred" names for constraints as well as to have
         # some flexibility with "None" name and similar; make use of new
@@ -355,7 +409,7 @@ def _get_constraint_final_name(constraint, dialect):
         if hasattr(constraint.name, "quote"):
             # might be quoted_name, might be truncated_name, keep it the
             # same
-            quoted_name_cls = type(constraint.name)
+            quoted_name_cls: type = type(constraint.name)
         else:
             quoted_name_cls = quoted_name
 
@@ -364,7 +418,8 @@ def _get_constraint_final_name(constraint, dialect):
 
         if isinstance(constraint, schema.Index):
             # name should not be quoted.
-            return dialect.ddl_compiler(dialect, None)._prepared_index_name(
+            d = dialect.ddl_compiler(dialect, None)
+            return d._prepared_index_name(  # type: ignore[attr-defined]
                 constraint
             )
         else:
@@ -372,10 +427,13 @@ def _get_constraint_final_name(constraint, dialect):
             return dialect.identifier_preparer.format_constraint(constraint)
 
 
-def _constraint_is_named(constraint, dialect):
+def _constraint_is_named(
+    constraint: Union["Constraint", "Index"], dialect: Optional["Dialect"]
+) -> bool:
     if sqla_14:
         if constraint.name is None:
             return False
+        assert dialect is not None
         name = dialect.identifier_preparer.format_constraint(
             constraint, _alembic_quote=False
         )
@@ -384,18 +442,21 @@ def _constraint_is_named(constraint, dialect):
         return constraint.name is not None
 
 
-def _is_mariadb(mysql_dialect):
+def _is_mariadb(mysql_dialect: "Dialect") -> bool:
     if sqla_14:
-        return mysql_dialect.is_mariadb
+        return mysql_dialect.is_mariadb  # type: ignore[attr-defined]
     else:
-        return mysql_dialect.server_version_info and mysql_dialect._is_mariadb
+        return bool(
+            mysql_dialect.server_version_info
+            and mysql_dialect._is_mariadb  # type: ignore[attr-defined]
+        )
 
 
 def _mariadb_normalized_version_info(mysql_dialect):
     return mysql_dialect._mariadb_normalized_version_info
 
 
-def _insert_inline(table):
+def _insert_inline(table: Union["TableClause", "Table"]) -> "Insert":
     if sqla_14:
         return table.insert().inline()
     else:
@@ -408,10 +469,10 @@ if sqla_14:
 else:
     from sqlalchemy import create_engine
 
-    def create_mock_engine(url, executor):
+    def create_mock_engine(url, executor, **kw):  # type: ignore[misc]
         return create_engine(
             "postgresql://", strategy="mock", executor=executor
         )
 
-    def _select(*columns):
-        return sql.select(list(columns))
+    def _select(*columns, **kw) -> "Select":
+        return sql.select(list(columns), **kw)
diff --git a/docs/build/unreleased/py3_typing.rst b/docs/build/unreleased/py3_typing.rst
new file mode 100644 (file)
index 0000000..7f8aa6c
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: feature, general
+
+    pep-484 type annotations have been added throughout the library. This
+    should be helpful in providing Mypy and IDE support, however there is not
+    full support for Alembic's dynamically modified "op" namespace as of yet; a
+    future release will likely modify the approach used for importing this
+    namespace to be better compatible with pep-484 capabilities.
\ No newline at end of file
index 7514d8bf2d735270de093937a7e8c67b15b6b174..025a93ae8ea3f45f9e856cceac47dc1286c61265 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -91,6 +91,7 @@ import-order-style = google
 application-import-names = alembic,tests
 per-file-ignores =
                 **/__init__.py:F401
+max-line-length = 79
 
 [sqla_testing]
 requirement_cls=tests.requirements:DefaultRequirements
@@ -115,4 +116,12 @@ oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
 addopts= --tb native -v -r sfxX -p no:warnings -p no:logging --maxfail=25
 python_files=tests/test_*.py
 
+[mypy]
+show_error_codes = True
+allow_redefinition = True
 
+[mypy-mako.*]
+ignore_missing_imports = True
+
+[mypy-sqlalchemy.testing.*]
+ignore_missing_imports = True
index c2c1410fac6bb575f4c46789c91d1af73fdce9f1..61998bf8fd684a94cafd5c64933dbead3110ae14 100644 (file)
@@ -9,14 +9,13 @@ from alembic.script.revision import Revision
 from alembic.script.revision import RevisionError
 from alembic.script.revision import RevisionMap
 from alembic.testing import assert_raises_message
-from alembic.testing import config
 from alembic.testing import eq_
+from alembic.testing import expect_raises_message
 from alembic.testing.fixtures import TestBase
 from . import _large_map
 
 
 class APITest(TestBase):
-    @config.requirements.python3
     def test_invalid_datatype(self):
         map_ = RevisionMap(
             lambda: [
@@ -25,29 +24,26 @@ class APITest(TestBase):
                 Revision("c", ("b",)),
             ]
         )
-        assert_raises_message(
+        with expect_raises_message(
             RevisionError,
             "revision identifier b'12345' is not a string; "
             "ensure database driver settings are correct",
-            map_.get_revisions,
-            b"12345",
-        )
+        ):
+            map_.get_revisions(b"12345")
 
-        assert_raises_message(
+        with expect_raises_message(
             RevisionError,
             "revision identifier b'12345' is not a string; "
             "ensure database driver settings are correct",
-            map_.get_revision,
-            b"12345",
-        )
+        ):
+            map_.get_revision(b"12345")
 
-        assert_raises_message(
+        with expect_raises_message(
             RevisionError,
             r"revision identifier \(b'12345',\) is not a string; "
             "ensure database driver settings are correct",
-            map_.get_revision,
-            (b"12345",),
-        )
+        ):
+            map_.get_revision((b"12345",))
 
         map_.get_revision(("a",))
         map_.get_revision("a")
@@ -310,12 +306,12 @@ class LabeledBranchTest(DownIterateTest):
         c1 = map_.get_revision("c1")
         c2 = map_.get_revision("c2")
         d = map_.get_revision("d")
-        eq_(map_.filter_for_lineage([c1, c2, d], "c1branch@head"), [c1, c2, d])
+        eq_(map_.filter_for_lineage([c1, c2, d], "c1branch@head"), (c1, c2, d))
 
     def test_filter_for_lineage_heads(self):
         eq_(
             self.map.filter_for_lineage([self.map.get_revision("f")], "heads"),
-            [self.map.get_revision("f")],
+            (self.map.get_revision("f"),),
         )
 
     def setUp(self):
@@ -333,13 +329,13 @@ class LabeledBranchTest(DownIterateTest):
         )
 
     def test_get_base_revisions_labeled(self):
-        eq_(self.map._get_base_revisions("somelongername@base"), ["a"])
+        eq_(self.map._get_base_revisions("somelongername@base"), ("a",))
 
     def test_get_current_named_rev(self):
         eq_(self.map.get_revision("ebranch@head"), self.map.get_revision("f"))
 
     def test_get_base_revisions(self):
-        eq_(self.map._get_base_revisions("base"), ["a", "d"])
+        eq_(self.map._get_base_revisions("base"), ("a", "d"))
 
     def test_iterate_head_to_named_base(self):
         self._assert_iteration(
diff --git a/tox.ini b/tox.ini
index f5456d8ea1660710f052ca74fb335a6d9838a43a..20e55abbb6664c1fe7824be07671f21409b39a36 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -55,6 +55,19 @@ commands=
   {oracle,mssql}: python reap_dbs.py db_idents.txt
 
 
+[testenv:mypy]
+basepython = python3
+deps=
+    mypy
+    sqlalchemy>=1.4.0
+    sqlalchemy2-stubs
+    mako
+    types-pkg-resources
+    types-python-dateutil
+    # is imported in alembic/testing and mypy complains if it's installed.
+    pytest
+commands = mypy ./alembic/ --exclude alembic/templates
+
 [testenv:pep8]
 basepython = python3
 deps=