]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
run pyupgrade
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 25 Nov 2022 17:29:40 +0000 (12:29 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 Nov 2022 14:18:00 +0000 (09:18 -0500)
command is:

find alembic -name "*.py" | xargs pyupgrade --py37-plus --keep-runtime-typing --keep-percent-format

I'm having some weird fighting with the tools/write_pyi, where
in different runtime contexts it keeps losing "MigrationContext"
and also Callable drops the args, but it's not consisistent.
For whatever reason, under py311 things *do* work every time.
im working w/ clean tox environments so not really sure what the
change is.  anyway, let's at least fix the quoting up
around the types.

This is towards getting the "*" in the op signatures for #1130.

Change-Id: I9175905d3b4325e03a97d6752356b70be20e9fad

41 files changed:
alembic/autogenerate/api.py
alembic/autogenerate/compare.py
alembic/autogenerate/render.py
alembic/autogenerate/rewriter.py
alembic/command.py
alembic/config.py
alembic/ddl/base.py
alembic/ddl/impl.py
alembic/ddl/mssql.py
alembic/ddl/mysql.py
alembic/ddl/oracle.py
alembic/ddl/postgresql.py
alembic/ddl/sqlite.py
alembic/operations/base.py
alembic/operations/batch.py
alembic/operations/ops.py
alembic/operations/schemaobj.py
alembic/runtime/environment.py
alembic/runtime/migration.py
alembic/script/base.py
alembic/script/revision.py
alembic/testing/env.py
alembic/testing/fixtures.py
alembic/testing/suite/_autogen_fixtures.py
alembic/testing/warnings.py
alembic/util/langhelpers.py
alembic/util/messaging.py
alembic/util/sqla_compat.py
tests/test_autogen_composition.py
tests/test_autogen_diffs.py
tests/test_autogen_indexes.py
tests/test_autogen_render.py
tests/test_batch.py
tests/test_command.py
tests/test_config.py
tests/test_environment.py
tests/test_external_dialect.py
tests/test_postgresql.py
tests/test_script_consumption.py
tests/test_script_production.py
tests/test_version_traversal.py

index cbd64e18c387b07ff4ef73a54bc685a88fcd645c..d7a0913d69f82a7511f371088a3976932e9d505c 100644 (file)
@@ -40,7 +40,7 @@ if TYPE_CHECKING:
     from alembic.script.base import ScriptDirectory
 
 
-def compare_metadata(context: "MigrationContext", metadata: "MetaData") -> Any:
+def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any:
     """Compare a database schema to that given in a
     :class:`~sqlalchemy.schema.MetaData` instance.
 
@@ -136,8 +136,8 @@ def compare_metadata(context: "MigrationContext", metadata: "MetaData") -> Any:
 
 
 def produce_migrations(
-    context: "MigrationContext", metadata: "MetaData"
-) -> "MigrationScript":
+    context: MigrationContext, metadata: MetaData
+) -> MigrationScript:
     """Produce a :class:`.MigrationScript` structure based on schema
     comparison.
 
@@ -167,13 +167,13 @@ def produce_migrations(
 
 
 def render_python_code(
-    up_or_down_op: "UpgradeOps",
+    up_or_down_op: UpgradeOps,
     sqlalchemy_module_prefix: str = "sa.",
     alembic_module_prefix: str = "op.",
     render_as_batch: bool = False,
     imports: Tuple[str, ...] = (),
     render_item: None = None,
-    migration_context: Optional["MigrationContext"] = None,
+    migration_context: Optional[MigrationContext] = None,
 ) -> str:
     """Render Python code given an :class:`.UpgradeOps` or
     :class:`.DowngradeOps` object.
@@ -205,7 +205,7 @@ def render_python_code(
 
 
 def _render_migration_diffs(
-    context: "MigrationContext", template_args: Dict[Any, Any]
+    context: MigrationContext, template_args: Dict[Any, Any]
 ) -> None:
     """legacy, used by test_autogen_composition at the moment"""
 
@@ -229,7 +229,7 @@ class AutogenContext:
     """Maintains configuration and state that's specific to an
     autogenerate operation."""
 
-    metadata: Optional["MetaData"] = None
+    metadata: Optional[MetaData] = None
     """The :class:`~sqlalchemy.schema.MetaData` object
     representing the destination.
 
@@ -247,7 +247,7 @@ class AutogenContext:
 
     """
 
-    connection: Optional["Connection"] = None
+    connection: Optional[Connection] = None
     """The :class:`~sqlalchemy.engine.base.Connection` object currently
     connected to the database backend being compared.
 
@@ -256,7 +256,7 @@ class AutogenContext:
 
     """
 
-    dialect: Optional["Dialect"] = None
+    dialect: Optional[Dialect] = None
     """The :class:`~sqlalchemy.engine.Dialect` object currently in use.
 
     This is normally obtained from the
@@ -278,13 +278,13 @@ class AutogenContext:
 
     """
 
-    migration_context: "MigrationContext" = None  # type: ignore[assignment]
+    migration_context: MigrationContext = None  # type: ignore[assignment]
     """The :class:`.MigrationContext` established by the ``env.py`` script."""
 
     def __init__(
         self,
-        migration_context: "MigrationContext",
-        metadata: Optional["MetaData"] = None,
+        migration_context: MigrationContext,
+        metadata: Optional[MetaData] = None,
         opts: Optional[dict] = None,
         autogenerate: bool = True,
     ) -> None:
@@ -342,7 +342,7 @@ class AutogenContext:
         self._has_batch: bool = False
 
     @util.memoized_property
-    def inspector(self) -> "Inspector":
+    def inspector(self) -> Inspector:
         if self.connection is None:
             raise TypeError(
                 "can't return inspector as this "
@@ -397,18 +397,16 @@ class AutogenContext:
     def run_object_filters(
         self,
         object_: Union[
-            "Table",
-            "Index",
-            "Column",
-            "UniqueConstraint",
-            "ForeignKeyConstraint",
+            Table,
+            Index,
+            Column,
+            UniqueConstraint,
+            ForeignKeyConstraint,
         ],
         name: Optional[str],
         type_: str,
         reflected: bool,
-        compare_to: Optional[
-            Union["Table", "Index", "Column", "UniqueConstraint"]
-        ],
+        compare_to: Optional[Union[Table, Index, Column, UniqueConstraint]],
     ) -> bool:
         """Run the context's object filters and return True if the targets
         should be part of the autogenerate operation.
@@ -476,8 +474,8 @@ class RevisionContext:
 
     def __init__(
         self,
-        config: "Config",
-        script_directory: "ScriptDirectory",
+        config: Config,
+        script_directory: ScriptDirectory,
         command_args: Dict[str, Any],
         process_revision_directives: Optional[Callable] = None,
     ) -> None:
@@ -492,8 +490,8 @@ class RevisionContext:
         self.generated_revisions = [self._default_revision()]
 
     def _to_script(
-        self, migration_script: "MigrationScript"
-    ) -> Optional["Script"]:
+        self, migration_script: MigrationScript
+    ) -> Optional[Script]:
         template_args: Dict[str, Any] = self.template_args.copy()
 
         if getattr(migration_script, "_needs_render", False):
@@ -522,19 +520,19 @@ class RevisionContext:
         )
 
     def run_autogenerate(
-        self, rev: tuple, migration_context: "MigrationContext"
+        self, rev: tuple, migration_context: MigrationContext
     ) -> None:
         self._run_environment(rev, migration_context, True)
 
     def run_no_autogenerate(
-        self, rev: tuple, migration_context: "MigrationContext"
+        self, rev: tuple, migration_context: MigrationContext
     ) -> None:
         self._run_environment(rev, migration_context, False)
 
     def _run_environment(
         self,
         rev: tuple,
-        migration_context: "MigrationContext",
+        migration_context: MigrationContext,
         autogenerate: bool,
     ) -> None:
         if autogenerate:
@@ -587,7 +585,7 @@ class RevisionContext:
         for migration_script in self.generated_revisions:
             migration_script._needs_render = True
 
-    def _default_revision(self) -> "MigrationScript":
+    def _default_revision(self) -> MigrationScript:
         command_args: Dict[str, Any] = self.command_args
         op = ops.MigrationScript(
             rev_id=command_args["rev_id"] or util.rev_id(),
@@ -602,6 +600,6 @@ class RevisionContext:
         )
         return op
 
-    def generate_scripts(self) -> Iterator[Optional["Script"]]:
+    def generate_scripts(self) -> Iterator[Optional[Script]]:
         for generated_revision in self.generated_revisions:
             yield self._to_script(generated_revision)
index c32ab4d9bb9f946af31ab32035888bb1c03922fe..c9971ea321fe09eb9f1b1271e43eb203cfb6f451 100644 (file)
@@ -47,7 +47,7 @@ log = logging.getLogger(__name__)
 
 
 def _populate_migration_script(
-    autogen_context: "AutogenContext", migration_script: "MigrationScript"
+    autogen_context: AutogenContext, migration_script: MigrationScript
 ) -> None:
     upgrade_ops = migration_script.upgrade_ops_list[-1]
     downgrade_ops = migration_script.downgrade_ops_list[-1]
@@ -60,14 +60,14 @@ comparators = util.Dispatcher(uselist=True)
 
 
 def _produce_net_changes(
-    autogen_context: "AutogenContext", upgrade_ops: "UpgradeOps"
+    autogen_context: AutogenContext, upgrade_ops: UpgradeOps
 ) -> None:
 
     connection = autogen_context.connection
     assert connection is not None
     include_schemas = autogen_context.opts.get("include_schemas", False)
 
-    inspector: "Inspector" = inspect(connection)
+    inspector: Inspector = inspect(connection)
 
     default_schema = connection.dialect.default_schema_name
     schemas: Set[Optional[str]]
@@ -93,8 +93,8 @@ def _produce_net_changes(
 
 @comparators.dispatch_for("schema")
 def _autogen_for_tables(
-    autogen_context: "AutogenContext",
-    upgrade_ops: "UpgradeOps",
+    autogen_context: AutogenContext,
+    upgrade_ops: UpgradeOps,
     schemas: Union[Set[None], Set[Optional[str]]],
 ) -> None:
     inspector = autogen_context.inspector
@@ -135,11 +135,11 @@ def _autogen_for_tables(
 
 
 def _compare_tables(
-    conn_table_names: "set",
-    metadata_table_names: "set",
-    inspector: "Inspector",
-    upgrade_ops: "UpgradeOps",
-    autogen_context: "AutogenContext",
+    conn_table_names: set,
+    metadata_table_names: set,
+    inspector: Inspector,
+    upgrade_ops: UpgradeOps,
+    autogen_context: AutogenContext,
 ) -> None:
 
     default_schema = inspector.bind.dialect.default_schema_name
@@ -159,17 +159,14 @@ def _compare_tables(
     # to adjust for the MetaData collection storing the tables either
     # as "schemaname.tablename" or just "tablename", create a new lookup
     # which will match the "non-default-schema" keys to the Table object.
-    tname_to_table = dict(
-        (
-            no_dflt_schema,
-            autogen_context.table_key_to_table[
-                sa_schema._get_table_key(tname, schema)
-            ],
-        )
+    tname_to_table = {
+        no_dflt_schema: autogen_context.table_key_to_table[
+            sa_schema._get_table_key(tname, schema)
+        ]
         for no_dflt_schema, (schema, tname) in zip(
             metadata_table_names_no_dflt_schema, metadata_table_names
         )
-    )
+    }
     metadata_table_names = metadata_table_names_no_dflt_schema
 
     for s, tname in metadata_table_names.difference(conn_table_names):
@@ -279,9 +276,7 @@ def _compare_tables(
                 upgrade_ops.ops.append(modify_table_ops)
 
 
-def _make_index(
-    params: Dict[str, Any], conn_table: "Table"
-) -> Optional["Index"]:
+def _make_index(params: Dict[str, Any], conn_table: Table) -> Optional[Index]:
     exprs = []
     for col_name in params["column_names"]:
         if col_name is None:
@@ -302,8 +297,8 @@ def _make_index(
 
 
 def _make_unique_constraint(
-    params: Dict[str, Any], conn_table: "Table"
-) -> "UniqueConstraint":
+    params: Dict[str, Any], conn_table: Table
+) -> UniqueConstraint:
     uq = sa_schema.UniqueConstraint(
         *[conn_table.c[cname] for cname in params["column_names"]],
         name=params["name"],
@@ -315,8 +310,8 @@ def _make_unique_constraint(
 
 
 def _make_foreign_key(
-    params: Dict[str, Any], conn_table: "Table"
-) -> "ForeignKeyConstraint":
+    params: Dict[str, Any], conn_table: Table
+) -> ForeignKeyConstraint:
     tname = params["referred_table"]
     if params["referred_schema"]:
         tname = "%s.%s" % (params["referred_schema"], tname)
@@ -340,12 +335,12 @@ def _make_foreign_key(
 @contextlib.contextmanager
 def _compare_columns(
     schema: Optional[str],
-    tname: Union["quoted_name", str],
-    conn_table: "Table",
-    metadata_table: "Table",
-    modify_table_ops: "ModifyTableOps",
-    autogen_context: "AutogenContext",
-    inspector: "Inspector",
+    tname: Union[quoted_name, str],
+    conn_table: Table,
+    metadata_table: Table,
+    modify_table_ops: ModifyTableOps,
+    autogen_context: AutogenContext,
+    inspector: Inspector,
 ) -> Iterator[None]:
     name = "%s.%s" % (schema, tname) if schema else tname
     metadata_col_names = OrderedSet(
@@ -411,9 +406,9 @@ def _compare_columns(
 
 
 class _constraint_sig:
-    const: Union["UniqueConstraint", "ForeignKeyConstraint", "Index"]
+    const: Union[UniqueConstraint, ForeignKeyConstraint, Index]
 
-    def md_name_to_sql_name(self, context: "AutogenContext") -> Optional[str]:
+    def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
         return sqla_compat._get_constraint_final_name(
             self.const, context.dialect
         )
@@ -432,7 +427,7 @@ class _uq_constraint_sig(_constraint_sig):
     is_index = False
     is_unique = True
 
-    def __init__(self, const: "UniqueConstraint") -> None:
+    def __init__(self, const: UniqueConstraint) -> None:
         self.const = const
         self.name = const.name
         self.sig = tuple(sorted([col.name for col in const.columns]))
@@ -445,25 +440,25 @@ class _uq_constraint_sig(_constraint_sig):
 class _ix_constraint_sig(_constraint_sig):
     is_index = True
 
-    def __init__(self, const: "Index") -> None:
+    def __init__(self, const: Index) -> None:
         self.const = const
         self.name = const.name
         self.sig = tuple(sorted([col.name for col in const.columns]))
         self.is_unique = bool(const.unique)
 
-    def md_name_to_sql_name(self, context: "AutogenContext") -> Optional[str]:
+    def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
         return sqla_compat._get_constraint_final_name(
             self.const, context.dialect
         )
 
     @property
-    def column_names(self) -> Union[List["quoted_name"], List[None]]:
+    def column_names(self) -> Union[List[quoted_name], List[None]]:
         return sqla_compat._get_index_column_names(self.const)
 
 
 class _fk_constraint_sig(_constraint_sig):
     def __init__(
-        self, const: "ForeignKeyConstraint", include_options: bool = False
+        self, const: ForeignKeyConstraint, include_options: bool = False
     ) -> None:
         self.const = const
         self.name = const.name
@@ -508,12 +503,12 @@ class _fk_constraint_sig(_constraint_sig):
 
 @comparators.dispatch_for("table")
 def _compare_indexes_and_uniques(
-    autogen_context: "AutogenContext",
-    modify_ops: "ModifyTableOps",
+    autogen_context: AutogenContext,
+    modify_ops: ModifyTableOps,
     schema: Optional[str],
-    tname: Union["quoted_name", str],
-    conn_table: Optional["Table"],
-    metadata_table: Optional["Table"],
+    tname: Union[quoted_name, str],
+    conn_table: Optional[Table],
+    metadata_table: Optional[Table],
 ) -> None:
 
     inspector = autogen_context.inspector
@@ -522,11 +517,11 @@ def _compare_indexes_and_uniques(
 
     # 1a. get raw indexes and unique constraints from metadata ...
     if metadata_table is not None:
-        metadata_unique_constraints = set(
+        metadata_unique_constraints = {
             uq
             for uq in metadata_table.constraints
             if isinstance(uq, sa_schema.UniqueConstraint)
-        )
+        }
         metadata_indexes = set(metadata_table.indexes)
     else:
         metadata_unique_constraints = set()
@@ -589,16 +584,16 @@ def _compare_indexes_and_uniques(
             # for DROP TABLE uniques are inline, don't need them
             conn_uniques = set()  # type:ignore[assignment]
         else:
-            conn_uniques = set(  # type:ignore[assignment]
+            conn_uniques = {  # type:ignore[assignment]
                 _make_unique_constraint(uq_def, conn_table)
                 for uq_def in conn_uniques
-            )
+            }
 
-        conn_indexes = set(  # type:ignore[assignment]
+        conn_indexes = {  # type:ignore[assignment]
             index
             for index in (_make_index(ix, conn_table) for ix in conn_indexes)
             if index is not None
-        )
+        }
 
     # 2a. if the dialect dupes unique indexes as unique constraints
     # (mysql and oracle), correct for that
@@ -626,63 +621,59 @@ def _compare_indexes_and_uniques(
     # _constraint_sig() objects provide a consistent facade over both
     # Index and UniqueConstraint so we can easily work with them
     # interchangeably
-    metadata_unique_constraints_sig = set(
+    metadata_unique_constraints_sig = {
         _uq_constraint_sig(uq) for uq in metadata_unique_constraints
-    )
+    }
 
-    metadata_indexes_sig = set(
-        _ix_constraint_sig(ix) for ix in metadata_indexes
-    )
+    metadata_indexes_sig = {_ix_constraint_sig(ix) for ix in metadata_indexes}
 
-    conn_unique_constraints = set(
-        _uq_constraint_sig(uq) for uq in conn_uniques
-    )
+    conn_unique_constraints = {_uq_constraint_sig(uq) for uq in conn_uniques}
 
-    conn_indexes_sig = set(_ix_constraint_sig(ix) for ix in conn_indexes)
+    conn_indexes_sig = {_ix_constraint_sig(ix) for ix in conn_indexes}
 
     # 5. index things by name, for those objects that have names
-    metadata_names = dict(
-        (cast(str, c.md_name_to_sql_name(autogen_context)), c)
+    metadata_names = {
+        cast(str, c.md_name_to_sql_name(autogen_context)): c
         for c in metadata_unique_constraints_sig.union(
             metadata_indexes_sig  # type:ignore[arg-type]
         )
         if isinstance(c, _ix_constraint_sig)
         or sqla_compat._constraint_is_named(c.const, autogen_context.dialect)
-    )
+    }
 
-    conn_uniques_by_name = dict((c.name, c) for c in conn_unique_constraints)
-    conn_indexes_by_name: Dict[Optional[str], _ix_constraint_sig] = dict(
-        (c.name, c) for c in conn_indexes_sig
-    )
-    conn_names = dict(
-        (c.name, c)
+    conn_uniques_by_name = {c.name: c for c in conn_unique_constraints}
+    conn_indexes_by_name: Dict[Optional[str], _ix_constraint_sig] = {
+        c.name: c for c in conn_indexes_sig
+    }
+    conn_names = {
+        c.name: c
         for c in conn_unique_constraints.union(
             conn_indexes_sig  # type:ignore[arg-type]
         )
         if c.name is not None
-    )
+    }
 
-    doubled_constraints = dict(
-        (name, (conn_uniques_by_name[name], conn_indexes_by_name[name]))
+    doubled_constraints = {
+        name: (conn_uniques_by_name[name], conn_indexes_by_name[name])
         for name in set(conn_uniques_by_name).intersection(
             conn_indexes_by_name
         )
-    )
+    }
 
     # 6. index things by "column signature", to help with unnamed unique
     # constraints.
-    conn_uniques_by_sig = dict((uq.sig, uq) for uq in conn_unique_constraints)
-    metadata_uniques_by_sig = dict(
-        (uq.sig, uq) for uq in metadata_unique_constraints_sig
-    )
-    metadata_indexes_by_sig = dict((ix.sig, ix) for ix in metadata_indexes_sig)
-    unnamed_metadata_uniques = dict(
-        (uq.sig, uq)
+    conn_uniques_by_sig = {uq.sig: uq for uq in conn_unique_constraints}
+    metadata_uniques_by_sig = {
+        uq.sig: uq for uq in metadata_unique_constraints_sig
+    }
+    metadata_indexes_by_sig = {ix.sig: ix for ix in metadata_indexes_sig}
+    unnamed_metadata_uniques = {
+        uq.sig: uq
         for uq in metadata_unique_constraints_sig
         if not sqla_compat._constraint_is_named(
             uq.const, autogen_context.dialect
         )
-    )
+    }
 
     # assumptions:
     # 1. a unique constraint or an index from the connection *always*
@@ -864,37 +855,31 @@ def _correct_for_uq_duplicates_uix(
         for cons in metadata_unique_constraints
     ]
 
-    metadata_uq_names = set(
+    metadata_uq_names = {
         name for name, cons in metadata_cons_names if name is not None
-    )
+    }
 
-    unnamed_metadata_uqs = set(
-        [
-            _uq_constraint_sig(cons).sig
-            for name, cons in metadata_cons_names
-            if name is None
-        ]
-    )
+    unnamed_metadata_uqs = {
+        _uq_constraint_sig(cons).sig
+        for name, cons in metadata_cons_names
+        if name is None
+    }
 
-    metadata_ix_names = set(
-        [
-            sqla_compat._get_constraint_final_name(cons, dialect)
-            for cons in metadata_indexes
-            if cons.unique
-        ]
-    )
+    metadata_ix_names = {
+        sqla_compat._get_constraint_final_name(cons, dialect)
+        for cons in metadata_indexes
+        if cons.unique
+    }
 
     # for reflection side, names are in their final database form
     # already since they're from the database
-    conn_ix_names = dict(
-        (cons.name, cons) for cons in conn_indexes if cons.unique
-    )
+    conn_ix_names = {cons.name: cons for cons in conn_indexes if cons.unique}
 
-    uqs_dupe_indexes = dict(
-        (cons.name, cons)
+    uqs_dupe_indexes = {
+        cons.name: cons
         for cons in conn_unique_constraints
         if cons.info["duplicates_index"]
-    )
+    }
 
     for overlap in uqs_dupe_indexes:
         if overlap not in metadata_uq_names:
@@ -910,13 +895,13 @@ def _correct_for_uq_duplicates_uix(
 
 @comparators.dispatch_for("column")
 def _compare_nullable(
-    autogen_context: "AutogenContext",
-    alter_column_op: "AlterColumnOp",
+    autogen_context: AutogenContext,
+    alter_column_op: AlterColumnOp,
     schema: Optional[str],
-    tname: Union["quoted_name", str],
-    cname: Union["quoted_name", str],
-    conn_col: "Column",
-    metadata_col: "Column",
+    tname: Union[quoted_name, str],
+    cname: Union[quoted_name, str],
+    conn_col: Column,
+    metadata_col: Column,
 ) -> None:
 
     metadata_col_nullable = metadata_col.nullable
@@ -952,13 +937,13 @@ def _compare_nullable(
 
 @comparators.dispatch_for("column")
 def _setup_autoincrement(
-    autogen_context: "AutogenContext",
-    alter_column_op: "AlterColumnOp",
+    autogen_context: AutogenContext,
+    alter_column_op: AlterColumnOp,
     schema: Optional[str],
-    tname: Union["quoted_name", str],
-    cname: "quoted_name",
-    conn_col: "Column",
-    metadata_col: "Column",
+    tname: Union[quoted_name, str],
+    cname: quoted_name,
+    conn_col: Column,
+    metadata_col: Column,
 ) -> None:
 
     if metadata_col.table._autoincrement_column is metadata_col:
@@ -971,13 +956,13 @@ def _setup_autoincrement(
 
 @comparators.dispatch_for("column")
 def _compare_type(
-    autogen_context: "AutogenContext",
-    alter_column_op: "AlterColumnOp",
+    autogen_context: AutogenContext,
+    alter_column_op: AlterColumnOp,
     schema: Optional[str],
-    tname: Union["quoted_name", str],
-    cname: Union["quoted_name", str],
-    conn_col: "Column",
-    metadata_col: "Column",
+    tname: Union[quoted_name, str],
+    cname: Union[quoted_name, str],
+    conn_col: Column,
+    metadata_col: Column,
 ) -> None:
 
     conn_type = conn_col.type
@@ -1015,8 +1000,8 @@ def _compare_type(
 
 def _render_server_default_for_compare(
     metadata_default: Optional[Any],
-    metadata_col: "Column",
-    autogen_context: "AutogenContext",
+    metadata_col: Column,
+    autogen_context: AutogenContext,
 ) -> Optional[str]:
     rendered = _user_defined_render(
         "server_default", metadata_default, autogen_context
@@ -1055,13 +1040,13 @@ def _normalize_computed_default(sqltext: str) -> str:
 
 
 def _compare_computed_default(
-    autogen_context: "AutogenContext",
-    alter_column_op: "AlterColumnOp",
+    autogen_context: AutogenContext,
+    alter_column_op: AlterColumnOp,
     schema: Optional[str],
-    tname: "str",
-    cname: "str",
-    conn_col: "Column",
-    metadata_col: "Column",
+    tname: str,
+    cname: str,
+    conn_col: Column,
+    metadata_col: Column,
 ) -> None:
     rendered_metadata_default = str(
         cast(sa_schema.Computed, metadata_col.server_default).sqltext.compile(
@@ -1121,13 +1106,13 @@ def _compare_identity_default(
 
 @comparators.dispatch_for("column")
 def _compare_server_default(
-    autogen_context: "AutogenContext",
-    alter_column_op: "AlterColumnOp",
+    autogen_context: AutogenContext,
+    alter_column_op: AlterColumnOp,
     schema: Optional[str],
-    tname: Union["quoted_name", str],
-    cname: Union["quoted_name", str],
-    conn_col: "Column",
-    metadata_col: "Column",
+    tname: Union[quoted_name, str],
+    cname: Union[quoted_name, str],
+    conn_col: Column,
+    metadata_col: Column,
 ) -> Optional[bool]:
 
     metadata_default = metadata_col.server_default
@@ -1210,14 +1195,14 @@ def _compare_server_default(
 
 @comparators.dispatch_for("column")
 def _compare_column_comment(
-    autogen_context: "AutogenContext",
-    alter_column_op: "AlterColumnOp",
+    autogen_context: AutogenContext,
+    alter_column_op: AlterColumnOp,
     schema: Optional[str],
-    tname: Union["quoted_name", str],
-    cname: "quoted_name",
-    conn_col: "Column",
-    metadata_col: "Column",
-) -> Optional["Literal[False]"]:
+    tname: Union[quoted_name, str],
+    cname: quoted_name,
+    conn_col: Column,
+    metadata_col: Column,
+) -> Optional[Literal[False]]:
 
     assert autogen_context.dialect is not None
     if not autogen_context.dialect.supports_comments:
@@ -1239,12 +1224,12 @@ def _compare_column_comment(
 
 @comparators.dispatch_for("table")
 def _compare_foreign_keys(
-    autogen_context: "AutogenContext",
-    modify_table_ops: "ModifyTableOps",
+    autogen_context: AutogenContext,
+    modify_table_ops: ModifyTableOps,
     schema: Optional[str],
-    tname: Union["quoted_name", str],
-    conn_table: Optional["Table"],
-    metadata_table: Optional["Table"],
+    tname: Union[quoted_name, str],
+    conn_table: Optional[Table],
+    metadata_table: Optional[Table],
 ) -> None:
 
     # if we're doing CREATE TABLE, all FKs are created
@@ -1253,11 +1238,11 @@ def _compare_foreign_keys(
         return
 
     inspector = autogen_context.inspector
-    metadata_fks = set(
+    metadata_fks = {
         fk
         for fk in metadata_table.constraints
         if isinstance(fk, sa_schema.ForeignKeyConstraint)
-    )
+    }
 
     conn_fks_list = [
         fk
@@ -1273,10 +1258,10 @@ def _compare_foreign_keys(
         conn_fks_list and "options" in conn_fks_list[0]
     )
 
-    conn_fks = set(
+    conn_fks = {
         _make_foreign_key(const, conn_table)  # type: ignore[arg-type]
         for const in conn_fks_list
-    )
+    }
 
     # give the dialect a chance to correct the FKs to match more
     # closely
@@ -1284,25 +1269,23 @@ def _compare_foreign_keys(
         conn_fks, metadata_fks
     )
 
-    metadata_fks_sig = set(
+    metadata_fks_sig = {
         _fk_constraint_sig(fk, include_options=backend_reflects_fk_options)
         for fk in metadata_fks
-    )
+    }
 
-    conn_fks_sig = set(
+    conn_fks_sig = {
         _fk_constraint_sig(fk, include_options=backend_reflects_fk_options)
         for fk in conn_fks
-    )
+    }
 
-    conn_fks_by_sig = dict((c.sig, c) for c in conn_fks_sig)
-    metadata_fks_by_sig = dict((c.sig, c) for c in metadata_fks_sig)
+    conn_fks_by_sig = {c.sig: c for c in conn_fks_sig}
+    metadata_fks_by_sig = {c.sig: c for c in metadata_fks_sig}
 
-    metadata_fks_by_name = dict(
-        (c.name, c) for c in metadata_fks_sig if c.name is not None
-    )
-    conn_fks_by_name = dict(
-        (c.name, c) for c in conn_fks_sig if c.name is not None
-    )
+    metadata_fks_by_name = {
+        c.name: c for c in metadata_fks_sig if c.name is not None
+    }
+    conn_fks_by_name = {c.name: c for c in conn_fks_sig if c.name is not None}
 
     def _add_fk(obj, compare_to):
         if autogen_context.run_object_filters(
@@ -1361,12 +1344,12 @@ def _compare_foreign_keys(
 
 @comparators.dispatch_for("table")
 def _compare_table_comment(
-    autogen_context: "AutogenContext",
-    modify_table_ops: "ModifyTableOps",
+    autogen_context: AutogenContext,
+    modify_table_ops: ModifyTableOps,
     schema: Optional[str],
-    tname: Union["quoted_name", str],
-    conn_table: Optional["Table"],
-    metadata_table: Optional["Table"],
+    tname: Union[quoted_name, str],
+    conn_table: Optional[Table],
+    metadata_table: Optional[Table],
 ) -> None:
 
     assert autogen_context.dialect is not None
index 1ac6753d9f1efa842f626ba71691313d5d2f835e..41903d81ec8444f29b426c058e978eed365c199e 100644 (file)
@@ -54,9 +54,9 @@ MAX_PYTHON_ARGS = 255
 
 
 def _render_gen_name(
-    autogen_context: "AutogenContext",
-    name: Optional[Union["quoted_name", str]],
-) -> Optional[Union["quoted_name", str, "_f_name"]]:
+    autogen_context: AutogenContext,
+    name: Optional[Union[quoted_name, str]],
+) -> Optional[Union[quoted_name, str, _f_name]]:
     if isinstance(name, conv):
         return _f_name(_alembic_autogenerate_prefix(autogen_context), name)
     else:
@@ -70,9 +70,9 @@ def _indent(text: str) -> str:
 
 
 def _render_python_into_templatevars(
-    autogen_context: "AutogenContext",
-    migration_script: "MigrationScript",
-    template_args: Dict[str, Union[str, "Config"]],
+    autogen_context: AutogenContext,
+    migration_script: MigrationScript,
+    template_args: Dict[str, Union[str, Config]],
 ) -> None:
     imports = autogen_context.imports
 
@@ -92,8 +92,8 @@ default_renderers = renderers = util.Dispatcher()
 
 
 def _render_cmd_body(
-    op_container: "ops.OpContainer",
-    autogen_context: "AutogenContext",
+    op_container: ops.OpContainer,
+    autogen_context: AutogenContext,
 ) -> str:
 
     buf = StringIO()
@@ -120,7 +120,7 @@ def _render_cmd_body(
 
 
 def render_op(
-    autogen_context: "AutogenContext", op: "ops.MigrateOperation"
+    autogen_context: AutogenContext, op: ops.MigrateOperation
 ) -> List[str]:
     renderer = renderers.dispatch(op)
     lines = util.to_list(renderer(autogen_context, op))
@@ -128,14 +128,14 @@ def render_op(
 
 
 def render_op_text(
-    autogen_context: "AutogenContext", op: "ops.MigrateOperation"
+    autogen_context: AutogenContext, op: ops.MigrateOperation
 ) -> str:
     return "\n".join(render_op(autogen_context, op))
 
 
 @renderers.dispatch_for(ops.ModifyTableOps)
 def _render_modify_table(
-    autogen_context: "AutogenContext", op: "ModifyTableOps"
+    autogen_context: AutogenContext, op: ModifyTableOps
 ) -> List[str]:
     opts = autogen_context.opts
     render_as_batch = opts.get("render_as_batch", False)
@@ -164,7 +164,7 @@ def _render_modify_table(
 
 @renderers.dispatch_for(ops.CreateTableCommentOp)
 def _render_create_table_comment(
-    autogen_context: "AutogenContext", op: "ops.CreateTableCommentOp"
+    autogen_context: AutogenContext, op: ops.CreateTableCommentOp
 ) -> str:
 
     templ = (
@@ -189,7 +189,7 @@ def _render_create_table_comment(
 
 @renderers.dispatch_for(ops.DropTableCommentOp)
 def _render_drop_table_comment(
-    autogen_context: "AutogenContext", op: "ops.DropTableCommentOp"
+    autogen_context: AutogenContext, op: ops.DropTableCommentOp
 ) -> str:
 
     templ = (
@@ -211,9 +211,7 @@ def _render_drop_table_comment(
 
 
 @renderers.dispatch_for(ops.CreateTableOp)
-def _add_table(
-    autogen_context: "AutogenContext", op: "ops.CreateTableOp"
-) -> str:
+def _add_table(autogen_context: AutogenContext, op: ops.CreateTableOp) -> str:
     table = op.to_table()
 
     args = [
@@ -263,9 +261,7 @@ def _add_table(
 
 
 @renderers.dispatch_for(ops.DropTableOp)
-def _drop_table(
-    autogen_context: "AutogenContext", op: "ops.DropTableOp"
-) -> str:
+def _drop_table(autogen_context: AutogenContext, op: ops.DropTableOp) -> str:
     text = "%(prefix)sdrop_table(%(tname)r" % {
         "prefix": _alembic_autogenerate_prefix(autogen_context),
         "tname": _ident(op.table_name),
@@ -277,9 +273,7 @@ def _drop_table(
 
 
 @renderers.dispatch_for(ops.CreateIndexOp)
-def _add_index(
-    autogen_context: "AutogenContext", op: "ops.CreateIndexOp"
-) -> str:
+def _add_index(autogen_context: AutogenContext, op: ops.CreateIndexOp) -> str:
     index = op.to_index()
 
     has_batch = autogen_context._has_batch
@@ -324,9 +318,7 @@ def _add_index(
 
 
 @renderers.dispatch_for(ops.DropIndexOp)
-def _drop_index(
-    autogen_context: "AutogenContext", op: "ops.DropIndexOp"
-) -> str:
+def _drop_index(autogen_context: AutogenContext, op: ops.DropIndexOp) -> str:
     index = op.to_index()
 
     has_batch = autogen_context._has_batch
@@ -362,14 +354,14 @@ def _drop_index(
 
 @renderers.dispatch_for(ops.CreateUniqueConstraintOp)
 def _add_unique_constraint(
-    autogen_context: "AutogenContext", op: "ops.CreateUniqueConstraintOp"
+    autogen_context: AutogenContext, op: ops.CreateUniqueConstraintOp
 ) -> List[str]:
     return [_uq_constraint(op.to_constraint(), autogen_context, True)]
 
 
 @renderers.dispatch_for(ops.CreateForeignKeyOp)
 def _add_fk_constraint(
-    autogen_context: "AutogenContext", op: "ops.CreateForeignKeyOp"
+    autogen_context: AutogenContext, op: ops.CreateForeignKeyOp
 ) -> str:
 
     args = [repr(_render_gen_name(autogen_context, op.constraint_name))]
@@ -418,7 +410,7 @@ def _add_check_constraint(constraint, autogen_context):
 
 @renderers.dispatch_for(ops.DropConstraintOp)
 def _drop_constraint(
-    autogen_context: "AutogenContext", op: "ops.DropConstraintOp"
+    autogen_context: AutogenContext, op: ops.DropConstraintOp
 ) -> str:
 
     if autogen_context._has_batch:
@@ -440,9 +432,7 @@ def _drop_constraint(
 
 
 @renderers.dispatch_for(ops.AddColumnOp)
-def _add_column(
-    autogen_context: "AutogenContext", op: "ops.AddColumnOp"
-) -> str:
+def _add_column(autogen_context: AutogenContext, op: ops.AddColumnOp) -> str:
 
     schema, tname, column = op.schema, op.table_name, op.column
     if autogen_context._has_batch:
@@ -462,9 +452,7 @@ def _add_column(
 
 
 @renderers.dispatch_for(ops.DropColumnOp)
-def _drop_column(
-    autogen_context: "AutogenContext", op: "ops.DropColumnOp"
-) -> str:
+def _drop_column(autogen_context: AutogenContext, op: ops.DropColumnOp) -> str:
 
     schema, tname, column_name = op.schema, op.table_name, op.column_name
 
@@ -487,7 +475,7 @@ def _drop_column(
 
 @renderers.dispatch_for(ops.AlterColumnOp)
 def _alter_column(
-    autogen_context: "AutogenContext", op: "ops.AlterColumnOp"
+    autogen_context: AutogenContext, op: ops.AlterColumnOp
 ) -> str:
 
     tname = op.table_name
@@ -556,7 +544,7 @@ class _f_name:
         return "%sf(%r)" % (self.prefix, _ident(self.name))
 
 
-def _ident(name: Optional[Union["quoted_name", str]]) -> Optional[str]:
+def _ident(name: Optional[Union[quoted_name, str]]) -> Optional[str]:
     """produce a __repr__() object for a string identifier that may
     use quoted_name() in SQLAlchemy 0.9 and greater.
 
@@ -574,7 +562,7 @@ def _ident(name: Optional[Union["quoted_name", str]]) -> Optional[str]:
 
 def _render_potential_expr(
     value: Any,
-    autogen_context: "AutogenContext",
+    autogen_context: AutogenContext,
     wrap_in_text: bool = True,
     is_server_default: bool = False,
 ) -> str:
@@ -597,7 +585,7 @@ def _render_potential_expr(
 
 
 def _get_index_rendered_expressions(
-    idx: "Index", autogen_context: "AutogenContext"
+    idx: Index, autogen_context: AutogenContext
 ) -> List[str]:
     return [
         repr(_ident(getattr(exp, "name", None)))
@@ -608,8 +596,8 @@ def _get_index_rendered_expressions(
 
 
 def _uq_constraint(
-    constraint: "UniqueConstraint",
-    autogen_context: "AutogenContext",
+    constraint: UniqueConstraint,
+    autogen_context: AutogenContext,
     alter: bool,
 ) -> str:
     opts: List[Tuple[str, Any]] = []
@@ -654,11 +642,11 @@ def _user_autogenerate_prefix(autogen_context, target):
         return prefix
 
 
-def _sqlalchemy_autogenerate_prefix(autogen_context: "AutogenContext") -> str:
+def _sqlalchemy_autogenerate_prefix(autogen_context: AutogenContext) -> str:
     return autogen_context.opts["sqlalchemy_module_prefix"] or ""
 
 
-def _alembic_autogenerate_prefix(autogen_context: "AutogenContext") -> str:
+def _alembic_autogenerate_prefix(autogen_context: AutogenContext) -> str:
     if autogen_context._has_batch:
         return "batch_op."
     else:
@@ -666,8 +654,8 @@ def _alembic_autogenerate_prefix(autogen_context: "AutogenContext") -> str:
 
 
 def _user_defined_render(
-    type_: str, object_: Any, autogen_context: "AutogenContext"
-) -> Union[str, "Literal[False]"]:
+    type_: str, object_: Any, autogen_context: AutogenContext
+) -> Union[str, Literal[False]]:
     if "render_item" in autogen_context.opts:
         render = autogen_context.opts["render_item"]
         if render:
@@ -677,7 +665,7 @@ def _user_defined_render(
     return False
 
 
-def _render_column(column: "Column", autogen_context: "AutogenContext") -> str:
+def _render_column(column: Column, autogen_context: AutogenContext) -> str:
     rendered = _user_defined_render("column", column, autogen_context)
     if rendered is not False:
         return rendered
@@ -734,7 +722,7 @@ def _render_column(column: "Column", autogen_context: "AutogenContext") -> str:
 
 
 def _should_render_server_default_positionally(
-    server_default: Union["Computed", "DefaultClause"]
+    server_default: Union[Computed, DefaultClause]
 ) -> bool:
     return sqla_compat._server_default_is_computed(
         server_default
@@ -742,10 +730,8 @@ def _should_render_server_default_positionally(
 
 
 def _render_server_default(
-    default: Optional[
-        Union["FetchedValue", str, "TextClause", "ColumnElement"]
-    ],
-    autogen_context: "AutogenContext",
+    default: Optional[Union[FetchedValue, str, TextClause, ColumnElement]],
+    autogen_context: AutogenContext,
     repr_: bool = True,
 ) -> Optional[str]:
     rendered = _user_defined_render("server_default", default, autogen_context)
@@ -771,7 +757,7 @@ def _render_server_default(
 
 
 def _render_computed(
-    computed: "Computed", autogen_context: "AutogenContext"
+    computed: Computed, autogen_context: AutogenContext
 ) -> str:
     text = _render_potential_expr(
         computed.sqltext, autogen_context, wrap_in_text=False
@@ -788,7 +774,7 @@ def _render_computed(
 
 
 def _render_identity(
-    identity: "Identity", autogen_context: "AutogenContext"
+    identity: Identity, autogen_context: AutogenContext
 ) -> str:
     # always=None means something different than always=False
     kwargs = OrderedDict(always=identity.always)
@@ -802,7 +788,7 @@ def _render_identity(
     }
 
 
-def _get_identity_options(identity_options: "Identity") -> OrderedDict:
+def _get_identity_options(identity_options: Identity) -> OrderedDict:
     kwargs = OrderedDict()
     for attr in sqla_compat._identity_options_attrs:
         value = getattr(identity_options, attr, None)
@@ -812,8 +798,8 @@ def _get_identity_options(identity_options: "Identity") -> OrderedDict:
 
 
 def _repr_type(
-    type_: "TypeEngine",
-    autogen_context: "AutogenContext",
+    type_: TypeEngine,
+    autogen_context: AutogenContext,
     _skip_variants: bool = False,
 ) -> str:
     rendered = _user_defined_render("type", type_, autogen_context)
@@ -855,9 +841,7 @@ def _repr_type(
         return "%s%r" % (prefix, type_)
 
 
-def _render_ARRAY_type(
-    type_: "ARRAY", autogen_context: "AutogenContext"
-) -> str:
+def _render_ARRAY_type(type_: ARRAY, autogen_context: AutogenContext) -> str:
     return cast(
         str,
         _render_type_w_subtype(
@@ -867,7 +851,7 @@ def _render_ARRAY_type(
 
 
 def _render_Variant_type(
-    type_: "TypeEngine", autogen_context: "AutogenContext"
+    type_: TypeEngine, autogen_context: AutogenContext
 ) -> str:
     base_type, variant_mapping = sqla_compat._get_variant_mapping(type_)
     base = _repr_type(base_type, autogen_context, _skip_variants=True)
@@ -882,12 +866,12 @@ def _render_Variant_type(
 
 
 def _render_type_w_subtype(
-    type_: "TypeEngine",
-    autogen_context: "AutogenContext",
+    type_: TypeEngine,
+    autogen_context: AutogenContext,
     attrname: str,
     regexp: str,
     prefix: Optional[str] = None,
-) -> Union[Optional[str], "Literal[False]"]:
+) -> Union[Optional[str], Literal[False]]:
     outer_repr = repr(type_)
     inner_type = getattr(type_, attrname, None)
     if inner_type is None:
@@ -919,9 +903,9 @@ _constraint_renderers = util.Dispatcher()
 
 
 def _render_constraint(
-    constraint: "Constraint",
-    autogen_context: "AutogenContext",
-    namespace_metadata: Optional["MetaData"],
+    constraint: Constraint,
+    autogen_context: AutogenContext,
+    namespace_metadata: Optional[MetaData],
 ) -> Optional[str]:
     try:
         renderer = _constraint_renderers.dispatch(constraint)
@@ -934,9 +918,9 @@ def _render_constraint(
 
 @_constraint_renderers.dispatch_for(sa_schema.PrimaryKeyConstraint)
 def _render_primary_key(
-    constraint: "PrimaryKeyConstraint",
-    autogen_context: "AutogenContext",
-    namespace_metadata: Optional["MetaData"],
+    constraint: PrimaryKeyConstraint,
+    autogen_context: AutogenContext,
+    namespace_metadata: Optional[MetaData],
 ) -> Optional[str]:
     rendered = _user_defined_render("primary_key", constraint, autogen_context)
     if rendered is not False:
@@ -960,9 +944,9 @@ def _render_primary_key(
 
 
 def _fk_colspec(
-    fk: "ForeignKey",
+    fk: ForeignKey,
     metadata_schema: Optional[str],
-    namespace_metadata: "MetaData",
+    namespace_metadata: MetaData,
 ) -> str:
     """Implement a 'safe' version of ForeignKey._get_colspec() that
     won't fail if the remote table can't be resolved.
@@ -997,7 +981,7 @@ def _fk_colspec(
 
 
 def _populate_render_fk_opts(
-    constraint: "ForeignKeyConstraint", opts: List[Tuple[str, str]]
+    constraint: ForeignKeyConstraint, opts: List[Tuple[str, str]]
 ) -> None:
 
     if constraint.onupdate:
@@ -1014,9 +998,9 @@ def _populate_render_fk_opts(
 
 @_constraint_renderers.dispatch_for(sa_schema.ForeignKeyConstraint)
 def _render_foreign_key(
-    constraint: "ForeignKeyConstraint",
-    autogen_context: "AutogenContext",
-    namespace_metadata: "MetaData",
+    constraint: ForeignKeyConstraint,
+    autogen_context: AutogenContext,
+    namespace_metadata: MetaData,
 ) -> Optional[str]:
     rendered = _user_defined_render("foreign_key", constraint, autogen_context)
     if rendered is not False:
@@ -1053,9 +1037,9 @@ def _render_foreign_key(
 
 @_constraint_renderers.dispatch_for(sa_schema.UniqueConstraint)
 def _render_unique_constraint(
-    constraint: "UniqueConstraint",
-    autogen_context: "AutogenContext",
-    namespace_metadata: Optional["MetaData"],
+    constraint: UniqueConstraint,
+    autogen_context: AutogenContext,
+    namespace_metadata: Optional[MetaData],
 ) -> str:
     rendered = _user_defined_render("unique", constraint, autogen_context)
     if rendered is not False:
@@ -1066,9 +1050,9 @@ def _render_unique_constraint(
 
 @_constraint_renderers.dispatch_for(sa_schema.CheckConstraint)
 def _render_check_constraint(
-    constraint: "CheckConstraint",
-    autogen_context: "AutogenContext",
-    namespace_metadata: Optional["MetaData"],
+    constraint: CheckConstraint,
+    autogen_context: AutogenContext,
+    namespace_metadata: Optional[MetaData],
 ) -> Optional[str]:
     rendered = _user_defined_render("check", constraint, autogen_context)
     if rendered is not False:
@@ -1106,9 +1090,7 @@ def _render_check_constraint(
 
 
 @renderers.dispatch_for(ops.ExecuteSQLOp)
-def _execute_sql(
-    autogen_context: "AutogenContext", op: "ops.ExecuteSQLOp"
-) -> str:
+def _execute_sql(autogen_context: AutogenContext, op: ops.ExecuteSQLOp) -> str:
     if not isinstance(op.sqltext, str):
         raise NotImplementedError(
             "Autogenerate rendering of SQL Expression language constructs "
index 79f665a052b41cdb3a17a5e6c31b510669cc3a5e..1a29b963eb1f4e6b53f2e8a851faebe70f31436c 100644 (file)
@@ -95,11 +95,11 @@ class Rewriter:
     def rewrites(
         self,
         operator: Union[
-            Type["AddColumnOp"],
-            Type["MigrateOperation"],
-            Type["AlterColumnOp"],
-            Type["CreateTableOp"],
-            Type["ModifyTableOps"],
+            Type[AddColumnOp],
+            Type[MigrateOperation],
+            Type[AlterColumnOp],
+            Type[CreateTableOp],
+            Type[ModifyTableOps],
         ],
     ) -> Callable:
         """Register a function as rewriter for a given type.
@@ -118,10 +118,10 @@ class Rewriter:
 
     def _rewrite(
         self,
-        context: "MigrationContext",
-        revision: "Revision",
-        directive: "MigrateOperation",
-    ) -> Iterator["MigrateOperation"]:
+        context: MigrationContext,
+        revision: Revision,
+        directive: MigrateOperation,
+    ) -> Iterator[MigrateOperation]:
         try:
             _rewriter = self.dispatch.dispatch(directive)
         except ValueError:
@@ -141,9 +141,9 @@ class Rewriter:
 
     def __call__(
         self,
-        context: "MigrationContext",
-        revision: "Revision",
-        directives: List["MigrationScript"],
+        context: MigrationContext,
+        revision: Revision,
+        directives: List[MigrationScript],
     ) -> None:
         self.process_revision_directives(context, revision, directives)
         if self._chained:
@@ -152,9 +152,9 @@ class Rewriter:
     @_traverse.dispatch_for(ops.MigrationScript)
     def _traverse_script(
         self,
-        context: "MigrationContext",
-        revision: "Revision",
-        directive: "MigrationScript",
+        context: MigrationContext,
+        revision: Revision,
+        directive: MigrationScript,
     ) -> None:
         upgrade_ops_list = []
         for upgrade_ops in directive.upgrade_ops_list:
@@ -179,26 +179,26 @@ class Rewriter:
     @_traverse.dispatch_for(ops.OpContainer)
     def _traverse_op_container(
         self,
-        context: "MigrationContext",
-        revision: "Revision",
-        directive: "OpContainer",
+        context: MigrationContext,
+        revision: Revision,
+        directive: OpContainer,
     ) -> None:
         self._traverse_list(context, revision, directive.ops)
 
     @_traverse.dispatch_for(ops.MigrateOperation)
     def _traverse_any_directive(
         self,
-        context: "MigrationContext",
-        revision: "Revision",
-        directive: "MigrateOperation",
+        context: MigrationContext,
+        revision: Revision,
+        directive: MigrateOperation,
     ) -> None:
         pass
 
     def _traverse_for(
         self,
-        context: "MigrationContext",
-        revision: "Revision",
-        directive: "MigrateOperation",
+        context: MigrationContext,
+        revision: Revision,
+        directive: MigrateOperation,
     ) -> Any:
         directives = list(self._rewrite(context, revision, directive))
         for directive in directives:
@@ -208,8 +208,8 @@ class Rewriter:
 
     def _traverse_list(
         self,
-        context: "MigrationContext",
-        revision: "Revision",
+        context: MigrationContext,
+        revision: Revision,
         directives: Any,
     ) -> None:
         dest = []
@@ -220,8 +220,8 @@ class Rewriter:
 
     def process_revision_directives(
         self,
-        context: "MigrationContext",
-        revision: "Revision",
-        directives: List["MigrationScript"],
+        context: MigrationContext,
+        revision: Revision,
+        directives: List[MigrationScript],
     ) -> None:
         self._traverse_list(context, revision, directives)
index 162b3d0c996ba8c73a5271b909cac4d428ff6325..5c33a95eadf49de9823cc237ec004fcde27c991d 100644 (file)
@@ -37,7 +37,7 @@ def list_templates(config):
 
 
 def init(
-    config: "Config",
+    config: Config,
     directory: str,
     template: str = "generic",
     package: bool = False,
@@ -114,7 +114,7 @@ def init(
 
 
 def revision(
-    config: "Config",
+    config: Config,
     message: Optional[str] = None,
     autogenerate: bool = False,
     sql: bool = False,
@@ -125,7 +125,7 @@ def revision(
     rev_id: Optional[str] = None,
     depends_on: Optional[str] = None,
     process_revision_directives: Optional[ProcessRevisionDirectiveFn] = None,
-) -> Union[Optional["Script"], List[Optional["Script"]]]:
+) -> Union[Optional[Script], List[Optional[Script]]]:
     """Create a new revision file.
 
     :param config: a :class:`.Config` object.
@@ -241,12 +241,12 @@ def revision(
 
 
 def merge(
-    config: "Config",
+    config: Config,
     revisions: str,
     message: Optional[str] = None,
     branch_label: Optional[str] = None,
     rev_id: Optional[str] = None,
-) -> Optional["Script"]:
+) -> Optional[Script]:
     """Merge two revisions together.  Creates a new migration file.
 
     :param config: a :class:`.Config` instance
@@ -280,7 +280,7 @@ def merge(
 
 
 def upgrade(
-    config: "Config",
+    config: Config,
     revision: str,
     sql: bool = False,
     tag: Optional[str] = None,
@@ -323,7 +323,7 @@ def upgrade(
 
 
 def downgrade(
-    config: "Config",
+    config: Config,
     revision: str,
     sql: bool = False,
     tag: Optional[str] = None,
@@ -394,7 +394,7 @@ def show(config, rev):
 
 
 def history(
-    config: "Config",
+    config: Config,
     rev_range: Optional[str] = None,
     verbose: bool = False,
     indicate_current: bool = False,
@@ -517,7 +517,7 @@ def branches(config, verbose=False):
             )
 
 
-def current(config: "Config", verbose: bool = False) -> None:
+def current(config: Config, verbose: bool = False) -> None:
     """Display the current revision for a database.
 
     :param config: a :class:`.Config` instance.
@@ -546,7 +546,7 @@ def current(config: "Config", verbose: bool = False) -> None:
 
 
 def stamp(
-    config: "Config",
+    config: Config,
     revision: str,
     sql: bool = False,
     tag: Optional[str] = None,
@@ -615,7 +615,7 @@ def stamp(
         script.run_env()
 
 
-def edit(config: "Config", rev: str) -> None:
+def edit(config: Config, rev: str) -> None:
     """Edit revision script(s) using $EDITOR.
 
     :param config: a :class:`.Config` instance.
@@ -648,7 +648,7 @@ def edit(config: "Config", rev: str) -> None:
             util.open_in_editor(sc.path)
 
 
-def ensure_version(config: "Config", sql: bool = False) -> None:
+def ensure_version(config: Config, sql: bool = False) -> None:
     """Create the alembic version table if it doesn't exist already .
 
     :param config: a :class:`.Config` instance.
index 8464407d59f7ff7ef33b470c9c0235fe60440186..ac27d585bed2731833a43a1fe67a159581f785ed 100644 (file)
@@ -561,7 +561,7 @@ class CommandLine:
             fn(
                 config,
                 *[getattr(options, k, None) for k in positional],
-                **dict((k, getattr(options, k, None)) for k in kwarg),
+                **{k: getattr(options, k, None) for k in kwarg},
             )
         except util.CommandError as e:
             if options.raiseerr:
index c9107867d3c6a92901a4461cdf3e11f06e1a22f6..c3bdaf382be31576eb57c4a07cda2b5912e6b263 100644 (file)
@@ -46,7 +46,7 @@ class AlterTable(DDLElement):
     def __init__(
         self,
         table_name: str,
-        schema: Optional[Union["quoted_name", str]] = None,
+        schema: Optional[Union[quoted_name, str]] = None,
     ) -> None:
         self.table_name = table_name
         self.schema = schema
@@ -56,10 +56,10 @@ class RenameTable(AlterTable):
     def __init__(
         self,
         old_table_name: str,
-        new_table_name: Union["quoted_name", str],
-        schema: Optional[Union["quoted_name", str]] = None,
+        new_table_name: Union[quoted_name, str],
+        schema: Optional[Union[quoted_name, str]] = None,
     ) -> None:
-        super(RenameTable, self).__init__(old_table_name, schema=schema)
+        super().__init__(old_table_name, schema=schema)
         self.new_table_name = new_table_name
 
 
@@ -69,12 +69,12 @@ class AlterColumn(AlterTable):
         name: str,
         column_name: str,
         schema: Optional[str] = None,
-        existing_type: Optional["TypeEngine"] = None,
+        existing_type: Optional[TypeEngine] = None,
         existing_nullable: Optional[bool] = None,
         existing_server_default: Optional[_ServerDefault] = None,
         existing_comment: Optional[str] = None,
     ) -> None:
-        super(AlterColumn, self).__init__(name, schema=schema)
+        super().__init__(name, schema=schema)
         self.column_name = column_name
         self.existing_type = (
             sqltypes.to_instance(existing_type)
@@ -90,15 +90,15 @@ class ColumnNullable(AlterColumn):
     def __init__(
         self, name: str, column_name: str, nullable: bool, **kw
     ) -> None:
-        super(ColumnNullable, self).__init__(name, column_name, **kw)
+        super().__init__(name, column_name, **kw)
         self.nullable = nullable
 
 
 class ColumnType(AlterColumn):
     def __init__(
-        self, name: str, column_name: str, type_: "TypeEngine", **kw
+        self, name: str, column_name: str, type_: TypeEngine, **kw
     ) -> None:
-        super(ColumnType, self).__init__(name, column_name, **kw)
+        super().__init__(name, column_name, **kw)
         self.type_ = sqltypes.to_instance(type_)
 
 
@@ -106,7 +106,7 @@ class ColumnName(AlterColumn):
     def __init__(
         self, name: str, column_name: str, newname: str, **kw
     ) -> None:
-        super(ColumnName, self).__init__(name, column_name, **kw)
+        super().__init__(name, column_name, **kw)
         self.newname = newname
 
 
@@ -118,15 +118,15 @@ class ColumnDefault(AlterColumn):
         default: Optional[_ServerDefault],
         **kw,
     ) -> None:
-        super(ColumnDefault, self).__init__(name, column_name, **kw)
+        super().__init__(name, column_name, **kw)
         self.default = default
 
 
 class ComputedColumnDefault(AlterColumn):
     def __init__(
-        self, name: str, column_name: str, default: Optional["Computed"], **kw
+        self, name: str, column_name: str, default: Optional[Computed], **kw
     ) -> None:
-        super(ComputedColumnDefault, self).__init__(name, column_name, **kw)
+        super().__init__(name, column_name, **kw)
         self.default = default
 
 
@@ -135,11 +135,11 @@ class IdentityColumnDefault(AlterColumn):
         self,
         name: str,
         column_name: str,
-        default: Optional["Identity"],
-        impl: "DefaultImpl",
+        default: Optional[Identity],
+        impl: DefaultImpl,
         **kw,
     ) -> None:
-        super(IdentityColumnDefault, self).__init__(name, column_name, **kw)
+        super().__init__(name, column_name, **kw)
         self.default = default
         self.impl = impl
 
@@ -148,18 +148,18 @@ class AddColumn(AlterTable):
     def __init__(
         self,
         name: str,
-        column: "Column",
-        schema: Optional[Union["quoted_name", str]] = None,
+        column: Column,
+        schema: Optional[Union[quoted_name, str]] = None,
     ) -> None:
-        super(AddColumn, self).__init__(name, schema=schema)
+        super().__init__(name, schema=schema)
         self.column = column
 
 
 class DropColumn(AlterTable):
     def __init__(
-        self, name: str, column: "Column", schema: Optional[str] = None
+        self, name: str, column: Column, schema: Optional[str] = None
     ) -> None:
-        super(DropColumn, self).__init__(name, schema=schema)
+        super().__init__(name, schema=schema)
         self.column = column
 
 
@@ -167,13 +167,13 @@ class ColumnComment(AlterColumn):
     def __init__(
         self, name: str, column_name: str, comment: Optional[str], **kw
     ) -> None:
-        super(ColumnComment, self).__init__(name, column_name, **kw)
+        super().__init__(name, column_name, **kw)
         self.comment = comment
 
 
 @compiles(RenameTable)
 def visit_rename_table(
-    element: "RenameTable", compiler: "DDLCompiler", **kw
+    element: RenameTable, compiler: DDLCompiler, **kw
 ) -> str:
     return "%s RENAME TO %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -182,9 +182,7 @@ def visit_rename_table(
 
 
 @compiles(AddColumn)
-def visit_add_column(
-    element: "AddColumn", compiler: "DDLCompiler", **kw
-) -> str:
+def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str:
     return "%s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         add_column(compiler, element.column, **kw),
@@ -192,9 +190,7 @@ def visit_add_column(
 
 
 @compiles(DropColumn)
-def visit_drop_column(
-    element: "DropColumn", compiler: "DDLCompiler", **kw
-) -> str:
+def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str:
     return "%s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         drop_column(compiler, element.column.name, **kw),
@@ -203,7 +199,7 @@ def visit_drop_column(
 
 @compiles(ColumnNullable)
 def visit_column_nullable(
-    element: "ColumnNullable", compiler: "DDLCompiler", **kw
+    element: ColumnNullable, compiler: DDLCompiler, **kw
 ) -> str:
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -213,9 +209,7 @@ def visit_column_nullable(
 
 
 @compiles(ColumnType)
-def visit_column_type(
-    element: "ColumnType", compiler: "DDLCompiler", **kw
-) -> str:
+def visit_column_type(element: ColumnType, compiler: DDLCompiler, **kw) -> str:
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
@@ -224,9 +218,7 @@ def visit_column_type(
 
 
 @compiles(ColumnName)
-def visit_column_name(
-    element: "ColumnName", compiler: "DDLCompiler", **kw
-) -> str:
+def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str:
     return "%s RENAME %s TO %s" % (
         alter_table(compiler, element.table_name, element.schema),
         format_column_name(compiler, element.column_name),
@@ -236,7 +228,7 @@ def visit_column_name(
 
 @compiles(ColumnDefault)
 def visit_column_default(
-    element: "ColumnDefault", compiler: "DDLCompiler", **kw
+    element: ColumnDefault, compiler: DDLCompiler, **kw
 ) -> str:
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -249,7 +241,7 @@ def visit_column_default(
 
 @compiles(ComputedColumnDefault)
 def visit_computed_column(
-    element: "ComputedColumnDefault", compiler: "DDLCompiler", **kw
+    element: ComputedColumnDefault, compiler: DDLCompiler, **kw
 ):
     raise exc.CompileError(
         'Adding or removing a "computed" construct, e.g. GENERATED '
@@ -259,7 +251,7 @@ def visit_computed_column(
 
 @compiles(IdentityColumnDefault)
 def visit_identity_column(
-    element: "IdentityColumnDefault", compiler: "DDLCompiler", **kw
+    element: IdentityColumnDefault, compiler: DDLCompiler, **kw
 ):
     raise exc.CompileError(
         'Adding, removing or modifying an "identity" construct, '
@@ -269,8 +261,8 @@ def visit_identity_column(
 
 
 def quote_dotted(
-    name: Union["quoted_name", str], quote: functools.partial
-) -> Union["quoted_name", str]:
+    name: Union[quoted_name, str], quote: functools.partial
+) -> Union[quoted_name, str]:
     """quote the elements of a dotted name"""
 
     if isinstance(name, quoted_name):
@@ -280,10 +272,10 @@ def quote_dotted(
 
 
 def format_table_name(
-    compiler: "Compiled",
-    name: Union["quoted_name", str],
-    schema: Optional[Union["quoted_name", str]],
-) -> Union["quoted_name", str]:
+    compiler: Compiled,
+    name: Union[quoted_name, str],
+    schema: Optional[Union[quoted_name, str]],
+) -> Union[quoted_name, str]:
     quote = functools.partial(compiler.preparer.quote)
     if schema:
         return quote_dotted(schema, quote) + "." + quote(name)
@@ -292,13 +284,13 @@ def format_table_name(
 
 
 def format_column_name(
-    compiler: "DDLCompiler", name: Optional[Union["quoted_name", str]]
-) -> Union["quoted_name", str]:
+    compiler: DDLCompiler, name: Optional[Union[quoted_name, str]]
+) -> Union[quoted_name, str]:
     return compiler.preparer.quote(name)  # type: ignore[arg-type]
 
 
 def format_server_default(
-    compiler: "DDLCompiler",
+    compiler: DDLCompiler,
     default: Optional[_ServerDefault],
 ) -> str:
     return compiler.get_column_default_string(
@@ -306,27 +298,27 @@ def format_server_default(
     )
 
 
-def format_type(compiler: "DDLCompiler", type_: "TypeEngine") -> str:
+def format_type(compiler: DDLCompiler, type_: TypeEngine) -> str:
     return compiler.dialect.type_compiler.process(type_)
 
 
 def alter_table(
-    compiler: "DDLCompiler",
+    compiler: DDLCompiler,
     name: str,
     schema: Optional[str],
 ) -> str:
     return "ALTER TABLE %s" % format_table_name(compiler, name, schema)
 
 
-def drop_column(compiler: "DDLCompiler", name: str, **kw) -> str:
+def drop_column(compiler: DDLCompiler, name: str, **kw) -> str:
     return "DROP COLUMN %s" % format_column_name(compiler, name)
 
 
-def alter_column(compiler: "DDLCompiler", name: str) -> str:
+def alter_column(compiler: DDLCompiler, name: str) -> str:
     return "ALTER COLUMN %s" % format_column_name(compiler, name)
 
 
-def add_column(compiler: "DDLCompiler", column: "Column", **kw) -> str:
+def add_column(compiler: DDLCompiler, column: Column, **kw) -> str:
     text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
 
     const = " ".join(
index 79d5245e5d30a91d494a4fba8b2f758931c9df5b..728d1dae394fab9fcecfd6069973fc58dd1e8e7a 100644 (file)
@@ -52,7 +52,7 @@ class ImplMeta(type):
     def __init__(
         cls,
         classname: str,
-        bases: Tuple[Type["DefaultImpl"]],
+        bases: Tuple[Type[DefaultImpl]],
         dict_: Dict[str, Any],
     ):
         newtype = type.__init__(cls, classname, bases, dict_)
@@ -61,7 +61,7 @@ class ImplMeta(type):
         return newtype
 
 
-_impls: Dict[str, Type["DefaultImpl"]] = {}
+_impls: Dict[str, Type[DefaultImpl]] = {}
 
 Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
 
@@ -91,11 +91,11 @@ class DefaultImpl(metaclass=ImplMeta):
 
     def __init__(
         self,
-        dialect: "Dialect",
-        connection: Optional["Connection"],
+        dialect: Dialect,
+        connection: Optional[Connection],
         as_sql: bool,
         transactional_ddl: Optional[bool],
-        output_buffer: Optional["TextIO"],
+        output_buffer: Optional[TextIO],
         context_opts: Dict[str, Any],
     ) -> None:
         self.dialect = dialect
@@ -116,7 +116,7 @@ class DefaultImpl(metaclass=ImplMeta):
                 )
 
     @classmethod
-    def get_by_dialect(cls, dialect: "Dialect") -> Type["DefaultImpl"]:
+    def get_by_dialect(cls, dialect: Dialect) -> Type[DefaultImpl]:
         return _impls[dialect.name]
 
     def static_output(self, text: str) -> None:
@@ -125,7 +125,7 @@ class DefaultImpl(metaclass=ImplMeta):
         self.output_buffer.flush()
 
     def requires_recreate_in_batch(
-        self, batch_op: "BatchOperationsImpl"
+        self, batch_op: BatchOperationsImpl
     ) -> bool:
         """Return True if the given :class:`.BatchOperationsImpl`
         would need the table to be recreated and copied in order to
@@ -138,7 +138,7 @@ class DefaultImpl(metaclass=ImplMeta):
         return False
 
     def prep_table_for_batch(
-        self, batch_impl: "ApplyBatchImpl", table: "Table"
+        self, batch_impl: ApplyBatchImpl, table: Table
     ) -> None:
         """perform any operations needed on a table before a new
         one is created to replace it in batch mode.
@@ -149,16 +149,16 @@ class DefaultImpl(metaclass=ImplMeta):
         """
 
     @property
-    def bind(self) -> Optional["Connection"]:
+    def bind(self) -> Optional[Connection]:
         return self.connection
 
     def _exec(
         self,
-        construct: Union["ClauseElement", str],
+        construct: Union[ClauseElement, str],
         execution_options: Optional[dict] = None,
         multiparams: Sequence[dict] = (),
         params: Dict[str, int] = util.immutabledict(),
-    ) -> Optional["CursorResult"]:
+    ) -> Optional[CursorResult]:
         if isinstance(construct, str):
             construct = text(construct)
         if self.as_sql:
@@ -196,7 +196,7 @@ class DefaultImpl(metaclass=ImplMeta):
 
     def execute(
         self,
-        sql: Union["ClauseElement", str],
+        sql: Union[ClauseElement, str],
         execution_options: None = None,
     ) -> None:
         self._exec(sql, execution_options)
@@ -206,15 +206,15 @@ class DefaultImpl(metaclass=ImplMeta):
         table_name: str,
         column_name: str,
         nullable: Optional[bool] = None,
-        server_default: Union["_ServerDefault", "Literal[False]"] = False,
+        server_default: Union[_ServerDefault, Literal[False]] = False,
         name: Optional[str] = None,
-        type_: Optional["TypeEngine"] = None,
+        type_: Optional[TypeEngine] = None,
         schema: Optional[str] = None,
         autoincrement: Optional[bool] = None,
-        comment: Optional[Union[str, "Literal[False]"]] = False,
+        comment: Optional[Union[str, Literal[False]]] = False,
         existing_comment: Optional[str] = None,
-        existing_type: Optional["TypeEngine"] = None,
-        existing_server_default: Optional["_ServerDefault"] = None,
+        existing_type: Optional[TypeEngine] = None,
+        existing_server_default: Optional[_ServerDefault] = None,
         existing_nullable: Optional[bool] = None,
         existing_autoincrement: Optional[bool] = None,
         **kw: Any,
@@ -316,15 +316,15 @@ class DefaultImpl(metaclass=ImplMeta):
     def add_column(
         self,
         table_name: str,
-        column: "Column",
-        schema: Optional[Union[str, "quoted_name"]] = None,
+        column: Column,
+        schema: Optional[Union[str, quoted_name]] = None,
     ) -> None:
         self._exec(base.AddColumn(table_name, column, schema=schema))
 
     def drop_column(
         self,
         table_name: str,
-        column: "Column",
+        column: Column,
         schema: Optional[str] = None,
         **kw,
     ) -> None:
@@ -334,20 +334,20 @@ class DefaultImpl(metaclass=ImplMeta):
         if const._create_rule is None or const._create_rule(self):
             self._exec(schema.AddConstraint(const))
 
-    def drop_constraint(self, const: "Constraint") -> None:
+    def drop_constraint(self, const: Constraint) -> None:
         self._exec(schema.DropConstraint(const))
 
     def rename_table(
         self,
         old_table_name: str,
-        new_table_name: Union[str, "quoted_name"],
-        schema: Optional[Union[str, "quoted_name"]] = None,
+        new_table_name: Union[str, quoted_name],
+        schema: Optional[Union[str, quoted_name]] = None,
     ) -> None:
         self._exec(
             base.RenameTable(old_table_name, new_table_name, schema=schema)
         )
 
-    def create_table(self, table: "Table") -> None:
+    def create_table(self, table: Table) -> None:
         table.dispatch.before_create(
             table, self.connection, checkfirst=False, _ddl_runner=self
         )
@@ -370,7 +370,7 @@ class DefaultImpl(metaclass=ImplMeta):
             if comment and with_comment:
                 self.create_column_comment(column)
 
-    def drop_table(self, table: "Table") -> None:
+    def drop_table(self, table: Table) -> None:
         table.dispatch.before_drop(
             table, self.connection, checkfirst=False, _ddl_runner=self
         )
@@ -379,24 +379,24 @@ class DefaultImpl(metaclass=ImplMeta):
             table, self.connection, checkfirst=False, _ddl_runner=self
         )
 
-    def create_index(self, index: "Index") -> None:
+    def create_index(self, index: Index) -> None:
         self._exec(schema.CreateIndex(index))
 
-    def create_table_comment(self, table: "Table") -> None:
+    def create_table_comment(self, table: Table) -> None:
         self._exec(schema.SetTableComment(table))
 
-    def drop_table_comment(self, table: "Table") -> None:
+    def drop_table_comment(self, table: Table) -> None:
         self._exec(schema.DropTableComment(table))
 
-    def create_column_comment(self, column: "ColumnElement") -> None:
+    def create_column_comment(self, column: ColumnElement) -> None:
         self._exec(schema.SetColumnComment(column))
 
-    def drop_index(self, index: "Index") -> None:
+    def drop_index(self, index: Index) -> None:
         self._exec(schema.DropIndex(index))
 
     def bulk_insert(
         self,
-        table: Union["TableClause", "Table"],
+        table: Union[TableClause, Table],
         rows: List[dict],
         multiinsert: bool = True,
     ) -> None:
@@ -408,19 +408,16 @@ class DefaultImpl(metaclass=ImplMeta):
             for row in rows:
                 self._exec(
                     sqla_compat._insert_inline(table).values(
-                        **dict(
-                            (
-                                k,
-                                sqla_compat._literal_bindparam(
-                                    k, v, type_=table.c[k].type
-                                )
-                                if not isinstance(
-                                    v, sqla_compat._literal_bindparam
-                                )
-                                else v,
+                        **{
+                            k: sqla_compat._literal_bindparam(
+                                k, v, type_=table.c[k].type
                             )
+                            if not isinstance(
+                                v, sqla_compat._literal_bindparam
+                            )
+                            else v
                             for k, v in row.items()
-                        )
+                        }
                     )
                 )
         else:
@@ -435,7 +432,7 @@ class DefaultImpl(metaclass=ImplMeta):
                             sqla_compat._insert_inline(table).values(**row)
                         )
 
-    def _tokenize_column_type(self, column: "Column") -> Params:
+    def _tokenize_column_type(self, column: Column) -> Params:
         definition = self.dialect.type_compiler.process(column.type).lower()
 
         # tokenize the SQLAlchemy-generated version of a type, so that
@@ -474,7 +471,7 @@ class DefaultImpl(metaclass=ImplMeta):
         return params
 
     def _column_types_match(
-        self, inspector_params: "Params", metadata_params: "Params"
+        self, inspector_params: Params, metadata_params: Params
     ) -> bool:
         if inspector_params.token0 == metadata_params.token0:
             return True
@@ -496,7 +493,7 @@ class DefaultImpl(metaclass=ImplMeta):
         return False
 
     def _column_args_match(
-        self, inspected_params: "Params", meta_params: "Params"
+        self, inspected_params: Params, meta_params: Params
     ) -> bool:
         """We want to compare column parameters. However, we only want
         to compare parameters that are set. If they both have `collation`,
@@ -529,7 +526,7 @@ class DefaultImpl(metaclass=ImplMeta):
         return True
 
     def compare_type(
-        self, inspector_column: "Column", metadata_column: "Column"
+        self, inspector_column: Column, metadata_column: Column
     ) -> bool:
         """Returns True if there ARE differences between the types of the two
         columns. Takes impl.type_synonyms into account between retrospected
@@ -555,10 +552,10 @@ class DefaultImpl(metaclass=ImplMeta):
 
     def correct_for_autogen_constraints(
         self,
-        conn_uniques: Set["UniqueConstraint"],
-        conn_indexes: Set["Index"],
-        metadata_unique_constraints: Set["UniqueConstraint"],
-        metadata_indexes: Set["Index"],
+        conn_uniques: Set[UniqueConstraint],
+        conn_indexes: Set[Index],
+        metadata_unique_constraints: Set[UniqueConstraint],
+        metadata_indexes: Set[Index],
     ) -> None:
         pass
 
@@ -569,7 +566,7 @@ class DefaultImpl(metaclass=ImplMeta):
             )
 
     def render_ddl_sql_expr(
-        self, expr: "ClauseElement", is_server_default: bool = False, **kw: Any
+        self, expr: ClauseElement, is_server_default: bool = False, **kw: Any
     ) -> str:
         """Render a SQL expression that is typically a server default,
         index expression, etc.
@@ -587,15 +584,13 @@ class DefaultImpl(metaclass=ImplMeta):
             )
         )
 
-    def _compat_autogen_column_reflect(
-        self, inspector: "Inspector"
-    ) -> Callable:
+    def _compat_autogen_column_reflect(self, inspector: Inspector) -> Callable:
         return self.autogen_column_reflect
 
     def correct_for_autogen_foreignkeys(
         self,
-        conn_fks: Set["ForeignKeyConstraint"],
-        metadata_fks: Set["ForeignKeyConstraint"],
+        conn_fks: Set[ForeignKeyConstraint],
+        metadata_fks: Set[ForeignKeyConstraint],
     ) -> None:
         pass
 
@@ -637,8 +632,8 @@ class DefaultImpl(metaclass=ImplMeta):
         self.static_output("COMMIT" + self.command_terminator)
 
     def render_type(
-        self, type_obj: "TypeEngine", autogen_context: "AutogenContext"
-    ) -> Union[str, "Literal[False]"]:
+        self, type_obj: TypeEngine, autogen_context: AutogenContext
+    ) -> Union[str, Literal[False]]:
         return False
 
     def _compare_identity_default(self, metadata_identity, inspector_identity):
index 28f0678e4167fca780a360fdcd6d343f4db2585a..6a208ec6852bb22047413f81132c3d4a8d3e27ef 100644 (file)
@@ -62,13 +62,13 @@ class MSSQLImpl(DefaultImpl):
     )
 
     def __init__(self, *arg, **kw) -> None:
-        super(MSSQLImpl, self).__init__(*arg, **kw)
+        super().__init__(*arg, **kw)
         self.batch_separator = self.context_opts.get(
             "mssql_batch_separator", self.batch_separator
         )
 
-    def _exec(self, construct: Any, *args, **kw) -> Optional["CursorResult"]:
-        result = super(MSSQLImpl, self)._exec(construct, *args, **kw)
+    def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]:
+        result = super()._exec(construct, *args, **kw)
         if self.as_sql and self.batch_separator:
             self.static_output(self.batch_separator)
         return result
@@ -77,7 +77,7 @@ class MSSQLImpl(DefaultImpl):
         self.static_output("BEGIN TRANSACTION" + self.command_terminator)
 
     def emit_commit(self) -> None:
-        super(MSSQLImpl, self).emit_commit()
+        super().emit_commit()
         if self.as_sql and self.batch_separator:
             self.static_output(self.batch_separator)
 
@@ -87,13 +87,13 @@ class MSSQLImpl(DefaultImpl):
         column_name: str,
         nullable: Optional[bool] = None,
         server_default: Optional[
-            Union["_ServerDefault", "Literal[False]"]
+            Union[_ServerDefault, Literal[False]]
         ] = False,
         name: Optional[str] = None,
-        type_: Optional["TypeEngine"] = None,
+        type_: Optional[TypeEngine] = None,
         schema: Optional[str] = None,
-        existing_type: Optional["TypeEngine"] = None,
-        existing_server_default: Optional["_ServerDefault"] = None,
+        existing_type: Optional[TypeEngine] = None,
+        existing_server_default: Optional[_ServerDefault] = None,
         existing_nullable: Optional[bool] = None,
         **kw: Any,
     ) -> None:
@@ -136,7 +136,7 @@ class MSSQLImpl(DefaultImpl):
             kw["server_default"] = server_default
             kw["existing_server_default"] = existing_server_default
 
-        super(MSSQLImpl, self).alter_column(
+        super().alter_column(
             table_name,
             column_name,
             nullable=nullable,
@@ -158,7 +158,7 @@ class MSSQLImpl(DefaultImpl):
                     )
                 )
             if server_default is not None:
-                super(MSSQLImpl, self).alter_column(
+                super().alter_column(
                     table_name,
                     column_name,
                     schema=schema,
@@ -166,11 +166,11 @@ class MSSQLImpl(DefaultImpl):
                 )
 
         if name is not None:
-            super(MSSQLImpl, self).alter_column(
+            super().alter_column(
                 table_name, column_name, schema=schema, name=name
             )
 
-    def create_index(self, index: "Index") -> None:
+    def create_index(self, index: Index) -> None:
         # this likely defaults to None if not present, so get()
         # should normally not return the default value.  being
         # defensive in any case
@@ -182,25 +182,25 @@ class MSSQLImpl(DefaultImpl):
         self._exec(CreateIndex(index))
 
     def bulk_insert(  # type:ignore[override]
-        self, table: Union["TableClause", "Table"], rows: List[dict], **kw: Any
+        self, table: Union[TableClause, Table], rows: List[dict], **kw: Any
     ) -> None:
         if self.as_sql:
             self._exec(
                 "SET IDENTITY_INSERT %s ON"
                 % self.dialect.identifier_preparer.format_table(table)
             )
-            super(MSSQLImpl, self).bulk_insert(table, rows, **kw)
+            super().bulk_insert(table, rows, **kw)
             self._exec(
                 "SET IDENTITY_INSERT %s OFF"
                 % self.dialect.identifier_preparer.format_table(table)
             )
         else:
-            super(MSSQLImpl, self).bulk_insert(table, rows, **kw)
+            super().bulk_insert(table, rows, **kw)
 
     def drop_column(
         self,
         table_name: str,
-        column: "Column",
+        column: Column,
         schema: Optional[str] = None,
         **kw,
     ) -> None:
@@ -221,9 +221,7 @@ class MSSQLImpl(DefaultImpl):
         drop_fks = kw.pop("mssql_drop_foreign_key", False)
         if drop_fks:
             self._exec(_ExecDropFKConstraint(table_name, column, schema))
-        super(MSSQLImpl, self).drop_column(
-            table_name, column, schema=schema, **kw
-        )
+        super().drop_column(table_name, column, schema=schema, **kw)
 
     def compare_server_default(
         self,
@@ -244,9 +242,9 @@ class MSSQLImpl(DefaultImpl):
         )
 
     def _compare_identity_default(self, metadata_identity, inspector_identity):
-        diff, ignored, is_alter = super(
-            MSSQLImpl, self
-        )._compare_identity_default(metadata_identity, inspector_identity)
+        diff, ignored, is_alter = super()._compare_identity_default(
+            metadata_identity, inspector_identity
+        )
 
         if (
             metadata_identity is None
@@ -268,7 +266,7 @@ class _ExecDropConstraint(Executable, ClauseElement):
     def __init__(
         self,
         tname: str,
-        colname: Union["Column", str],
+        colname: Union[Column, str],
         type_: str,
         schema: Optional[str],
     ) -> None:
@@ -282,7 +280,7 @@ class _ExecDropFKConstraint(Executable, ClauseElement):
     inherit_cache = False
 
     def __init__(
-        self, tname: str, colname: "Column", schema: Optional[str]
+        self, tname: str, colname: Column, schema: Optional[str]
     ) -> None:
         self.tname = tname
         self.colname = colname
@@ -291,7 +289,7 @@ class _ExecDropFKConstraint(Executable, ClauseElement):
 
 @compiles(_ExecDropConstraint, "mssql")
 def _exec_drop_col_constraint(
-    element: "_ExecDropConstraint", compiler: "MSSQLCompiler", **kw
+    element: _ExecDropConstraint, compiler: MSSQLCompiler, **kw
 ) -> str:
     schema, tname, colname, type_ = (
         element.schema,
@@ -317,7 +315,7 @@ exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
 
 @compiles(_ExecDropFKConstraint, "mssql")
 def _exec_drop_col_fk_constraint(
-    element: "_ExecDropFKConstraint", compiler: "MSSQLCompiler", **kw
+    element: _ExecDropFKConstraint, compiler: MSSQLCompiler, **kw
 ) -> str:
     schema, tname, colname = element.schema, element.tname, element.colname
 
@@ -336,22 +334,20 @@ exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
 
 
 @compiles(AddColumn, "mssql")
-def visit_add_column(
-    element: "AddColumn", compiler: "MSDDLCompiler", **kw
-) -> str:
+def visit_add_column(element: AddColumn, compiler: MSDDLCompiler, **kw) -> str:
     return "%s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         mssql_add_column(compiler, element.column, **kw),
     )
 
 
-def mssql_add_column(compiler: "MSDDLCompiler", column: "Column", **kw) -> str:
+def mssql_add_column(compiler: MSDDLCompiler, column: Column, **kw) -> str:
     return "ADD %s" % compiler.get_column_specification(column, **kw)
 
 
 @compiles(ColumnNullable, "mssql")
 def visit_column_nullable(
-    element: "ColumnNullable", compiler: "MSDDLCompiler", **kw
+    element: ColumnNullable, compiler: MSDDLCompiler, **kw
 ) -> str:
     return "%s %s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -363,7 +359,7 @@ def visit_column_nullable(
 
 @compiles(ColumnDefault, "mssql")
 def visit_column_default(
-    element: "ColumnDefault", compiler: "MSDDLCompiler", **kw
+    element: ColumnDefault, compiler: MSDDLCompiler, **kw
 ) -> str:
     # TODO: there can also be a named constraint
     # with ADD CONSTRAINT here
@@ -376,7 +372,7 @@ def visit_column_default(
 
 @compiles(ColumnName, "mssql")
 def visit_rename_column(
-    element: "ColumnName", compiler: "MSDDLCompiler", **kw
+    element: ColumnName, compiler: MSDDLCompiler, **kw
 ) -> str:
     return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % (
         format_table_name(compiler, element.table_name, element.schema),
@@ -387,7 +383,7 @@ def visit_rename_column(
 
 @compiles(ColumnType, "mssql")
 def visit_column_type(
-    element: "ColumnType", compiler: "MSDDLCompiler", **kw
+    element: ColumnType, compiler: MSDDLCompiler, **kw
 ) -> str:
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -398,7 +394,7 @@ def visit_column_type(
 
 @compiles(RenameTable, "mssql")
 def visit_rename_table(
-    element: "RenameTable", compiler: "MSDDLCompiler", **kw
+    element: RenameTable, compiler: MSDDLCompiler, **kw
 ) -> str:
     return "EXEC sp_rename '%s', %s" % (
         format_table_name(compiler, element.table_name, element.schema),
index 0c03fbe1121489d459937f6319ec4e364465c714..a452760227227789a7d80312d154a145fc4626b4 100644 (file)
@@ -51,16 +51,16 @@ class MySQLImpl(DefaultImpl):
         table_name: str,
         column_name: str,
         nullable: Optional[bool] = None,
-        server_default: Union["_ServerDefault", "Literal[False]"] = False,
+        server_default: Union[_ServerDefault, Literal[False]] = False,
         name: Optional[str] = None,
-        type_: Optional["TypeEngine"] = None,
+        type_: Optional[TypeEngine] = None,
         schema: Optional[str] = None,
-        existing_type: Optional["TypeEngine"] = None,
-        existing_server_default: Optional["_ServerDefault"] = None,
+        existing_type: Optional[TypeEngine] = None,
+        existing_server_default: Optional[_ServerDefault] = None,
         existing_nullable: Optional[bool] = None,
         autoincrement: Optional[bool] = None,
         existing_autoincrement: Optional[bool] = None,
-        comment: Optional[Union[str, "Literal[False]"]] = False,
+        comment: Optional[Union[str, Literal[False]]] = False,
         existing_comment: Optional[str] = None,
         **kw: Any,
     ) -> None:
@@ -71,7 +71,7 @@ class MySQLImpl(DefaultImpl):
         ):
             # modifying computed or identity columns is not supported
             # the default will raise
-            super(MySQLImpl, self).alter_column(
+            super().alter_column(
                 table_name,
                 column_name,
                 nullable=nullable,
@@ -147,17 +147,17 @@ class MySQLImpl(DefaultImpl):
 
     def drop_constraint(
         self,
-        const: "Constraint",
+        const: Constraint,
     ) -> None:
         if isinstance(const, schema.CheckConstraint) and _is_type_bound(const):
             return
 
-        super(MySQLImpl, self).drop_constraint(const)
+        super().drop_constraint(const)
 
     def _is_mysql_allowed_functional_default(
         self,
-        type_: Optional["TypeEngine"],
-        server_default: Union["_ServerDefault", "Literal[False]"],
+        type_: Optional[TypeEngine],
+        server_default: Union[_ServerDefault, Literal[False]],
     ) -> bool:
         return (
             type_ is not None
@@ -263,12 +263,12 @@ class MySQLImpl(DefaultImpl):
                 metadata_indexes.remove(idx)
 
     def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
-        conn_fk_by_sig = dict(
-            (compare._fk_constraint_sig(fk).sig, fk) for fk in conn_fks
-        )
-        metadata_fk_by_sig = dict(
-            (compare._fk_constraint_sig(fk).sig, fk) for fk in metadata_fks
-        )
+        conn_fk_by_sig = {
+            compare._fk_constraint_sig(fk).sig: fk for fk in conn_fks
+        }
+        metadata_fk_by_sig = {
+            compare._fk_constraint_sig(fk).sig: fk for fk in metadata_fks
+        }
 
         for sig in set(conn_fk_by_sig).intersection(metadata_fk_by_sig):
             mdfk = metadata_fk_by_sig[sig]
@@ -299,7 +299,7 @@ class MySQLAlterDefault(AlterColumn):
         self,
         name: str,
         column_name: str,
-        default: "_ServerDefault",
+        default: _ServerDefault,
         schema: Optional[str] = None,
     ) -> None:
         super(AlterColumn, self).__init__(name, schema=schema)
@@ -314,11 +314,11 @@ class MySQLChangeColumn(AlterColumn):
         column_name: str,
         schema: Optional[str] = None,
         newname: Optional[str] = None,
-        type_: Optional["TypeEngine"] = None,
+        type_: Optional[TypeEngine] = None,
         nullable: Optional[bool] = None,
-        default: Optional[Union["_ServerDefault", "Literal[False]"]] = False,
+        default: Optional[Union[_ServerDefault, Literal[False]]] = False,
         autoincrement: Optional[bool] = None,
-        comment: Optional[Union[str, "Literal[False]"]] = False,
+        comment: Optional[Union[str, Literal[False]]] = False,
     ) -> None:
         super(AlterColumn, self).__init__(name, schema=schema)
         self.column_name = column_name
@@ -352,7 +352,7 @@ def _mysql_doesnt_support_individual(element, compiler, **kw):
 
 @compiles(MySQLAlterDefault, "mysql", "mariadb")
 def _mysql_alter_default(
-    element: "MySQLAlterDefault", compiler: "MySQLDDLCompiler", **kw
+    element: MySQLAlterDefault, compiler: MySQLDDLCompiler, **kw
 ) -> str:
     return "%s ALTER COLUMN %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -365,7 +365,7 @@ def _mysql_alter_default(
 
 @compiles(MySQLModifyColumn, "mysql", "mariadb")
 def _mysql_modify_column(
-    element: "MySQLModifyColumn", compiler: "MySQLDDLCompiler", **kw
+    element: MySQLModifyColumn, compiler: MySQLDDLCompiler, **kw
 ) -> str:
     return "%s MODIFY %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -383,7 +383,7 @@ def _mysql_modify_column(
 
 @compiles(MySQLChangeColumn, "mysql", "mariadb")
 def _mysql_change_column(
-    element: "MySQLChangeColumn", compiler: "MySQLDDLCompiler", **kw
+    element: MySQLChangeColumn, compiler: MySQLDDLCompiler, **kw
 ) -> str:
     return "%s CHANGE %s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -401,12 +401,12 @@ def _mysql_change_column(
 
 
 def _mysql_colspec(
-    compiler: "MySQLDDLCompiler",
+    compiler: MySQLDDLCompiler,
     nullable: Optional[bool],
-    server_default: Optional[Union["_ServerDefault", "Literal[False]"]],
-    type_: "TypeEngine",
+    server_default: Optional[Union[_ServerDefault, Literal[False]]],
+    type_: TypeEngine,
     autoincrement: Optional[bool],
-    comment: Optional[Union[str, "Literal[False]"]],
+    comment: Optional[Union[str, Literal[False]]],
 ) -> str:
     spec = "%s %s" % (
         compiler.dialect.type_compiler.process(type_),
@@ -426,7 +426,7 @@ def _mysql_colspec(
 
 @compiles(schema.DropConstraint, "mysql", "mariadb")
 def _mysql_drop_constraint(
-    element: "DropConstraint", compiler: "MySQLDDLCompiler", **kw
+    element: DropConstraint, compiler: MySQLDDLCompiler, **kw
 ) -> str:
     """Redefine SQLAlchemy's drop constraint to
     raise errors for invalid constraint type."""
index accd1fcfb2481a44e1a51ba40641a16528b1f136..920b70ae06dfa30e9d1856d7b5a32c98c6520388 100644 (file)
@@ -41,13 +41,13 @@ class OracleImpl(DefaultImpl):
     identity_attrs_ignore = ()
 
     def __init__(self, *arg, **kw) -> None:
-        super(OracleImpl, self).__init__(*arg, **kw)
+        super().__init__(*arg, **kw)
         self.batch_separator = self.context_opts.get(
             "oracle_batch_separator", self.batch_separator
         )
 
-    def _exec(self, construct: Any, *args, **kw) -> Optional["CursorResult"]:
-        result = super(OracleImpl, self)._exec(construct, *args, **kw)
+    def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]:
+        result = super()._exec(construct, *args, **kw)
         if self.as_sql and self.batch_separator:
             self.static_output(self.batch_separator)
         return result
@@ -61,7 +61,7 @@ class OracleImpl(DefaultImpl):
 
 @compiles(AddColumn, "oracle")
 def visit_add_column(
-    element: "AddColumn", compiler: "OracleDDLCompiler", **kw
+    element: AddColumn, compiler: OracleDDLCompiler, **kw
 ) -> str:
     return "%s %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -71,7 +71,7 @@ def visit_add_column(
 
 @compiles(ColumnNullable, "oracle")
 def visit_column_nullable(
-    element: "ColumnNullable", compiler: "OracleDDLCompiler", **kw
+    element: ColumnNullable, compiler: OracleDDLCompiler, **kw
 ) -> str:
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -82,7 +82,7 @@ def visit_column_nullable(
 
 @compiles(ColumnType, "oracle")
 def visit_column_type(
-    element: "ColumnType", compiler: "OracleDDLCompiler", **kw
+    element: ColumnType, compiler: OracleDDLCompiler, **kw
 ) -> str:
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -93,7 +93,7 @@ def visit_column_type(
 
 @compiles(ColumnName, "oracle")
 def visit_column_name(
-    element: "ColumnName", compiler: "OracleDDLCompiler", **kw
+    element: ColumnName, compiler: OracleDDLCompiler, **kw
 ) -> str:
     return "%s RENAME COLUMN %s TO %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -104,7 +104,7 @@ def visit_column_name(
 
 @compiles(ColumnDefault, "oracle")
 def visit_column_default(
-    element: "ColumnDefault", compiler: "OracleDDLCompiler", **kw
+    element: ColumnDefault, compiler: OracleDDLCompiler, **kw
 ) -> str:
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -117,7 +117,7 @@ def visit_column_default(
 
 @compiles(ColumnComment, "oracle")
 def visit_column_comment(
-    element: "ColumnComment", compiler: "OracleDDLCompiler", **kw
+    element: ColumnComment, compiler: OracleDDLCompiler, **kw
 ) -> str:
     ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
 
@@ -135,7 +135,7 @@ def visit_column_comment(
 
 @compiles(RenameTable, "oracle")
 def visit_rename_table(
-    element: "RenameTable", compiler: "OracleDDLCompiler", **kw
+    element: RenameTable, compiler: OracleDDLCompiler, **kw
 ) -> str:
     return "%s RENAME TO %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -143,17 +143,17 @@ def visit_rename_table(
     )
 
 
-def alter_column(compiler: "OracleDDLCompiler", name: str) -> str:
+def alter_column(compiler: OracleDDLCompiler, name: str) -> str:
     return "MODIFY %s" % format_column_name(compiler, name)
 
 
-def add_column(compiler: "OracleDDLCompiler", column: "Column", **kw) -> str:
+def add_column(compiler: OracleDDLCompiler, column: Column, **kw) -> str:
     return "ADD %s" % compiler.get_column_specification(column, **kw)
 
 
 @compiles(IdentityColumnDefault, "oracle")
 def visit_identity_column(
-    element: "IdentityColumnDefault", compiler: "OracleDDLCompiler", **kw
+    element: IdentityColumnDefault, compiler: OracleDDLCompiler, **kw
 ):
     text = "%s %s " % (
         alter_table(compiler, element.table_name, element.schema),
index 5d93803ace143500111c1716e7ab7e38ef361ede..29efe4c91264dbabb31da423530bd9c9a951285e 100644 (file)
@@ -136,13 +136,13 @@ class PostgresqlImpl(DefaultImpl):
         table_name: str,
         column_name: str,
         nullable: Optional[bool] = None,
-        server_default: Union["_ServerDefault", "Literal[False]"] = False,
+        server_default: Union[_ServerDefault, Literal[False]] = False,
         name: Optional[str] = None,
-        type_: Optional["TypeEngine"] = None,
+        type_: Optional[TypeEngine] = None,
         schema: Optional[str] = None,
         autoincrement: Optional[bool] = None,
-        existing_type: Optional["TypeEngine"] = None,
-        existing_server_default: Optional["_ServerDefault"] = None,
+        existing_type: Optional[TypeEngine] = None,
+        existing_server_default: Optional[_ServerDefault] = None,
         existing_nullable: Optional[bool] = None,
         existing_autoincrement: Optional[bool] = None,
         **kw: Any,
@@ -169,7 +169,7 @@ class PostgresqlImpl(DefaultImpl):
                 )
             )
 
-        super(PostgresqlImpl, self).alter_column(
+        super().alter_column(
             table_name,
             column_name,
             nullable=nullable,
@@ -230,13 +230,13 @@ class PostgresqlImpl(DefaultImpl):
         metadata_indexes,
     ):
 
-        conn_indexes_by_name = dict((c.name, c) for c in conn_indexes)
+        conn_indexes_by_name = {c.name: c for c in conn_indexes}
 
-        doubled_constraints = set(
+        doubled_constraints = {
             index
             for index in conn_indexes
             if index.info.get("duplicates_constraint")
-        )
+        }
 
         for ix in doubled_constraints:
             conn_indexes.remove(ix)
@@ -260,8 +260,8 @@ class PostgresqlImpl(DefaultImpl):
                     metadata_indexes.discard(idx)
 
     def render_type(
-        self, type_: "TypeEngine", autogen_context: "AutogenContext"
-    ) -> Union[str, "Literal[False]"]:
+        self, type_: TypeEngine, autogen_context: AutogenContext
+    ) -> Union[str, Literal[False]]:
         mod = type(type_).__module__
         if not mod.startswith("sqlalchemy.dialects.postgresql"):
             return False
@@ -273,7 +273,7 @@ class PostgresqlImpl(DefaultImpl):
         return False
 
     def _render_HSTORE_type(
-        self, type_: "HSTORE", autogen_context: "AutogenContext"
+        self, type_: HSTORE, autogen_context: AutogenContext
     ) -> str:
         return cast(
             str,
@@ -283,7 +283,7 @@ class PostgresqlImpl(DefaultImpl):
         )
 
     def _render_ARRAY_type(
-        self, type_: "ARRAY", autogen_context: "AutogenContext"
+        self, type_: ARRAY, autogen_context: AutogenContext
     ) -> str:
         return cast(
             str,
@@ -293,7 +293,7 @@ class PostgresqlImpl(DefaultImpl):
         )
 
     def _render_JSON_type(
-        self, type_: "JSON", autogen_context: "AutogenContext"
+        self, type_: JSON, autogen_context: AutogenContext
     ) -> str:
         return cast(
             str,
@@ -303,7 +303,7 @@ class PostgresqlImpl(DefaultImpl):
         )
 
     def _render_JSONB_type(
-        self, type_: "JSONB", autogen_context: "AutogenContext"
+        self, type_: JSONB, autogen_context: AutogenContext
     ) -> str:
         return cast(
             str,
@@ -315,17 +315,17 @@ class PostgresqlImpl(DefaultImpl):
 
 class PostgresqlColumnType(AlterColumn):
     def __init__(
-        self, name: str, column_name: str, type_: "TypeEngine", **kw
+        self, name: str, column_name: str, type_: TypeEngine, **kw
     ) -> None:
         using = kw.pop("using", None)
-        super(PostgresqlColumnType, self).__init__(name, column_name, **kw)
+        super().__init__(name, column_name, **kw)
         self.type_ = sqltypes.to_instance(type_)
         self.using = using
 
 
 @compiles(RenameTable, "postgresql")
 def visit_rename_table(
-    element: RenameTable, compiler: "PGDDLCompiler", **kw
+    element: RenameTable, compiler: PGDDLCompiler, **kw
 ) -> str:
     return "%s RENAME TO %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -335,7 +335,7 @@ def visit_rename_table(
 
 @compiles(PostgresqlColumnType, "postgresql")
 def visit_column_type(
-    element: PostgresqlColumnType, compiler: "PGDDLCompiler", **kw
+    element: PostgresqlColumnType, compiler: PGDDLCompiler, **kw
 ) -> str:
     return "%s %s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -347,7 +347,7 @@ def visit_column_type(
 
 @compiles(ColumnComment, "postgresql")
 def visit_column_comment(
-    element: "ColumnComment", compiler: "PGDDLCompiler", **kw
+    element: ColumnComment, compiler: PGDDLCompiler, **kw
 ) -> str:
     ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
     comment = (
@@ -369,7 +369,7 @@ def visit_column_comment(
 
 @compiles(IdentityColumnDefault, "postgresql")
 def visit_identity_column(
-    element: "IdentityColumnDefault", compiler: "PGDDLCompiler", **kw
+    element: IdentityColumnDefault, compiler: PGDDLCompiler, **kw
 ):
     text = "%s %s " % (
         alter_table(compiler, element.table_name, element.schema),
@@ -415,14 +415,14 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
     def __init__(
         self,
         constraint_name: Optional[str],
-        table_name: Union[str, "quoted_name"],
+        table_name: Union[str, quoted_name],
         elements: Union[
             Sequence[Tuple[str, str]],
-            Sequence[Tuple["ColumnClause", str]],
+            Sequence[Tuple[ColumnClause, str]],
         ],
-        where: Optional[Union["BinaryExpression", str]] = None,
+        where: Optional[Union[BinaryExpression, str]] = None,
         schema: Optional[str] = None,
-        _orig_constraint: Optional["ExcludeConstraint"] = None,
+        _orig_constraint: Optional[ExcludeConstraint] = None,
         **kw,
     ) -> None:
         self.constraint_name = constraint_name
@@ -435,8 +435,8 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
 
     @classmethod
     def from_constraint(  # type:ignore[override]
-        cls, constraint: "ExcludeConstraint"
-    ) -> "CreateExcludeConstraintOp":
+        cls, constraint: ExcludeConstraint
+    ) -> CreateExcludeConstraintOp:
         constraint_table = sqla_compat._table_for_constraint(constraint)
 
         return cls(
@@ -455,8 +455,8 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
         )
 
     def to_constraint(
-        self, migration_context: Optional["MigrationContext"] = None
-    ) -> "ExcludeConstraint":
+        self, migration_context: Optional[MigrationContext] = None
+    ) -> ExcludeConstraint:
         if self._orig_constraint is not None:
             return self._orig_constraint
         schema_obj = schemaobj.SchemaObjects(migration_context)
@@ -479,12 +479,12 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
     @classmethod
     def create_exclude_constraint(
         cls,
-        operations: "Operations",
+        operations: Operations,
         constraint_name: str,
         table_name: str,
         *elements: Any,
         **kw: Any,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         """Issue an alter to create an EXCLUDE constraint using the
         current migration context.
 
@@ -546,16 +546,16 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
 
 @render.renderers.dispatch_for(CreateExcludeConstraintOp)
 def _add_exclude_constraint(
-    autogen_context: "AutogenContext", op: "CreateExcludeConstraintOp"
+    autogen_context: AutogenContext, op: CreateExcludeConstraintOp
 ) -> str:
     return _exclude_constraint(op.to_constraint(), autogen_context, alter=True)
 
 
 @render._constraint_renderers.dispatch_for(ExcludeConstraint)
 def _render_inline_exclude_constraint(
-    constraint: "ExcludeConstraint",
-    autogen_context: "AutogenContext",
-    namespace_metadata: "MetaData",
+    constraint: ExcludeConstraint,
+    autogen_context: AutogenContext,
+    namespace_metadata: MetaData,
 ) -> str:
     rendered = render._user_defined_render(
         "exclude", constraint, autogen_context
@@ -566,7 +566,7 @@ def _render_inline_exclude_constraint(
     return _exclude_constraint(constraint, autogen_context, False)
 
 
-def _postgresql_autogenerate_prefix(autogen_context: "AutogenContext") -> str:
+def _postgresql_autogenerate_prefix(autogen_context: AutogenContext) -> str:
 
     imports = autogen_context.imports
     if imports is not None:
@@ -575,8 +575,8 @@ def _postgresql_autogenerate_prefix(autogen_context: "AutogenContext") -> str:
 
 
 def _exclude_constraint(
-    constraint: "ExcludeConstraint",
-    autogen_context: "AutogenContext",
+    constraint: ExcludeConstraint,
+    autogen_context: AutogenContext,
     alter: bool,
 ) -> str:
     opts: List[Tuple[str, Union[quoted_name, str, _f_name, None]]] = []
@@ -645,7 +645,7 @@ def _exclude_constraint(
 
 
 def _render_potential_column(
-    value: Union["ColumnClause", "Column"], autogen_context: "AutogenContext"
+    value: Union[ColumnClause, Column], autogen_context: AutogenContext
 ) -> str:
     if isinstance(value, ColumnClause):
         template = "%(prefix)scolumn(%(name)r)"
index f986c32c3a50fcb6de1a03e2e23cdeb71baba3d6..51233326c6aebb6781184290dffcac2b60874d06 100644 (file)
@@ -41,7 +41,7 @@ class SQLiteImpl(DefaultImpl):
     """
 
     def requires_recreate_in_batch(
-        self, batch_op: "BatchOperationsImpl"
+        self, batch_op: BatchOperationsImpl
     ) -> bool:
         """Return True if the given :class:`.BatchOperationsImpl`
         would need the table to be recreated and copied in order to
@@ -68,7 +68,7 @@ class SQLiteImpl(DefaultImpl):
         else:
             return False
 
-    def add_constraint(self, const: "Constraint"):
+    def add_constraint(self, const: Constraint):
         # attempt to distinguish between an
         # auto-gen constraint and an explicit one
         if const._create_rule is None:  # type:ignore[attr-defined]
@@ -85,7 +85,7 @@ class SQLiteImpl(DefaultImpl):
                 "SQLite migrations using a copy-and-move strategy."
             )
 
-    def drop_constraint(self, const: "Constraint"):
+    def drop_constraint(self, const: Constraint):
         if const._create_rule is None:  # type:ignore[attr-defined]
             raise NotImplementedError(
                 "No support for ALTER of constraints in SQLite dialect. "
@@ -95,8 +95,8 @@ class SQLiteImpl(DefaultImpl):
 
     def compare_server_default(
         self,
-        inspector_column: "Column",
-        metadata_column: "Column",
+        inspector_column: Column,
+        metadata_column: Column,
         rendered_metadata_default: Optional[str],
         rendered_inspector_default: Optional[str],
     ) -> bool:
@@ -140,8 +140,8 @@ class SQLiteImpl(DefaultImpl):
 
     def autogen_column_reflect(
         self,
-        inspector: "Inspector",
-        table: "Table",
+        inspector: Inspector,
+        table: Table,
         column_info: Dict[str, Any],
     ) -> None:
         # SQLite expression defaults require parenthesis when sent
@@ -152,11 +152,11 @@ class SQLiteImpl(DefaultImpl):
             column_info["default"] = "(%s)" % (column_info["default"],)
 
     def render_ddl_sql_expr(
-        self, expr: "ClauseElement", is_server_default: bool = False, **kw
+        self, expr: ClauseElement, is_server_default: bool = False, **kw
     ) -> str:
         # SQLite expression defaults require parenthesis when sent
         # as DDL
-        str_expr = super(SQLiteImpl, self).render_ddl_sql_expr(
+        str_expr = super().render_ddl_sql_expr(
             expr, is_server_default=is_server_default, **kw
         )
 
@@ -169,9 +169,9 @@ class SQLiteImpl(DefaultImpl):
 
     def cast_for_batch_migrate(
         self,
-        existing: "Column",
-        existing_transfer: Dict[str, Union["TypeEngine", "Cast"]],
-        new_type: "TypeEngine",
+        existing: Column,
+        existing_transfer: Dict[str, Union[TypeEngine, Cast]],
+        new_type: TypeEngine,
     ) -> None:
         if (
             existing.type._type_affinity  # type:ignore[attr-defined]
@@ -185,7 +185,7 @@ class SQLiteImpl(DefaultImpl):
 
 @compiles(RenameTable, "sqlite")
 def visit_rename_table(
-    element: "RenameTable", compiler: "DDLCompiler", **kw
+    element: RenameTable, compiler: DDLCompiler, **kw
 ) -> str:
     return "%s RENAME TO %s" % (
         alter_table(compiler, element.table_name, element.schema),
index 2178998a638ea4390219a7d4f37821827cdf393b..04b66b55dceffcb4ebcce0c59d78b5d80b3320e1 100644 (file)
@@ -75,7 +75,7 @@ class Operations(util.ModuleClsProxy):
 
     """
 
-    impl: Union["DefaultImpl", "BatchOperationsImpl"]
+    impl: Union[DefaultImpl, BatchOperationsImpl]
     _to_impl = util.Dispatcher()
 
     def __init__(
@@ -222,13 +222,13 @@ class Operations(util.ModuleClsProxy):
         schema: Optional[str] = None,
         recreate: Literal["auto", "always", "never"] = "auto",
         partial_reordering: Optional[tuple] = None,
-        copy_from: Optional["Table"] = None,
+        copy_from: Optional[Table] = None,
         table_args: Tuple[Any, ...] = (),
         table_kwargs: Mapping[str, Any] = util.immutabledict(),
         reflect_args: Tuple[Any, ...] = (),
         reflect_kwargs: Mapping[str, Any] = util.immutabledict(),
         naming_convention: Optional[Dict[str, str]] = None,
-    ) -> Iterator["BatchOperations"]:
+    ) -> Iterator[BatchOperations]:
         """Invoke a series of per-table migrations in batch.
 
         Batch mode allows a series of operations specific to a table
@@ -514,7 +514,7 @@ class BatchOperations(Operations):
 
     """
 
-    impl: "BatchOperationsImpl"
+    impl: BatchOperationsImpl
 
     def _noop(self, operation):
         raise NotImplementedError(
index f1459e2bd82c5fbc370873e75b690d518d135b48..0c773c68ccc8de3e9df914215869318cb5cc19c8 100644 (file)
@@ -86,11 +86,11 @@ class BatchOperationsImpl:
         self.batch = []
 
     @property
-    def dialect(self) -> "Dialect":
+    def dialect(self) -> Dialect:
         return self.operations.impl.dialect
 
     @property
-    def impl(self) -> "DefaultImpl":
+    def impl(self) -> DefaultImpl:
         return self.operations.impl
 
     def _should_recreate(self) -> bool:
@@ -174,19 +174,19 @@ class BatchOperationsImpl:
     def drop_column(self, *arg, **kw) -> None:
         self.batch.append(("drop_column", arg, kw))
 
-    def add_constraint(self, const: "Constraint") -> None:
+    def add_constraint(self, const: Constraint) -> None:
         self.batch.append(("add_constraint", (const,), {}))
 
-    def drop_constraint(self, const: "Constraint") -> None:
+    def drop_constraint(self, const: Constraint) -> None:
         self.batch.append(("drop_constraint", (const,), {}))
 
     def rename_table(self, *arg, **kw):
         self.batch.append(("rename_table", arg, kw))
 
-    def create_index(self, idx: "Index") -> None:
+    def create_index(self, idx: Index) -> None:
         self.batch.append(("create_index", (idx,), {}))
 
-    def drop_index(self, idx: "Index") -> None:
+    def drop_index(self, idx: Index) -> None:
         self.batch.append(("drop_index", (idx,), {}))
 
     def create_table_comment(self, table):
@@ -208,8 +208,8 @@ class BatchOperationsImpl:
 class ApplyBatchImpl:
     def __init__(
         self,
-        impl: "DefaultImpl",
-        table: "Table",
+        impl: DefaultImpl,
+        table: Table,
         table_args: tuple,
         table_kwargs: Dict[str, Any],
         reflected: bool,
@@ -236,12 +236,12 @@ class ApplyBatchImpl:
         self._grab_table_elements()
 
     @classmethod
-    def _calc_temp_name(cls, tablename: Union["quoted_name", str]) -> str:
+    def _calc_temp_name(cls, tablename: Union[quoted_name, str]) -> str:
         return ("_alembic_tmp_%s" % tablename)[0:50]
 
     def _grab_table_elements(self) -> None:
         schema = self.table.schema
-        self.columns: Dict[str, "Column"] = OrderedDict()
+        self.columns: Dict[str, Column] = OrderedDict()
         for c in self.table.c:
             c_copy = _copy(c, schema=schema)
             c_copy.unique = c_copy.index = False
@@ -250,11 +250,11 @@ class ApplyBatchImpl:
             if isinstance(c.type, SchemaEventTarget):
                 assert c_copy.type is not c.type
             self.columns[c.name] = c_copy
-        self.named_constraints: Dict[str, "Constraint"] = {}
+        self.named_constraints: Dict[str, Constraint] = {}
         self.unnamed_constraints = []
         self.col_named_constraints = {}
-        self.indexes: Dict[str, "Index"] = {}
-        self.new_indexes: Dict[str, "Index"] = {}
+        self.indexes: Dict[str, Index] = {}
+        self.new_indexes: Dict[str, Index] = {}
 
         for const in self.table.constraints:
             if _is_type_bound(const):
@@ -336,14 +336,12 @@ class ApplyBatchImpl:
             list(self.named_constraints.values()) + self.unnamed_constraints
         ):
 
-            const_columns = set(
-                [c.key for c in _columns_for_constraint(const)]
-            )
+            const_columns = {c.key for c in _columns_for_constraint(const)}
 
             if not const_columns.issubset(self.column_transfers):
                 continue
 
-            const_copy: "Constraint"
+            const_copy: Constraint
             if isinstance(const, ForeignKeyConstraint):
                 if _fk_is_self_referential(const):
                     # for self-referential constraint, refer to the
@@ -368,7 +366,7 @@ class ApplyBatchImpl:
                 self._setup_referent(m, const)
             new_table.append_constraint(const_copy)
 
-    def _gather_indexes_from_both_tables(self) -> List["Index"]:
+    def _gather_indexes_from_both_tables(self) -> List[Index]:
         assert self.new_table is not None
         idx: List[Index] = []
 
@@ -402,7 +400,7 @@ class ApplyBatchImpl:
         return idx
 
     def _setup_referent(
-        self, metadata: "MetaData", constraint: "ForeignKeyConstraint"
+        self, metadata: MetaData, constraint: ForeignKeyConstraint
     ) -> None:
         spec = constraint.elements[
             0
@@ -440,7 +438,7 @@ class ApplyBatchImpl:
                     schema=referent_schema,
                 )
 
-    def _create(self, op_impl: "DefaultImpl") -> None:
+    def _create(self, op_impl: DefaultImpl) -> None:
         self._transfer_elements_to_new_table()
 
         op_impl.prep_table_for_batch(self, self.table)
@@ -484,11 +482,11 @@ class ApplyBatchImpl:
         table_name: str,
         column_name: str,
         nullable: Optional[bool] = None,
-        server_default: Optional[Union["Function", str, bool]] = False,
+        server_default: Optional[Union[Function, str, bool]] = False,
         name: Optional[str] = None,
-        type_: Optional["TypeEngine"] = None,
+        type_: Optional[TypeEngine] = None,
         autoincrement: None = None,
-        comment: Union[str, "Literal[False]"] = False,
+        comment: Union[str, Literal[False]] = False,
         **kw,
     ) -> None:
         existing = self.columns[column_name]
@@ -587,9 +585,9 @@ class ApplyBatchImpl:
                             insert_after = index_cols[idx]
                     else:
                         # insert before a column that is also new
-                        insert_after = dict(
-                            (b, a) for a, b in self.add_col_ordering
-                        )[insert_before]
+                        insert_after = {
+                            b: a for a, b in self.add_col_ordering
+                        }[insert_before]
 
         if insert_before:
             self.add_col_ordering += ((colname, insert_before),)
@@ -607,7 +605,7 @@ class ApplyBatchImpl:
     def add_column(
         self,
         table_name: str,
-        column: "Column",
+        column: Column,
         insert_before: Optional[str] = None,
         insert_after: Optional[str] = None,
         **kw,
@@ -621,7 +619,7 @@ class ApplyBatchImpl:
         self.column_transfers[column.name] = {}
 
     def drop_column(
-        self, table_name: str, column: Union["ColumnClause", "Column"], **kw
+        self, table_name: str, column: Union[ColumnClause, Column], **kw
     ) -> None:
         if column.name in self.table.primary_key.columns:
             _remove_column_from_collection(
@@ -663,7 +661,7 @@ class ApplyBatchImpl:
 
         """
 
-    def add_constraint(self, const: "Constraint") -> None:
+    def add_constraint(self, const: Constraint) -> None:
         if not const.name:
             raise ValueError("Constraint must have a name")
         if isinstance(const, sql_schema.PrimaryKeyConstraint):
@@ -672,7 +670,7 @@ class ApplyBatchImpl:
 
         self.named_constraints[const.name] = const
 
-    def drop_constraint(self, const: "Constraint") -> None:
+    def drop_constraint(self, const: Constraint) -> None:
         if not const.name:
             raise ValueError("Constraint must have a name")
         try:
@@ -698,10 +696,10 @@ class ApplyBatchImpl:
                 for col in const.columns:
                     self.columns[col.name].primary_key = False
 
-    def create_index(self, idx: "Index") -> None:
+    def create_index(self, idx: Index) -> None:
         self.new_indexes[idx.name] = idx  # type: ignore[index]
 
-    def drop_index(self, idx: "Index") -> None:
+    def drop_index(self, idx: Index) -> None:
         try:
             del self.indexes[idx.name]  # type: ignore[arg-type]
         except KeyError:
index 85ffe149bb319e968270a94a16b48579efbd6071..a93596da727542a8314f2a8bb4508a396a674e52 100644 (file)
@@ -78,9 +78,9 @@ class MigrateOperation:
         """
         return {}
 
-    _mutations: FrozenSet["Rewriter"] = frozenset()
+    _mutations: FrozenSet[Rewriter] = frozenset()
 
-    def reverse(self) -> "MigrateOperation":
+    def reverse(self) -> MigrateOperation:
         raise NotImplementedError
 
     def to_diff_tuple(self) -> Tuple[Any, ...]:
@@ -105,21 +105,21 @@ class AddConstraintOp(MigrateOperation):
         return go
 
     @classmethod
-    def from_constraint(cls, constraint: "Constraint") -> "AddConstraintOp":
+    def from_constraint(cls, constraint: Constraint) -> AddConstraintOp:
         return cls.add_constraint_ops.dispatch(constraint.__visit_name__)(
             constraint
         )
 
     @abstractmethod
     def to_constraint(
-        self, migration_context: Optional["MigrationContext"] = None
-    ) -> "Constraint":
+        self, migration_context: Optional[MigrationContext] = None
+    ) -> Constraint:
         pass
 
-    def reverse(self) -> "DropConstraintOp":
+    def reverse(self) -> DropConstraintOp:
         return DropConstraintOp.from_constraint(self.to_constraint())
 
-    def to_diff_tuple(self) -> Tuple[str, "Constraint"]:
+    def to_diff_tuple(self) -> Tuple[str, Constraint]:
         return ("add_constraint", self.to_constraint())
 
 
@@ -134,7 +134,7 @@ class DropConstraintOp(MigrateOperation):
         table_name: str,
         type_: Optional[str] = None,
         schema: Optional[str] = None,
-        _reverse: Optional["AddConstraintOp"] = None,
+        _reverse: Optional[AddConstraintOp] = None,
     ) -> None:
         self.constraint_name = constraint_name
         self.table_name = table_name
@@ -142,12 +142,12 @@ class DropConstraintOp(MigrateOperation):
         self.schema = schema
         self._reverse = _reverse
 
-    def reverse(self) -> "AddConstraintOp":
+    def reverse(self) -> AddConstraintOp:
         return AddConstraintOp.from_constraint(self.to_constraint())
 
     def to_diff_tuple(
         self,
-    ) -> Tuple[str, "SchemaItem"]:
+    ) -> Tuple[str, SchemaItem]:
         if self.constraint_type == "foreignkey":
             return ("remove_fk", self.to_constraint())
         else:
@@ -156,8 +156,8 @@ class DropConstraintOp(MigrateOperation):
     @classmethod
     def from_constraint(
         cls,
-        constraint: "Constraint",
-    ) -> "DropConstraintOp":
+        constraint: Constraint,
+    ) -> DropConstraintOp:
         types = {
             "unique_constraint": "unique",
             "foreign_key_constraint": "foreignkey",
@@ -178,7 +178,7 @@ class DropConstraintOp(MigrateOperation):
 
     def to_constraint(
         self,
-    ) -> "Constraint":
+    ) -> Constraint:
 
         if self._reverse is not None:
             constraint = self._reverse.to_constraint()
@@ -197,12 +197,12 @@ class DropConstraintOp(MigrateOperation):
     @classmethod
     def drop_constraint(
         cls,
-        operations: "Operations",
+        operations: Operations,
         constraint_name: str,
         table_name: str,
         type_: Optional[str] = None,
         schema: Optional[str] = None,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         r"""Drop a constraint of the given name, typically via DROP CONSTRAINT.
 
         :param constraint_name: name of the constraint.
@@ -222,7 +222,7 @@ class DropConstraintOp(MigrateOperation):
     @classmethod
     def batch_drop_constraint(
         cls,
-        operations: "BatchOperations",
+        operations: BatchOperations,
         constraint_name: str,
         type_: Optional[str] = None,
     ) -> None:
@@ -271,7 +271,7 @@ class CreatePrimaryKeyOp(AddConstraintOp):
         self.kw = kw
 
     @classmethod
-    def from_constraint(cls, constraint: "Constraint") -> "CreatePrimaryKeyOp":
+    def from_constraint(cls, constraint: Constraint) -> CreatePrimaryKeyOp:
         constraint_table = sqla_compat._table_for_constraint(constraint)
         pk_constraint = cast("PrimaryKeyConstraint", constraint)
 
@@ -284,8 +284,8 @@ class CreatePrimaryKeyOp(AddConstraintOp):
         )
 
     def to_constraint(
-        self, migration_context: Optional["MigrationContext"] = None
-    ) -> "PrimaryKeyConstraint":
+        self, migration_context: Optional[MigrationContext] = None
+    ) -> PrimaryKeyConstraint:
         schema_obj = schemaobj.SchemaObjects(migration_context)
 
         return schema_obj.primary_key_constraint(
@@ -299,12 +299,12 @@ class CreatePrimaryKeyOp(AddConstraintOp):
     @classmethod
     def create_primary_key(
         cls,
-        operations: "Operations",
+        operations: Operations,
         constraint_name: Optional[str],
         table_name: str,
         columns: List[str],
         schema: Optional[str] = None,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         """Issue a "create primary key" instruction using the current
         migration context.
 
@@ -347,7 +347,7 @@ class CreatePrimaryKeyOp(AddConstraintOp):
     @classmethod
     def batch_create_primary_key(
         cls,
-        operations: "BatchOperations",
+        operations: BatchOperations,
         constraint_name: str,
         columns: List[str],
     ) -> None:
@@ -397,8 +397,8 @@ class CreateUniqueConstraintOp(AddConstraintOp):
 
     @classmethod
     def from_constraint(
-        cls, constraint: "Constraint"
-    ) -> "CreateUniqueConstraintOp":
+        cls, constraint: Constraint
+    ) -> CreateUniqueConstraintOp:
 
         constraint_table = sqla_compat._table_for_constraint(constraint)
 
@@ -419,8 +419,8 @@ class CreateUniqueConstraintOp(AddConstraintOp):
         )
 
     def to_constraint(
-        self, migration_context: Optional["MigrationContext"] = None
-    ) -> "UniqueConstraint":
+        self, migration_context: Optional[MigrationContext] = None
+    ) -> UniqueConstraint:
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.unique_constraint(
             self.constraint_name,
@@ -433,7 +433,7 @@ class CreateUniqueConstraintOp(AddConstraintOp):
     @classmethod
     def create_unique_constraint(
         cls,
-        operations: "Operations",
+        operations: Operations,
         constraint_name: Optional[str],
         table_name: str,
         columns: Sequence[str],
@@ -484,7 +484,7 @@ class CreateUniqueConstraintOp(AddConstraintOp):
     @classmethod
     def batch_create_unique_constraint(
         cls,
-        operations: "BatchOperations",
+        operations: BatchOperations,
         constraint_name: str,
         columns: Sequence[str],
         **kw: Any,
@@ -531,11 +531,11 @@ class CreateForeignKeyOp(AddConstraintOp):
         self.remote_cols = remote_cols
         self.kw = kw
 
-    def to_diff_tuple(self) -> Tuple[str, "ForeignKeyConstraint"]:
+    def to_diff_tuple(self) -> Tuple[str, ForeignKeyConstraint]:
         return ("add_fk", self.to_constraint())
 
     @classmethod
-    def from_constraint(cls, constraint: "Constraint") -> "CreateForeignKeyOp":
+    def from_constraint(cls, constraint: Constraint) -> CreateForeignKeyOp:
 
         fk_constraint = cast("ForeignKeyConstraint", constraint)
         kw: dict = {}
@@ -576,8 +576,8 @@ class CreateForeignKeyOp(AddConstraintOp):
         )
 
     def to_constraint(
-        self, migration_context: Optional["MigrationContext"] = None
-    ) -> "ForeignKeyConstraint":
+        self, migration_context: Optional[MigrationContext] = None
+    ) -> ForeignKeyConstraint:
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.foreign_key_constraint(
             self.constraint_name,
@@ -591,7 +591,7 @@ class CreateForeignKeyOp(AddConstraintOp):
     @classmethod
     def create_foreign_key(
         cls,
-        operations: "Operations",
+        operations: Operations,
         constraint_name: Optional[str],
         source_table: str,
         referent_table: str,
@@ -605,7 +605,7 @@ class CreateForeignKeyOp(AddConstraintOp):
         source_schema: Optional[str] = None,
         referent_schema: Optional[str] = None,
         **dialect_kw: Any,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         """Issue a "create foreign key" instruction using the
         current migration context.
 
@@ -671,7 +671,7 @@ class CreateForeignKeyOp(AddConstraintOp):
     @classmethod
     def batch_create_foreign_key(
         cls,
-        operations: "BatchOperations",
+        operations: BatchOperations,
         constraint_name: str,
         referent_table: str,
         local_cols: List[str],
@@ -736,7 +736,7 @@ class CreateCheckConstraintOp(AddConstraintOp):
         self,
         constraint_name: Optional[str],
         table_name: str,
-        condition: Union[str, "TextClause", "ColumnElement[Any]"],
+        condition: Union[str, TextClause, ColumnElement[Any]],
         schema: Optional[str] = None,
         **kw: Any,
     ) -> None:
@@ -748,8 +748,8 @@ class CreateCheckConstraintOp(AddConstraintOp):
 
     @classmethod
     def from_constraint(
-        cls, constraint: "Constraint"
-    ) -> "CreateCheckConstraintOp":
+        cls, constraint: Constraint
+    ) -> CreateCheckConstraintOp:
         constraint_table = sqla_compat._table_for_constraint(constraint)
 
         ck_constraint = cast("CheckConstraint", constraint)
@@ -763,8 +763,8 @@ class CreateCheckConstraintOp(AddConstraintOp):
         )
 
     def to_constraint(
-        self, migration_context: Optional["MigrationContext"] = None
-    ) -> "CheckConstraint":
+        self, migration_context: Optional[MigrationContext] = None
+    ) -> CheckConstraint:
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.check_constraint(
             self.constraint_name,
@@ -777,13 +777,13 @@ class CreateCheckConstraintOp(AddConstraintOp):
     @classmethod
     def create_check_constraint(
         cls,
-        operations: "Operations",
+        operations: Operations,
         constraint_name: Optional[str],
         table_name: str,
-        condition: Union[str, "BinaryExpression"],
+        condition: Union[str, BinaryExpression],
         schema: Optional[str] = None,
         **kw: Any,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         """Issue a "create check constraint" instruction using the
         current migration context.
 
@@ -830,11 +830,11 @@ class CreateCheckConstraintOp(AddConstraintOp):
     @classmethod
     def batch_create_check_constraint(
         cls,
-        operations: "BatchOperations",
+        operations: BatchOperations,
         constraint_name: str,
-        condition: "TextClause",
+        condition: TextClause,
         **kw: Any,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         """Issue a "create check constraint" instruction using the
         current batch migration context.
 
@@ -865,7 +865,7 @@ class CreateIndexOp(MigrateOperation):
         self,
         index_name: str,
         table_name: str,
-        columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]],
+        columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
         schema: Optional[str] = None,
         unique: bool = False,
         **kw: Any,
@@ -877,14 +877,14 @@ class CreateIndexOp(MigrateOperation):
         self.unique = unique
         self.kw = kw
 
-    def reverse(self) -> "DropIndexOp":
+    def reverse(self) -> DropIndexOp:
         return DropIndexOp.from_index(self.to_index())
 
-    def to_diff_tuple(self) -> Tuple[str, "Index"]:
+    def to_diff_tuple(self) -> Tuple[str, Index]:
         return ("add_index", self.to_index())
 
     @classmethod
-    def from_index(cls, index: "Index") -> "CreateIndexOp":
+    def from_index(cls, index: Index) -> CreateIndexOp:
         assert index.table is not None
         return cls(
             index.name,  # type: ignore[arg-type]
@@ -896,8 +896,8 @@ class CreateIndexOp(MigrateOperation):
         )
 
     def to_index(
-        self, migration_context: Optional["MigrationContext"] = None
-    ) -> "Index":
+        self, migration_context: Optional[MigrationContext] = None
+    ) -> Index:
         schema_obj = schemaobj.SchemaObjects(migration_context)
 
         idx = schema_obj.index(
@@ -916,11 +916,11 @@ class CreateIndexOp(MigrateOperation):
         operations: Operations,
         index_name: str,
         table_name: str,
-        columns: Sequence[Union[str, "TextClause", "Function"]],
+        columns: Sequence[Union[str, TextClause, Function]],
         schema: Optional[str] = None,
         unique: bool = False,
         **kw: Any,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         r"""Issue a "create index" instruction using the current
         migration context.
 
@@ -970,11 +970,11 @@ class CreateIndexOp(MigrateOperation):
     @classmethod
     def batch_create_index(
         cls,
-        operations: "BatchOperations",
+        operations: BatchOperations,
         index_name: str,
         columns: List[str],
         **kw: Any,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         """Issue a "create index" instruction using the
         current batch migration context.
 
@@ -1001,10 +1001,10 @@ class DropIndexOp(MigrateOperation):
 
     def __init__(
         self,
-        index_name: Union["quoted_name", str, "conv"],
+        index_name: Union[quoted_name, str, conv],
         table_name: Optional[str] = None,
         schema: Optional[str] = None,
-        _reverse: Optional["CreateIndexOp"] = None,
+        _reverse: Optional[CreateIndexOp] = None,
         **kw: Any,
     ) -> None:
         self.index_name = index_name
@@ -1013,14 +1013,14 @@ class DropIndexOp(MigrateOperation):
         self._reverse = _reverse
         self.kw = kw
 
-    def to_diff_tuple(self) -> Tuple[str, "Index"]:
+    def to_diff_tuple(self) -> Tuple[str, Index]:
         return ("remove_index", self.to_index())
 
-    def reverse(self) -> "CreateIndexOp":
+    def reverse(self) -> CreateIndexOp:
         return CreateIndexOp.from_index(self.to_index())
 
     @classmethod
-    def from_index(cls, index: "Index") -> "DropIndexOp":
+    def from_index(cls, index: Index) -> DropIndexOp:
         assert index.table is not None
         return cls(
             index.name,  # type: ignore[arg-type]
@@ -1031,8 +1031,8 @@ class DropIndexOp(MigrateOperation):
         )
 
     def to_index(
-        self, migration_context: Optional["MigrationContext"] = None
-    ) -> "Index":
+        self, migration_context: Optional[MigrationContext] = None
+    ) -> Index:
         schema_obj = schemaobj.SchemaObjects(migration_context)
 
         # need a dummy column name here since SQLAlchemy
@@ -1048,12 +1048,12 @@ class DropIndexOp(MigrateOperation):
     @classmethod
     def drop_index(
         cls,
-        operations: "Operations",
+        operations: Operations,
         index_name: str,
         table_name: Optional[str] = None,
         schema: Optional[str] = None,
         **kw: Any,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         r"""Issue a "drop index" instruction using the current
         migration context.
 
@@ -1081,7 +1081,7 @@ class DropIndexOp(MigrateOperation):
     @classmethod
     def batch_drop_index(
         cls, operations: BatchOperations, index_name: str, **kw: Any
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         """Issue a "drop index" instruction using the
         current batch migration context.
 
@@ -1107,9 +1107,9 @@ class CreateTableOp(MigrateOperation):
     def __init__(
         self,
         table_name: str,
-        columns: Sequence["SchemaItem"],
+        columns: Sequence[SchemaItem],
         schema: Optional[str] = None,
-        _namespace_metadata: Optional["MetaData"] = None,
+        _namespace_metadata: Optional[MetaData] = None,
         _constraints_included: bool = False,
         **kw: Any,
     ) -> None:
@@ -1123,18 +1123,18 @@ class CreateTableOp(MigrateOperation):
         self._namespace_metadata = _namespace_metadata
         self._constraints_included = _constraints_included
 
-    def reverse(self) -> "DropTableOp":
+    def reverse(self) -> DropTableOp:
         return DropTableOp.from_table(
             self.to_table(), _namespace_metadata=self._namespace_metadata
         )
 
-    def to_diff_tuple(self) -> Tuple[str, "Table"]:
+    def to_diff_tuple(self) -> Tuple[str, Table]:
         return ("add_table", self.to_table())
 
     @classmethod
     def from_table(
-        cls, table: "Table", _namespace_metadata: Optional["MetaData"] = None
-    ) -> "CreateTableOp":
+        cls, table: Table, _namespace_metadata: Optional[MetaData] = None
+    ) -> CreateTableOp:
         if _namespace_metadata is None:
             _namespace_metadata = table.metadata
 
@@ -1157,8 +1157,8 @@ class CreateTableOp(MigrateOperation):
         )
 
     def to_table(
-        self, migration_context: Optional["MigrationContext"] = None
-    ) -> "Table":
+        self, migration_context: Optional[MigrationContext] = None
+    ) -> Table:
         schema_obj = schemaobj.SchemaObjects(migration_context)
 
         return schema_obj.table(
@@ -1175,11 +1175,11 @@ class CreateTableOp(MigrateOperation):
     @classmethod
     def create_table(
         cls,
-        operations: "Operations",
+        operations: Operations,
         table_name: str,
-        *columns: "SchemaItem",
+        *columns: SchemaItem,
         **kw: Any,
-    ) -> "Optional[Table]":
+    ) -> Optional[Table]:
         r"""Issue a "create table" instruction using the current migration
         context.
 
@@ -1269,7 +1269,7 @@ class DropTableOp(MigrateOperation):
         table_name: str,
         schema: Optional[str] = None,
         table_kw: Optional[MutableMapping[Any, Any]] = None,
-        _reverse: Optional["CreateTableOp"] = None,
+        _reverse: Optional[CreateTableOp] = None,
     ) -> None:
         self.table_name = table_name
         self.schema = schema
@@ -1279,16 +1279,16 @@ class DropTableOp(MigrateOperation):
         self.prefixes = self.table_kw.pop("prefixes", None)
         self._reverse = _reverse
 
-    def to_diff_tuple(self) -> Tuple[str, "Table"]:
+    def to_diff_tuple(self) -> Tuple[str, Table]:
         return ("remove_table", self.to_table())
 
-    def reverse(self) -> "CreateTableOp":
+    def reverse(self) -> CreateTableOp:
         return CreateTableOp.from_table(self.to_table())
 
     @classmethod
     def from_table(
-        cls, table: "Table", _namespace_metadata: Optional["MetaData"] = None
-    ) -> "DropTableOp":
+        cls, table: Table, _namespace_metadata: Optional[MetaData] = None
+    ) -> DropTableOp:
         return cls(
             table.name,
             schema=table.schema,
@@ -1304,8 +1304,8 @@ class DropTableOp(MigrateOperation):
         )
 
     def to_table(
-        self, migration_context: Optional["MigrationContext"] = None
-    ) -> "Table":
+        self, migration_context: Optional[MigrationContext] = None
+    ) -> Table:
         if self._reverse:
             cols_and_constraints = self._reverse.columns
         else:
@@ -1329,7 +1329,7 @@ class DropTableOp(MigrateOperation):
     @classmethod
     def drop_table(
         cls,
-        operations: "Operations",
+        operations: Operations,
         table_name: str,
         schema: Optional[str] = None,
         **kw: Any,
@@ -1377,17 +1377,17 @@ class RenameTableOp(AlterTableOp):
         new_table_name: str,
         schema: Optional[str] = None,
     ) -> None:
-        super(RenameTableOp, self).__init__(old_table_name, schema=schema)
+        super().__init__(old_table_name, schema=schema)
         self.new_table_name = new_table_name
 
     @classmethod
     def rename_table(
         cls,
-        operations: "Operations",
+        operations: Operations,
         old_table_name: str,
         new_table_name: str,
         schema: Optional[str] = None,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         """Emit an ALTER TABLE to rename a table.
 
         :param old_table_name: old name.
@@ -1424,12 +1424,12 @@ class CreateTableCommentOp(AlterTableOp):
     @classmethod
     def create_table_comment(
         cls,
-        operations: "Operations",
+        operations: Operations,
         table_name: str,
         comment: Optional[str],
         existing_comment: None = None,
         schema: Optional[str] = None,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         """Emit a COMMENT ON operation to set the comment for a table.
 
         .. versionadded:: 1.0.6
@@ -1534,11 +1534,11 @@ class DropTableCommentOp(AlterTableOp):
     @classmethod
     def drop_table_comment(
         cls,
-        operations: "Operations",
+        operations: Operations,
         table_name: str,
         existing_comment: Optional[str] = None,
         schema: Optional[str] = None,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         """Issue a "drop table comment" operation to
         remove an existing comment set on a table.
 
@@ -1609,13 +1609,13 @@ class AlterColumnOp(AlterTableOp):
         existing_nullable: Optional[bool] = None,
         existing_comment: Optional[str] = None,
         modify_nullable: Optional[bool] = None,
-        modify_comment: Optional[Union[str, "Literal[False]"]] = False,
+        modify_comment: Optional[Union[str, Literal[False]]] = False,
         modify_server_default: Any = False,
         modify_name: Optional[str] = None,
         modify_type: Optional[Any] = None,
         **kw: Any,
     ) -> None:
-        super(AlterColumnOp, self).__init__(table_name, schema=schema)
+        super().__init__(table_name, schema=schema)
         self.column_name = column_name
         self.existing_type = existing_type
         self.existing_server_default = existing_server_default
@@ -1723,7 +1723,7 @@ class AlterColumnOp(AlterTableOp):
         else:
             return False
 
-    def reverse(self) -> "AlterColumnOp":
+    def reverse(self) -> AlterColumnOp:
 
         kw = self.kw.copy()
         kw["existing_type"] = self.existing_type
@@ -1740,11 +1740,11 @@ class AlterColumnOp(AlterTableOp):
             kw["modify_comment"] = self.modify_comment
 
         # TODO: make this a little simpler
-        all_keys = set(
+        all_keys = {
             m.group(1)
             for m in [re.match(r"^(?:existing_|modify_)(.+)$", k) for k in kw]
             if m
-        )
+        }
 
         for k in all_keys:
             if "modify_%s" % k in kw:
@@ -1763,21 +1763,19 @@ class AlterColumnOp(AlterTableOp):
         table_name: str,
         column_name: str,
         nullable: Optional[bool] = None,
-        comment: Optional[Union[str, "Literal[False]"]] = False,
+        comment: Optional[Union[str, Literal[False]]] = False,
         server_default: Any = False,
         new_column_name: Optional[str] = None,
-        type_: Optional[Union["TypeEngine", Type["TypeEngine"]]] = None,
-        existing_type: Optional[
-            Union["TypeEngine", Type["TypeEngine"]]
-        ] = None,
+        type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
+        existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
         existing_server_default: Optional[
-            Union[str, bool, "Identity", "Computed"]
+            Union[str, bool, Identity, Computed]
         ] = False,
         existing_nullable: Optional[bool] = None,
         existing_comment: Optional[str] = None,
         schema: Optional[str] = None,
         **kw: Any,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         r"""Issue an "alter column" instruction using the
         current migration context.
 
@@ -1891,20 +1889,18 @@ class AlterColumnOp(AlterTableOp):
         operations: BatchOperations,
         column_name: str,
         nullable: Optional[bool] = None,
-        comment: Union[str, "Literal[False]"] = False,
-        server_default: Union["Function", bool] = False,
+        comment: Union[str, Literal[False]] = False,
+        server_default: Union[Function, bool] = False,
         new_column_name: Optional[str] = None,
-        type_: Optional[Union["TypeEngine", Type["TypeEngine"]]] = None,
-        existing_type: Optional[
-            Union["TypeEngine", Type["TypeEngine"]]
-        ] = None,
+        type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
+        existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
         existing_server_default: bool = False,
         existing_nullable: None = None,
         existing_comment: None = None,
         insert_before: None = None,
         insert_after: None = None,
         **kw: Any,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         """Issue an "alter column" instruction using the current
         batch migration context.
 
@@ -1958,29 +1954,29 @@ class AddColumnOp(AlterTableOp):
     def __init__(
         self,
         table_name: str,
-        column: "Column",
+        column: Column,
         schema: Optional[str] = None,
         **kw: Any,
     ) -> None:
-        super(AddColumnOp, self).__init__(table_name, schema=schema)
+        super().__init__(table_name, schema=schema)
         self.column = column
         self.kw = kw
 
-    def reverse(self) -> "DropColumnOp":
+    def reverse(self) -> DropColumnOp:
         return DropColumnOp.from_column_and_tablename(
             self.schema, self.table_name, self.column
         )
 
     def to_diff_tuple(
         self,
-    ) -> Tuple[str, Optional[str], str, "Column"]:
+    ) -> Tuple[str, Optional[str], str, Column]:
         return ("add_column", self.schema, self.table_name, self.column)
 
-    def to_column(self) -> "Column":
+    def to_column(self) -> Column:
         return self.column
 
     @classmethod
-    def from_column(cls, col: "Column") -> "AddColumnOp":
+    def from_column(cls, col: Column) -> AddColumnOp:
         return cls(col.table.name, col, schema=col.table.schema)
 
     @classmethod
@@ -1988,18 +1984,18 @@ class AddColumnOp(AlterTableOp):
         cls,
         schema: Optional[str],
         tname: str,
-        col: "Column",
-    ) -> "AddColumnOp":
+        col: Column,
+    ) -> AddColumnOp:
         return cls(tname, col, schema=schema)
 
     @classmethod
     def add_column(
         cls,
-        operations: "Operations",
+        operations: Operations,
         table_name: str,
-        column: "Column",
+        column: Column,
         schema: Optional[str] = None,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         """Issue an "add column" instruction using the current
         migration context.
 
@@ -2055,11 +2051,11 @@ class AddColumnOp(AlterTableOp):
     @classmethod
     def batch_add_column(
         cls,
-        operations: "BatchOperations",
-        column: "Column",
+        operations: BatchOperations,
+        column: Column,
         insert_before: Optional[str] = None,
         insert_after: Optional[str] = None,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         """Issue an "add column" instruction using the current
         batch migration context.
 
@@ -2094,17 +2090,17 @@ class DropColumnOp(AlterTableOp):
         table_name: str,
         column_name: str,
         schema: Optional[str] = None,
-        _reverse: Optional["AddColumnOp"] = None,
+        _reverse: Optional[AddColumnOp] = None,
         **kw: Any,
     ) -> None:
-        super(DropColumnOp, self).__init__(table_name, schema=schema)
+        super().__init__(table_name, schema=schema)
         self.column_name = column_name
         self.kw = kw
         self._reverse = _reverse
 
     def to_diff_tuple(
         self,
-    ) -> Tuple[str, Optional[str], str, "Column"]:
+    ) -> Tuple[str, Optional[str], str, Column]:
         return (
             "remove_column",
             self.schema,
@@ -2112,7 +2108,7 @@ class DropColumnOp(AlterTableOp):
             self.to_column(),
         )
 
-    def reverse(self) -> "AddColumnOp":
+    def reverse(self) -> AddColumnOp:
         if self._reverse is None:
             raise ValueError(
                 "operation is not reversible; "
@@ -2128,8 +2124,8 @@ class DropColumnOp(AlterTableOp):
         cls,
         schema: Optional[str],
         tname: str,
-        col: "Column",
-    ) -> "DropColumnOp":
+        col: Column,
+    ) -> DropColumnOp:
         return cls(
             tname,
             col.name,
@@ -2138,8 +2134,8 @@ class DropColumnOp(AlterTableOp):
         )
 
     def to_column(
-        self, migration_context: Optional["MigrationContext"] = None
-    ) -> "Column":
+        self, migration_context: Optional[MigrationContext] = None
+    ) -> Column:
         if self._reverse is not None:
             return self._reverse.column
         schema_obj = schemaobj.SchemaObjects(migration_context)
@@ -2148,12 +2144,12 @@ class DropColumnOp(AlterTableOp):
     @classmethod
     def drop_column(
         cls,
-        operations: "Operations",
+        operations: Operations,
         table_name: str,
         column_name: str,
         schema: Optional[str] = None,
         **kw: Any,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         """Issue a "drop column" instruction using the current
         migration context.
 
@@ -2196,8 +2192,8 @@ class DropColumnOp(AlterTableOp):
 
     @classmethod
     def batch_drop_column(
-        cls, operations: "BatchOperations", column_name: str, **kw: Any
-    ) -> Optional["Table"]:
+        cls, operations: BatchOperations, column_name: str, **kw: Any
+    ) -> Optional[Table]:
         """Issue a "drop column" instruction using the current
         batch migration context.
 
@@ -2221,7 +2217,7 @@ class BulkInsertOp(MigrateOperation):
 
     def __init__(
         self,
-        table: Union["Table", "TableClause"],
+        table: Union[Table, TableClause],
         rows: List[dict],
         multiinsert: bool = True,
     ) -> None:
@@ -2233,7 +2229,7 @@ class BulkInsertOp(MigrateOperation):
     def bulk_insert(
         cls,
         operations: Operations,
-        table: Union["Table", "TableClause"],
+        table: Union[Table, TableClause],
         rows: List[dict],
         multiinsert: bool = True,
     ) -> None:
@@ -2322,7 +2318,7 @@ class ExecuteSQLOp(MigrateOperation):
 
     def __init__(
         self,
-        sqltext: Union["Update", str, "Insert", "TextClause"],
+        sqltext: Union[Update, str, Insert, TextClause],
         execution_options: None = None,
     ) -> None:
         self.sqltext = sqltext
@@ -2332,9 +2328,9 @@ class ExecuteSQLOp(MigrateOperation):
     def execute(
         cls,
         operations: Operations,
-        sqltext: Union[str, "TextClause", "Update"],
+        sqltext: Union[str, TextClause, Update],
         execution_options: None = None,
-    ) -> Optional["Table"]:
+    ) -> Optional[Table]:
         r"""Execute the given SQL using the current migration context.
 
         The given SQL can be a plain string, e.g.::
@@ -2434,12 +2430,11 @@ class OpContainer(MigrateOperation):
 
     @classmethod
     def _ops_as_diffs(
-        cls, migrations: "OpContainer"
+        cls, migrations: OpContainer
     ) -> Iterator[Tuple[Any, ...]]:
         for op in migrations.ops:
             if hasattr(op, "ops"):
-                for sub_op in cls._ops_as_diffs(cast("OpContainer", op)):
-                    yield sub_op
+                yield from cls._ops_as_diffs(cast("OpContainer", op))
             else:
                 yield op.to_diff_tuple()
 
@@ -2453,11 +2448,11 @@ class ModifyTableOps(OpContainer):
         ops: Sequence[MigrateOperation],
         schema: Optional[str] = None,
     ) -> None:
-        super(ModifyTableOps, self).__init__(ops)
+        super().__init__(ops)
         self.table_name = table_name
         self.schema = schema
 
-    def reverse(self) -> "ModifyTableOps":
+    def reverse(self) -> ModifyTableOps:
         return ModifyTableOps(
             self.table_name,
             ops=list(reversed([op.reverse() for op in self.ops])),
@@ -2480,16 +2475,16 @@ class UpgradeOps(OpContainer):
         ops: Sequence[MigrateOperation] = (),
         upgrade_token: str = "upgrades",
     ) -> None:
-        super(UpgradeOps, self).__init__(ops=ops)
+        super().__init__(ops=ops)
         self.upgrade_token = upgrade_token
 
-    def reverse_into(self, downgrade_ops: "DowngradeOps") -> "DowngradeOps":
+    def reverse_into(self, downgrade_ops: DowngradeOps) -> DowngradeOps:
         downgrade_ops.ops[:] = list(  # type:ignore[index]
             reversed([op.reverse() for op in self.ops])
         )
         return downgrade_ops
 
-    def reverse(self) -> "DowngradeOps":
+    def reverse(self) -> DowngradeOps:
         return self.reverse_into(DowngradeOps(ops=[]))
 
 
@@ -2508,7 +2503,7 @@ class DowngradeOps(OpContainer):
         ops: Sequence[MigrateOperation] = (),
         downgrade_token: str = "downgrades",
     ) -> None:
-        super(DowngradeOps, self).__init__(ops=ops)
+        super().__init__(ops=ops)
         self.downgrade_token = downgrade_token
 
     def reverse(self):
@@ -2546,8 +2541,8 @@ class MigrationScript(MigrateOperation):
     def __init__(
         self,
         rev_id: Optional[str],
-        upgrade_ops: "UpgradeOps",
-        downgrade_ops: "DowngradeOps",
+        upgrade_ops: UpgradeOps,
+        downgrade_ops: DowngradeOps,
         message: Optional[str] = None,
         imports: Set[str] = set(),
         head: Optional[str] = None,
@@ -2618,7 +2613,7 @@ class MigrationScript(MigrateOperation):
             assert isinstance(elem, DowngradeOps)
 
     @property
-    def upgrade_ops_list(self) -> List["UpgradeOps"]:
+    def upgrade_ops_list(self) -> List[UpgradeOps]:
         """A list of :class:`.UpgradeOps` instances.
 
         This is used in place of the :attr:`.MigrationScript.upgrade_ops`
@@ -2629,7 +2624,7 @@ class MigrationScript(MigrateOperation):
         return self._upgrade_ops
 
     @property
-    def downgrade_ops_list(self) -> List["DowngradeOps"]:
+    def downgrade_ops_list(self) -> List[DowngradeOps]:
         """A list of :class:`.DowngradeOps` instances.
 
         This is used in place of the :attr:`.MigrationScript.downgrade_ops`
index 6c6f9714f6924c8e904f4fc0d5e9748cc347d75e..dfda8bbeaad0d2910ef3bb1dc53229269693cef6 100644 (file)
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
 
 class SchemaObjects:
     def __init__(
-        self, migration_context: Optional["MigrationContext"] = None
+        self, migration_context: Optional[MigrationContext] = None
     ) -> None:
         self.migration_context = migration_context
 
@@ -47,7 +47,7 @@ class SchemaObjects:
         cols: Sequence[str],
         schema: Optional[str] = None,
         **dialect_kw,
-    ) -> "PrimaryKeyConstraint":
+    ) -> PrimaryKeyConstraint:
         m = self.metadata()
         columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
         t = sa_schema.Table(table_name, m, *columns, schema=schema)
@@ -71,7 +71,7 @@ class SchemaObjects:
         initially: Optional[str] = None,
         match: Optional[str] = None,
         **dialect_kw,
-    ) -> "ForeignKeyConstraint":
+    ) -> ForeignKeyConstraint:
         m = self.metadata()
         if source == referent and source_schema == referent_schema:
             t1_cols = local_cols + remote_cols
@@ -120,7 +120,7 @@ class SchemaObjects:
         local_cols: Sequence[str],
         schema: Optional[str] = None,
         **kw,
-    ) -> "UniqueConstraint":
+    ) -> UniqueConstraint:
         t = sa_schema.Table(
             source,
             self.metadata(),
@@ -138,10 +138,10 @@ class SchemaObjects:
         self,
         name: Optional[str],
         source: str,
-        condition: Union[str, "TextClause", "ColumnElement[Any]"],
+        condition: Union[str, TextClause, ColumnElement[Any]],
         schema: Optional[str] = None,
         **kw,
-    ) -> Union["CheckConstraint"]:
+    ) -> Union[CheckConstraint]:
         t = sa_schema.Table(
             source,
             self.metadata(),
@@ -182,7 +182,7 @@ class SchemaObjects:
             t.append_constraint(const)
             return const
 
-    def metadata(self) -> "MetaData":
+    def metadata(self) -> MetaData:
         kw = {}
         if (
             self.migration_context is not None
@@ -193,7 +193,7 @@ class SchemaObjects:
                 kw["naming_convention"] = mt.naming_convention
         return sa_schema.MetaData(**kw)
 
-    def table(self, name: str, *columns, **kw) -> "Table":
+    def table(self, name: str, *columns, **kw) -> Table:
         m = self.metadata()
 
         cols = [
@@ -230,17 +230,17 @@ class SchemaObjects:
             self._ensure_table_for_fk(m, f)
         return t
 
-    def column(self, name: str, type_: "TypeEngine", **kw) -> "Column":
+    def column(self, name: str, type_: TypeEngine, **kw) -> Column:
         return sa_schema.Column(name, type_, **kw)
 
     def index(
         self,
         name: str,
         tablename: Optional[str],
-        columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]],
+        columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
         schema: Optional[str] = None,
         **kw,
-    ) -> "Index":
+    ) -> Index:
         t = sa_schema.Table(
             tablename or "no_table",
             self.metadata(),
@@ -264,9 +264,7 @@ class SchemaObjects:
             sname = None
         return (sname, tname)
 
-    def _ensure_table_for_fk(
-        self, metadata: "MetaData", fk: "ForeignKey"
-    ) -> None:
+    def _ensure_table_for_fk(self, metadata: MetaData, fk: ForeignKey) -> None:
         """create a placeholder Table object for the referent of a
         ForeignKey.
 
index 6dbbcc31c323244696f75ad34b5da351b25f6fb4..44dcd72db1e3fcdfdbedf2b9c52a06295ab39938 100644 (file)
@@ -99,14 +99,14 @@ class EnvironmentContext(util.ModuleClsProxy):
 
     """
 
-    _migration_context: Optional["MigrationContext"] = None
+    _migration_context: Optional[MigrationContext] = None
 
-    config: "Config" = None  # type:ignore[assignment]
+    config: Config = None  # type:ignore[assignment]
     """An instance of :class:`.Config` representing the
     configuration file contents as well as other variables
     set programmatically within it."""
 
-    script: "ScriptDirectory" = None  # type:ignore[assignment]
+    script: ScriptDirectory = None  # type:ignore[assignment]
     """An instance of :class:`.ScriptDirectory` which provides
     programmatic access to version files within the ``versions/``
     directory.
index 677d0c74d0a2b68f5756074392f2aec83892e54e..95eb82a450a662ee1c2fd7839d9c95cc589151ea 100644 (file)
@@ -50,11 +50,11 @@ log = logging.getLogger(__name__)
 
 
 class _ProxyTransaction:
-    def __init__(self, migration_context: "MigrationContext") -> None:
+    def __init__(self, migration_context: MigrationContext) -> None:
         self.migration_context = migration_context
 
     @property
-    def _proxied_transaction(self) -> Optional["Transaction"]:
+    def _proxied_transaction(self) -> Optional[Transaction]:
         return self.migration_context._transaction
 
     def rollback(self) -> None:
@@ -69,7 +69,7 @@ class _ProxyTransaction:
         t.commit()
         self.migration_context._transaction = None
 
-    def __enter__(self) -> "_ProxyTransaction":
+    def __enter__(self) -> _ProxyTransaction:
         return self
 
     def __exit__(self, type_: None, value: None, traceback: None) -> None:
@@ -127,22 +127,22 @@ class MigrationContext:
 
     def __init__(
         self,
-        dialect: "Dialect",
-        connection: Optional["Connection"],
+        dialect: Dialect,
+        connection: Optional[Connection],
         opts: Dict[str, Any],
-        environment_context: Optional["EnvironmentContext"] = None,
+        environment_context: Optional[EnvironmentContext] = None,
     ) -> None:
         self.environment_context = environment_context
         self.opts = opts
         self.dialect = dialect
-        self.script: Optional["ScriptDirectory"] = opts.get("script")
+        self.script: Optional[ScriptDirectory] = opts.get("script")
         as_sql: bool = opts.get("as_sql", False)
         transactional_ddl = opts.get("transactional_ddl")
         self._transaction_per_migration = opts.get(
             "transaction_per_migration", False
         )
         self.on_version_apply_callbacks = opts.get("on_version_apply", ())
-        self._transaction: Optional["Transaction"] = None
+        self._transaction: Optional[Transaction] = None
 
         if as_sql:
             self.connection = cast(
@@ -215,14 +215,14 @@ class MigrationContext:
     @classmethod
     def configure(
         cls,
-        connection: Optional["Connection"] = None,
+        connection: Optional[Connection] = None,
         url: Optional[str] = None,
         dialect_name: Optional[str] = None,
-        dialect: Optional["Dialect"] = None,
-        environment_context: Optional["EnvironmentContext"] = None,
+        dialect: Optional[Dialect] = None,
+        environment_context: Optional[EnvironmentContext] = None,
         dialect_opts: Optional[Dict[str, str]] = None,
         opts: Optional[Any] = None,
-    ) -> "MigrationContext":
+    ) -> MigrationContext:
         """Create a new :class:`.MigrationContext`.
 
         This is a factory method usually called
@@ -366,7 +366,7 @@ class MigrationContext:
 
     def begin_transaction(
         self, _per_migration: bool = False
-    ) -> Union["_ProxyTransaction", ContextManager]:
+    ) -> Union[_ProxyTransaction, ContextManager]:
         """Begin a logical transaction for migration operations.
 
         This method is used within an ``env.py`` script to demarcate where
@@ -552,9 +552,7 @@ class MigrationContext:
             self.connection, self.version_table, self.version_table_schema
         )
 
-    def stamp(
-        self, script_directory: "ScriptDirectory", revision: str
-    ) -> None:
+    def stamp(self, script_directory: ScriptDirectory, revision: str) -> None:
         """Stamp the version table with a specific revision.
 
         This method calculates those branches to which the given revision
@@ -653,7 +651,7 @@ class MigrationContext:
 
     def execute(
         self,
-        sql: Union["ClauseElement", str],
+        sql: Union[ClauseElement, str],
         execution_options: Optional[dict] = None,
     ) -> None:
         """Execute a SQL construct or string statement.
@@ -667,15 +665,15 @@ class MigrationContext:
         self.impl._exec(sql, execution_options)
 
     def _stdout_connection(
-        self, connection: Optional["Connection"]
-    ) -> "MockConnection":
+        self, connection: Optional[Connection]
+    ) -> MockConnection:
         def dump(construct, *multiparams, **params):
             self.impl._exec(construct)
 
         return MockEngineStrategy.MockConnection(self.dialect, dump)
 
     @property
-    def bind(self) -> Optional["Connection"]:
+    def bind(self) -> Optional[Connection]:
         """Return the current "bind".
 
         In online mode, this is an instance of
@@ -696,7 +694,7 @@ class MigrationContext:
         return self.connection
 
     @property
-    def config(self) -> Optional["Config"]:
+    def config(self) -> Optional[Config]:
         """Return the :class:`.Config` used by the current environment,
         if any."""
 
@@ -706,7 +704,7 @@ class MigrationContext:
             return None
 
     def _compare_type(
-        self, inspector_column: "Column", metadata_column: "Column"
+        self, inspector_column: Column, metadata_column: Column
     ) -> bool:
         if self._user_compare_type is False:
             return False
@@ -726,8 +724,8 @@ class MigrationContext:
 
     def _compare_server_default(
         self,
-        inspector_column: "Column",
-        metadata_column: "Column",
+        inspector_column: Column,
+        metadata_column: Column,
         rendered_metadata_default: Optional[str],
         rendered_column_default: Optional[str],
     ) -> bool:
@@ -756,7 +754,7 @@ class MigrationContext:
 
 
 class HeadMaintainer:
-    def __init__(self, context: "MigrationContext", heads: Any) -> None:
+    def __init__(self, context: MigrationContext, heads: Any) -> None:
         self.context = context
         self.heads = set(heads)
 
@@ -820,7 +818,7 @@ class HeadMaintainer:
                 % (from_, to_, self.context.version_table, ret.rowcount)
             )
 
-    def update_to_step(self, step: Union["RevisionStep", "StampStep"]) -> None:
+    def update_to_step(self, step: Union[RevisionStep, StampStep]) -> None:
         if step.should_delete_branch(self.heads):
             vers = step.delete_version_num
             log.debug("branch delete %s", vers)
@@ -916,12 +914,12 @@ class MigrationInfo:
     from dependencies.
     """
 
-    revision_map: "RevisionMap"
+    revision_map: RevisionMap
     """The revision map inside of which this operation occurs."""
 
     def __init__(
         self,
-        revision_map: "RevisionMap",
+        revision_map: RevisionMap,
         is_upgrade: bool,
         is_stamp: bool,
         up_revisions: Union[str, Tuple[str, ...]],
@@ -1010,14 +1008,14 @@ class MigrationStep:
 
     @classmethod
     def upgrade_from_script(
-        cls, revision_map: "RevisionMap", script: "Script"
-    ) -> "RevisionStep":
+        cls, revision_map: RevisionMap, script: Script
+    ) -> RevisionStep:
         return RevisionStep(revision_map, script, True)
 
     @classmethod
     def downgrade_from_script(
-        cls, revision_map: "RevisionMap", script: "Script"
-    ) -> "RevisionStep":
+        cls, revision_map: RevisionMap, script: Script
+    ) -> RevisionStep:
         return RevisionStep(revision_map, script, False)
 
     @property
@@ -1046,7 +1044,7 @@ class MigrationStep:
 
 class RevisionStep(MigrationStep):
     def __init__(
-        self, revision_map: "RevisionMap", revision: "Script", is_upgrade: bool
+        self, revision_map: RevisionMap, revision: Script, is_upgrade: bool
     ) -> None:
         self.revision_map = revision_map
         self.revision = revision
@@ -1142,12 +1140,12 @@ class RevisionStep(MigrationStep):
         other_heads = set(heads).difference(self.from_revisions)
 
         if other_heads:
-            ancestors = set(
+            ancestors = {
                 r.revision
                 for r in self.revision_map._get_ancestor_nodes(
                     self.revision_map.get_revisions(other_heads), check=False
                 )
-            )
+            }
             from_revisions = list(
                 set(self.from_revisions).difference(ancestors)
             )
@@ -1164,12 +1162,12 @@ class RevisionStep(MigrationStep):
     def _unmerge_to_revisions(self, heads: Collection[str]) -> Tuple[str, ...]:
         other_heads = set(heads).difference([self.revision.revision])
         if other_heads:
-            ancestors = set(
+            ancestors = {
                 r.revision
                 for r in self.revision_map._get_ancestor_nodes(
                     self.revision_map.get_revisions(other_heads), check=False
                 )
-            )
+            }
             return tuple(set(self.to_revisions).difference(ancestors))
         else:
             return self.to_revisions
@@ -1253,7 +1251,7 @@ class RevisionStep(MigrationStep):
         return self.revision.revision
 
     @property
-    def info(self) -> "MigrationInfo":
+    def info(self) -> MigrationInfo:
         return MigrationInfo(
             revision_map=self.revision_map,
             up_revisions=self.revision.revision,
@@ -1270,7 +1268,7 @@ class StampStep(MigrationStep):
         to_: Optional[Union[str, Collection[str]]],
         is_upgrade: bool,
         branch_move: bool,
-        revision_map: Optional["RevisionMap"] = None,
+        revision_map: Optional[RevisionMap] = None,
     ) -> None:
         self.from_: Tuple[str, ...] = util.to_tuple(from_, default=())
         self.to_: Tuple[str, ...] = util.to_tuple(to_, default=())
@@ -1368,7 +1366,7 @@ class StampStep(MigrationStep):
         return len(self.to_) > 1
 
     @property
-    def info(self) -> "MigrationInfo":
+    def info(self) -> MigrationInfo:
         up, down = (
             (self.to_, self.from_)
             if self.is_upgrade
index cae0a2bcedbc6435ebadfb46e32e2294906a56f2..3c09cef7d165dcca4c8fdeabc4f42a718f2880ea 100644 (file)
@@ -463,7 +463,7 @@ class ScriptDirectory:
 
     def _stamp_revs(
         self, revision: _RevIdType, heads: _RevIdType
-    ) -> List["StampStep"]:
+    ) -> List[StampStep]:
         with self._catch_revision_errors(
             multiple_heads="Multiple heads are present; please specify a "
             "single target revision"
@@ -592,7 +592,7 @@ class ScriptDirectory:
         if not os.path.exists(path):
             util.status("Creating directory %s" % path, os.makedirs, path)
 
-    def _generate_create_date(self) -> "datetime.datetime":
+    def _generate_create_date(self) -> datetime.datetime:
         if self.timezone is not None:
             if tz is None:
                 raise util.CommandError(
@@ -769,7 +769,7 @@ class ScriptDirectory:
         path: str,
         rev_id: str,
         message: Optional[str],
-        create_date: "datetime.datetime",
+        create_date: datetime.datetime,
     ) -> str:
         epoch = int(create_date.timestamp())
         slug = "_".join(_slug_re.findall(message or "")).lower()
@@ -804,7 +804,7 @@ class Script(revision.Revision):
     def __init__(self, module: ModuleType, rev_id: str, path: str):
         self.module = module
         self.path = path
-        super(Script, self).__init__(
+        super().__init__(
             rev_id,
             module.down_revision,  # type: ignore[attr-defined]
             branch_labels=util.to_tuple(
@@ -964,7 +964,7 @@ class Script(revision.Revision):
             # in the immediate path
             paths = os.listdir(path)
 
-            names = set(fname.split(".")[0] for fname in paths)
+            names = {fname.split(".")[0] for fname in paths}
 
             # look for __pycache__
             if os.path.exists(os.path.join(path, "__pycache__")):
index 6e25891d47c0a3580669c1f8691688d9b58cebf5..39152969f067bb92acc174bbf758140db296ef89 100644 (file)
@@ -51,7 +51,7 @@ class RangeNotAncestorError(RevisionError):
     ) -> None:
         self.lower = lower
         self.upper = upper
-        super(RangeNotAncestorError, self).__init__(
+        super().__init__(
             "Revision %s is not an ancestor of revision %s"
             % (lower or "base", upper or "base")
         )
@@ -61,7 +61,7 @@ class MultipleHeads(RevisionError):
     def __init__(self, heads: Sequence[str], argument: Optional[str]) -> None:
         self.heads = heads
         self.argument = argument
-        super(MultipleHeads, self).__init__(
+        super().__init__(
             "Multiple heads are present for given argument '%s'; "
             "%s" % (argument, ", ".join(heads))
         )
@@ -69,7 +69,7 @@ class MultipleHeads(RevisionError):
 
 class ResolutionError(RevisionError):
     def __init__(self, message: str, argument: str) -> None:
-        super(ResolutionError, self).__init__(message)
+        super().__init__(message)
         self.argument = argument
 
 
@@ -78,7 +78,7 @@ class CycleDetected(RevisionError):
 
     def __init__(self, revisions: Sequence[str]) -> None:
         self.revisions = revisions
-        super(CycleDetected, self).__init__(
+        super().__init__(
             "%s is detected in revisions (%s)"
             % (self.kind, ", ".join(revisions))
         )
@@ -88,21 +88,21 @@ class DependencyCycleDetected(CycleDetected):
     kind = "Dependency cycle"
 
     def __init__(self, revisions: Sequence[str]) -> None:
-        super(DependencyCycleDetected, self).__init__(revisions)
+        super().__init__(revisions)
 
 
 class LoopDetected(CycleDetected):
     kind = "Self-loop"
 
     def __init__(self, revision: str) -> None:
-        super(LoopDetected, self).__init__([revision])
+        super().__init__([revision])
 
 
 class DependencyLoopDetected(DependencyCycleDetected, LoopDetected):
     kind = "Dependency self-loop"
 
     def __init__(self, revision: Sequence[str]) -> None:
-        super(DependencyLoopDetected, self).__init__(revision)
+        super().__init__(revision)
 
 
 class RevisionMap:
@@ -114,7 +114,7 @@ class RevisionMap:
 
     """
 
-    def __init__(self, generator: Callable[[], Iterable["Revision"]]) -> None:
+    def __init__(self, generator: Callable[[], Iterable[Revision]]) -> None:
         """Construct a new :class:`.RevisionMap`.
 
         :param generator: a zero-arg callable that will generate an iterable
@@ -180,10 +180,10 @@ class RevisionMap:
         # general)
         map_: _InterimRevisionMapType = sqlautil.OrderedDict()
 
-        heads: Set["Revision"] = sqlautil.OrderedSet()
-        _real_heads: Set["Revision"] = sqlautil.OrderedSet()
-        bases: Tuple["Revision", ...] = ()
-        _real_bases: Tuple["Revision", ...] = ()
+        heads: Set[Revision] = sqlautil.OrderedSet()
+        _real_heads: Set[Revision] = sqlautil.OrderedSet()
+        bases: Tuple[Revision, ...] = ()
+        _real_bases: Tuple[Revision, ...] = ()
 
         has_branch_labels = set()
         all_revisions = set()
@@ -249,10 +249,10 @@ class RevisionMap:
     def _detect_cycles(
         self,
         rev_map: _InterimRevisionMapType,
-        heads: Set["Revision"],
-        bases: Tuple["Revision", ...],
-        _real_heads: Set["Revision"],
-        _real_bases: Tuple["Revision", ...],
+        heads: Set[Revision],
+        bases: Tuple[Revision, ...],
+        _real_heads: Set[Revision],
+        _real_bases: Tuple[Revision, ...],
     ) -> None:
         if not rev_map:
             return
@@ -299,7 +299,7 @@ class RevisionMap:
             raise DependencyCycleDetected(sorted(deleted_revs))
 
     def _map_branch_labels(
-        self, revisions: Collection["Revision"], map_: _RevisionMapType
+        self, revisions: Collection[Revision], map_: _RevisionMapType
     ) -> None:
         for revision in revisions:
             if revision.branch_labels:
@@ -320,7 +320,7 @@ class RevisionMap:
                     map_[branch_label] = revision
 
     def _add_branches(
-        self, revisions: Collection["Revision"], map_: _RevisionMapType
+        self, revisions: Collection[Revision], map_: _RevisionMapType
     ) -> None:
         for revision in revisions:
             if revision.branch_labels:
@@ -344,7 +344,7 @@ class RevisionMap:
                         break
 
     def _add_depends_on(
-        self, revisions: Collection["Revision"], map_: _RevisionMapType
+        self, revisions: Collection[Revision], map_: _RevisionMapType
     ) -> None:
         """Resolve the 'dependencies' for each revision in a collection
         in terms of actual revision ids, as opposed to branch labels or other
@@ -367,7 +367,7 @@ class RevisionMap:
                 revision._resolved_dependencies = ()
 
     def _normalize_depends_on(
-        self, revisions: Collection["Revision"], map_: _RevisionMapType
+        self, revisions: Collection[Revision], map_: _RevisionMapType
     ) -> None:
         """Create a collection of "dependencies" that omits dependencies
         that are already ancestor nodes for each revision in a given
@@ -406,9 +406,7 @@ class RevisionMap:
             else:
                 revision._normalized_resolved_dependencies = ()
 
-    def add_revision(
-        self, revision: "Revision", _replace: bool = False
-    ) -> None:
+    def add_revision(self, revision: Revision, _replace: bool = False) -> None:
         """add a single revision to an existing map.
 
         This method is for single-revision use cases, it's not
@@ -602,7 +600,7 @@ class RevisionMap:
         else:
             branch_rev = None
 
-        revision: Union[Optional[Revision], "Literal[False]"]
+        revision: Union[Optional[Revision], Literal[False]]
         try:
             revision = self._revision_map[resolved_id]
         except KeyError:
index 13d29ff9528160b32f08468c20329eb5ce2c1e13..3d42f1cb4b885947755dbfeb62899e759a392dbf 100644 (file)
@@ -1,4 +1,3 @@
-#!coding: utf-8
 import importlib.machinery
 import os
 import shutil
index 26427507eddeeb9209082fd33809f891a8d508c8..ef1c3bbaf6aa8fd7c7743e88cd7232bd3aeb1ec5 100644 (file)
@@ -1,4 +1,3 @@
-# coding: utf-8
 from __future__ import annotations
 
 import configparser
index f97dd753204cb1d06c30b33caa70948aad449146..e09fbfe58d0b525e695b5b787f3a1bf6cd61bddd 100644 (file)
@@ -208,8 +208,7 @@ class AutogenTest(_ComparesFKs):
     def _flatten_diffs(self, diffs):
         for d in diffs:
             if isinstance(d, list):
-                for fd in self._flatten_diffs(d):
-                    yield fd
+                yield from self._flatten_diffs(d)
             else:
                 yield d
 
index d809dfe2d614ba423d8a3b891b3c6f5bc77f2d02..86d45a0dd558e2696ae34d277418eb43928cb8a7 100644 (file)
@@ -5,7 +5,6 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-from __future__ import absolute_import
 
 import warnings
 
index ff2687ce8abdd3aca0297bd2f59c88c78235bc3e..8203358e94f19fc1277260c7295fe9a73cdd285f 100644 (file)
@@ -30,7 +30,7 @@ _T = TypeVar("_T")
 
 class _ModuleClsMeta(type):
     def __setattr__(cls, key: str, value: Callable) -> None:
-        super(_ModuleClsMeta, cls).__setattr__(key, value)
+        super().__setattr__(key, value)
         cls._update_module_proxies(key)  # type: ignore
 
 
@@ -270,7 +270,7 @@ class Dispatcher:
         else:
             return fn_or_list  # type: ignore
 
-    def branch(self) -> "Dispatcher":
+    def branch(self) -> Dispatcher:
         """Return a copy of this dispatcher that is independently
         writable."""
 
index 54dc04fd2aa11b35027257632882c461c1f75e5f..7d9d090a774d4cc090d93c44415fbeb4d8e8523c 100644 (file)
@@ -30,7 +30,7 @@ try:
     _h, TERMWIDTH, _hp, _wp = struct.unpack("HHHH", ioctl)
     if TERMWIDTH <= 0:  # can occur if running in emacs pseudo-tty
         TERMWIDTH = None
-except (ImportError, IOError):
+except (ImportError, OSError):
     TERMWIDTH = None
 
 
@@ -42,7 +42,7 @@ def write_outstream(stream: TextIO, *text) -> None:
         t = t.decode(encoding)
         try:
             stream.write(t)
-        except IOError:
+        except OSError:
             # suppress "broken pipe" errors.
             # no known way to handle this on Python 3 however
             # as the exception is "ignored" (noisily) in TextIOWrapper.
@@ -92,7 +92,7 @@ def msg(msg: str, newline: bool = True, flush: bool = False) -> None:
         sys.stdout.flush()
 
 
-def format_as_comma(value: Optional[Union[str, "Iterable[str]"]]) -> str:
+def format_as_comma(value: Optional[Union[str, Iterable[str]]]) -> str:
     if value is None:
         return ""
     elif isinstance(value, str):
index 8046c9c46e45cc5591cf63cf57c0ada2b702fd93..23255be3bcbfc8d9627461ac9ec0f59ee44de9ca 100644 (file)
@@ -108,7 +108,7 @@ AUTOINCREMENT_DEFAULT = "auto"
 
 @contextlib.contextmanager
 def _ensure_scope_for_ddl(
-    connection: Optional["Connection"],
+    connection: Optional[Connection],
 ) -> Iterator[None]:
     try:
         in_transaction = connection.in_transaction  # type: ignore[union-attr]
@@ -137,8 +137,8 @@ def url_render_as_string(url, hide_password=True):
 
 
 def _safe_begin_connection_transaction(
-    connection: "Connection",
-) -> "Transaction":
+    connection: Connection,
+) -> Transaction:
     transaction = _get_connection_transaction(connection)
     if transaction:
         return transaction
@@ -147,7 +147,7 @@ def _safe_begin_connection_transaction(
 
 
 def _safe_commit_connection_transaction(
-    connection: "Connection",
+    connection: Connection,
 ) -> None:
     transaction = _get_connection_transaction(connection)
     if transaction:
@@ -155,14 +155,14 @@ def _safe_commit_connection_transaction(
 
 
 def _safe_rollback_connection_transaction(
-    connection: "Connection",
+    connection: Connection,
 ) -> None:
     transaction = _get_connection_transaction(connection)
     if transaction:
         transaction.rollback()
 
 
-def _get_connection_in_transaction(connection: Optional["Connection"]) -> bool:
+def _get_connection_in_transaction(connection: Optional[Connection]) -> bool:
     try:
         in_transaction = connection.in_transaction  # type: ignore
     except AttributeError:
@@ -184,8 +184,8 @@ def _copy(schema_item: _CE, **kw) -> _CE:
 
 
 def _get_connection_transaction(
-    connection: "Connection",
-) -> Optional["Transaction"]:
+    connection: Connection,
+) -> Optional[Transaction]:
     if sqla_14:
         return connection.get_transaction()
     else:
@@ -201,7 +201,7 @@ def _create_url(*arg, **kw) -> url.URL:
 
 
 def _connectable_has_table(
-    connectable: "Connection", tablename: str, schemaname: Union[str, None]
+    connectable: Connection, tablename: str, schemaname: Union[str, None]
 ) -> bool:
     if sqla_14:
         return inspect(connectable).has_table(tablename, schemaname)
@@ -244,7 +244,7 @@ def _server_default_is_identity(*server_default) -> bool:
         return any(isinstance(sd, Identity) for sd in server_default)
 
 
-def _table_for_constraint(constraint: "Constraint") -> "Table":
+def _table_for_constraint(constraint: Constraint) -> Table:
     if isinstance(constraint, ForeignKeyConstraint):
         table = constraint.parent
         assert table is not None
@@ -263,7 +263,7 @@ def _columns_for_constraint(constraint):
 
 
 def _reflect_table(
-    inspector: "Inspector", table: "Table", include_cols: None
+    inspector: Inspector, table: Table, include_cols: None
 ) -> None:
     if sqla_14:
         return inspector.reflect_table(table, None)
@@ -326,7 +326,7 @@ def _fk_spec(constraint):
     )
 
 
-def _fk_is_self_referential(constraint: "ForeignKeyConstraint") -> bool:
+def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool:
     spec = constraint.elements[0]._get_colspec()  # type: ignore[attr-defined]
     tokens = spec.split(".")
     tokens.pop(-1)  # colname
@@ -335,7 +335,7 @@ def _fk_is_self_referential(constraint: "ForeignKeyConstraint") -> bool:
     return tablekey == constraint.parent.key
 
 
-def _is_type_bound(constraint: "Constraint") -> bool:
+def _is_type_bound(constraint: Constraint) -> bool:
     # this deals with SQLAlchemy #3260, don't copy CHECK constraints
     # that will be generated by the type.
     # new feature added for #3260
@@ -351,7 +351,7 @@ def _find_columns(clause):
 
 
 def _remove_column_from_collection(
-    collection: "ColumnCollection", column: Union["Column", "ColumnClause"]
+    collection: ColumnCollection, column: Union[Column, ColumnClause]
 ) -> None:
     """remove a column from a ColumnCollection."""
 
@@ -369,8 +369,8 @@ def _remove_column_from_collection(
 
 
 def _textual_index_column(
-    table: "Table", text_: Union[str, "TextClause", "ColumnElement"]
-) -> Union["ColumnElement", "Column"]:
+    table: Table, text_: Union[str, TextClause, ColumnElement]
+) -> Union[ColumnElement, Column]:
     """a workaround for the Index construct's severe lack of flexibility"""
     if isinstance(text_, str):
         c = Column(text_, sqltypes.NULLTYPE)
@@ -384,7 +384,7 @@ def _textual_index_column(
         raise ValueError("String or text() construct expected")
 
 
-def _copy_expression(expression: _CE, target_table: "Table") -> _CE:
+def _copy_expression(expression: _CE, target_table: Table) -> _CE:
     def replace(col):
         if (
             isinstance(col, Column)
@@ -423,7 +423,7 @@ class _textual_index_element(sql.ColumnElement):
 
     __visit_name__ = "_textual_idx_element"
 
-    def __init__(self, table: "Table", text: "TextClause") -> None:
+    def __init__(self, table: Table, text: TextClause) -> None:
         self.table = table
         self.text = text
         self.key = text.text
@@ -436,7 +436,7 @@ class _textual_index_element(sql.ColumnElement):
 
 @compiles(_textual_index_element)
 def _render_textual_index_column(
-    element: _textual_index_element, compiler: "SQLCompiler", **kw
+    element: _textual_index_element, compiler: SQLCompiler, **kw
 ) -> str:
     return compiler.process(element.text, **kw)
 
@@ -447,7 +447,7 @@ class _literal_bindparam(BindParameter):
 
 @compiles(_literal_bindparam)
 def _render_literal_bindparam(
-    element: _literal_bindparam, compiler: "SQLCompiler", **kw
+    element: _literal_bindparam, compiler: SQLCompiler, **kw
 ) -> str:
     return compiler.render_literal_bindparam(element, **kw)
 
@@ -460,7 +460,7 @@ def _get_index_column_names(idx):
     return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)]
 
 
-def _column_kwargs(col: "Column") -> Mapping:
+def _column_kwargs(col: Column) -> Mapping:
     if sqla_13:
         return col.kwargs
     else:
@@ -468,7 +468,7 @@ def _column_kwargs(col: "Column") -> Mapping:
 
 
 def _get_constraint_final_name(
-    constraint: Union["Index", "Constraint"], dialect: Optional["Dialect"]
+    constraint: Union[Index, Constraint], dialect: Optional[Dialect]
 ) -> Optional[str]:
     if constraint.name is None:
         return None
@@ -508,7 +508,7 @@ def _get_constraint_final_name(
 
 
 def _constraint_is_named(
-    constraint: Union["Constraint", "Index"], dialect: Optional["Dialect"]
+    constraint: Union[Constraint, Index], dialect: Optional[Dialect]
 ) -> bool:
     if sqla_14:
         if constraint.name is None:
@@ -522,7 +522,7 @@ def _constraint_is_named(
         return constraint.name is not None
 
 
-def _is_mariadb(mysql_dialect: "Dialect") -> bool:
+def _is_mariadb(mysql_dialect: Dialect) -> bool:
     if sqla_14:
         return mysql_dialect.is_mariadb  # type: ignore[attr-defined]
     else:
@@ -536,7 +536,7 @@ def _mariadb_normalized_version_info(mysql_dialect):
     return mysql_dialect._mariadb_normalized_version_info
 
 
-def _insert_inline(table: Union["TableClause", "Table"]) -> "Insert":
+def _insert_inline(table: Union[TableClause, Table]) -> Insert:
     if sqla_14:
         return table.insert().inline()
     else:
@@ -554,5 +554,5 @@ else:
             "postgresql://", strategy="mock", executor=executor
         )
 
-    def _select(*columns, **kw) -> "Select":  # type: ignore[no-redef]
+    def _select(*columns, **kw) -> Select:  # type: ignore[no-redef]
         return sql.select(list(columns), **kw)  # type: ignore[call-overload]
index acd3603b2afbbcd9b6b4092978c436572049c21c..99e5486ffdda1b604aed2b508cdf4163f912658b 100644 (file)
@@ -243,12 +243,10 @@ nullable=True))
         autogenerate._render_migration_diffs(self.context, template_args)
         eq_(
             set(template_args["imports"].split("\n")),
-            set(
-                [
-                    "from foobar import bat",
-                    "from mypackage import my_special_import",
-                ]
-            ),
+            {
+                "from foobar import bat",
+                "from mypackage import my_special_import",
+            },
         )
 
 
index ead1a7cd9dfffedda02f6b381b8cc75a969fa7d3..86b2460ca23773befb3c0c20df2f643b305a7d50 100644 (file)
@@ -289,7 +289,7 @@ class AutogenDefaultSchemaIsNoneTest(AutogenFixtureTest, TestBase):
     __only_on__ = "sqlite"
 
     def setUp(self):
-        super(AutogenDefaultSchemaIsNoneTest, self).setUp()
+        super().setUp()
 
         # in SQLAlchemy 1.4, SQLite dialect is setting this name
         # to "main" as is the actual default schema name for SQLite.
@@ -512,13 +512,11 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
         )
 
         alter_cols = (
-            set(
-                [
-                    d[2]
-                    for d in self._flatten_diffs(diffs)
-                    if d[0].startswith("modify")
-                ]
-            )
+            {
+                d[2]
+                for d in self._flatten_diffs(diffs)
+                if d[0].startswith("modify")
+            }
             .union(
                 d[3].name
                 for d in self._flatten_diffs(diffs)
@@ -530,7 +528,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
                 if d[0] == "add_table"
             )
         )
-        eq_(alter_cols, set(["user_id", "order", "user"]))
+        eq_(alter_cols, {"user_id", "order", "user"})
 
     def test_include_name(self):
         all_names = set()
@@ -582,13 +580,11 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
         )
 
         alter_cols = (
-            set(
-                [
-                    d[2]
-                    for d in self._flatten_diffs(diffs)
-                    if d[0].startswith("modify")
-                ]
-            )
+            {
+                d[2]
+                for d in self._flatten_diffs(diffs)
+                if d[0].startswith("modify")
+            }
             .union(
                 d[3].name
                 for d in self._flatten_diffs(diffs)
index fb710991645c780941706605a63d7e8d19b1b027..68a6bd6f2a87bbc2a24d6ea65f15fbcf75a163bb 100644 (file)
@@ -552,7 +552,7 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
 
         diffs = self._fixture(m1, m2)
 
-        diffs = set(
+        diffs = {
             (
                 cmd,
                 isinstance(obj, (UniqueConstraint, Index))
@@ -560,23 +560,21 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
                 else False,
             )
             for cmd, obj in diffs
-        )
+        }
 
         if self.reports_unnamed_constraints:
             if self.reports_unique_constraints_as_indexes:
                 eq_(
                     diffs,
-                    set([("remove_index", True), ("add_constraint", False)]),
+                    {("remove_index", True), ("add_constraint", False)},
                 )
             else:
                 eq_(
                     diffs,
-                    set(
-                        [
-                            ("remove_constraint", True),
-                            ("add_constraint", False),
-                        ]
-                    ),
+                    {
+                        ("remove_constraint", True),
+                        ("add_constraint", False),
+                    },
                 )
 
     def test_remove_named_unique_index(self):
@@ -594,8 +592,8 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         diffs = self._fixture(m1, m2)
 
         if self.reports_unique_constraints:
-            diffs = set((cmd, obj.name) for cmd, obj in diffs)
-            eq_(diffs, set([("remove_index", "xidx")]))
+            diffs = {(cmd, obj.name) for cmd, obj in diffs}
+            eq_(diffs, {("remove_index", "xidx")})
         else:
             eq_(diffs, [])
 
@@ -614,11 +612,11 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         diffs = self._fixture(m1, m2)
 
         if self.reports_unique_constraints:
-            diffs = set((cmd, obj.name) for cmd, obj in diffs)
+            diffs = {(cmd, obj.name) for cmd, obj in diffs}
             if self.reports_unique_constraints_as_indexes:
-                eq_(diffs, set([("remove_index", "xidx")]))
+                eq_(diffs, {("remove_index", "xidx")})
             else:
-                eq_(diffs, set([("remove_constraint", "xidx")]))
+                eq_(diffs, {("remove_constraint", "xidx")})
         else:
             eq_(diffs, [])
 
@@ -668,9 +666,9 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
 
         eq_(diffs[0][0], "add_table")
         eq_(len(diffs), 2)
-        assert UniqueConstraint not in set(
+        assert UniqueConstraint not in {
             type(c) for c in diffs[0][1].constraints
-        )
+        }
 
         eq_(diffs[1][0], "add_index")
         d_table = diffs[0][1]
@@ -1071,9 +1069,7 @@ class AutogenerateIndexTest(AutogenFixtureTest, TestBase):
         eq_(diffs[1][0], "remove_index")
         eq_(diffs[2][0], "remove_table")
 
-        eq_(
-            set([diffs[0][1].name, diffs[1][1].name]), set(["xy_idx", "y_idx"])
-        )
+        eq_({diffs[0][1].name, diffs[1][1].name}, {"xy_idx", "y_idx"})
 
     def test_add_ix_on_table_create(self):
         m1 = MetaData()
@@ -1083,9 +1079,9 @@ class AutogenerateIndexTest(AutogenFixtureTest, TestBase):
 
         eq_(diffs[0][0], "add_table")
         eq_(len(diffs), 2)
-        assert UniqueConstraint not in set(
+        assert UniqueConstraint not in {
             type(c) for c in diffs[0][1].constraints
-        )
+        }
         eq_(diffs[1][0], "add_index")
         eq_(diffs[1][1].unique, False)
 
index 67093284d75b0f18a2815466a441ee58bebda83d..0a2fc87662541058cf0849b3b7a31beee79bbec2 100644 (file)
@@ -1296,7 +1296,7 @@ class AutogenRenderTest(TestBase):
         )
         eq_(
             self.autogen_context.imports,
-            set(["from mypackage import MySpecialType"]),
+            {"from mypackage import MySpecialType"},
         )
 
     def test_render_modify_type(self):
@@ -1833,7 +1833,7 @@ class AutogenRenderTest(TestBase):
         )
         eq_(
             self.autogen_context.imports,
-            set(["from sqlalchemy.dialects import mysql"]),
+            {"from sqlalchemy.dialects import mysql"},
         )
 
     def test_render_server_default_text(self):
index 2d29f6c6c87d88883ba0bcb45bd354aef21bccd6..e0289aa40cac7bb4774f5194d365ef80349f1a92 100644 (file)
@@ -1553,11 +1553,11 @@ class BatchRoundTripTest(TestBase):
 
         insp = inspect(self.conn)
         eq_(
-            set(
+            {
                 (ix["name"], tuple(ix["column_names"]))
                 for ix in insp.get_indexes("t_w_ix")
-            ),
-            set([("ix_data", ("data",)), ("ix_thing", ("thing",))]),
+            },
+            {("ix_data", ("data",)), ("ix_thing", ("thing",))},
         )
 
     def test_fk_points_to_me_auto(self):
@@ -2268,39 +2268,37 @@ class BatchRoundTripMySQLTest(BatchRoundTripTest):
 
     @exclusions.fails()
     def test_drop_pk_col_readd_pk_col(self):
-        super(BatchRoundTripMySQLTest, self).test_drop_pk_col_readd_pk_col()
+        super().test_drop_pk_col_readd_pk_col()
 
     @exclusions.fails()
     def test_drop_pk_col_readd_col_also_pk_const(self):
-        super(
-            BatchRoundTripMySQLTest, self
-        ).test_drop_pk_col_readd_col_also_pk_const()
+        super().test_drop_pk_col_readd_col_also_pk_const()
 
     @exclusions.fails()
     def test_rename_column_pk(self):
-        super(BatchRoundTripMySQLTest, self).test_rename_column_pk()
+        super().test_rename_column_pk()
 
     @exclusions.fails()
     def test_rename_column(self):
-        super(BatchRoundTripMySQLTest, self).test_rename_column()
+        super().test_rename_column()
 
     @exclusions.fails()
     def test_change_type(self):
-        super(BatchRoundTripMySQLTest, self).test_change_type()
+        super().test_change_type()
 
     def test_create_drop_index(self):
-        super(BatchRoundTripMySQLTest, self).test_create_drop_index()
+        super().test_create_drop_index()
 
     # fails on mariadb 10.2, succeeds on 10.3
     @exclusions.fails_if(config.requirements.mysql_check_col_name_change)
     def test_rename_column_boolean(self):
-        super(BatchRoundTripMySQLTest, self).test_rename_column_boolean()
+        super().test_rename_column_boolean()
 
     def test_change_type_boolean_to_int(self):
-        super(BatchRoundTripMySQLTest, self).test_change_type_boolean_to_int()
+        super().test_change_type_boolean_to_int()
 
     def test_change_type_int_to_boolean(self):
-        super(BatchRoundTripMySQLTest, self).test_change_type_int_to_boolean()
+        super().test_change_type_int_to_boolean()
 
 
 class BatchRoundTripPostgresqlTest(BatchRoundTripTest):
@@ -2327,34 +2325,26 @@ class BatchRoundTripPostgresqlTest(BatchRoundTripTest):
 
     @exclusions.fails()
     def test_drop_pk_col_readd_pk_col(self):
-        super(
-            BatchRoundTripPostgresqlTest, self
-        ).test_drop_pk_col_readd_pk_col()
+        super().test_drop_pk_col_readd_pk_col()
 
     @exclusions.fails()
     def test_drop_pk_col_readd_col_also_pk_const(self):
-        super(
-            BatchRoundTripPostgresqlTest, self
-        ).test_drop_pk_col_readd_col_also_pk_const()
+        super().test_drop_pk_col_readd_col_also_pk_const()
 
     @exclusions.fails()
     def test_change_type(self):
-        super(BatchRoundTripPostgresqlTest, self).test_change_type()
+        super().test_change_type()
 
     def test_create_drop_index(self):
-        super(BatchRoundTripPostgresqlTest, self).test_create_drop_index()
+        super().test_create_drop_index()
 
     @exclusions.fails()
     def test_change_type_int_to_boolean(self):
-        super(
-            BatchRoundTripPostgresqlTest, self
-        ).test_change_type_int_to_boolean()
+        super().test_change_type_int_to_boolean()
 
     @exclusions.fails()
     def test_change_type_boolean_to_int(self):
-        super(
-            BatchRoundTripPostgresqlTest, self
-        ).test_change_type_boolean_to_int()
+        super().test_change_type_boolean_to_int()
 
     def test_add_col_table_has_native_boolean(self):
         self._native_boolean_fixture()
index 0c0ce378f3a12c99a2d4063d27a102cf503371c6..e136c4e706b39dc45df8ad7e00cb2d71bafd8361 100644 (file)
@@ -224,15 +224,13 @@ class CurrentTest(_BufMixin, TestBase):
 
         yield
 
-        lines = set(
-            [
-                re.match(r"(^.\w)", elem).group(1)
-                for elem in re.split(
-                    "\n", buf.getvalue().decode("ascii", "replace").strip()
-                )
-                if elem
-            ]
-        )
+        lines = {
+            re.match(r"(^.\w)", elem).group(1)
+            for elem in re.split(
+                "\n", buf.getvalue().decode("ascii", "replace").strip()
+            )
+            if elem
+        }
 
         eq_(lines, set(revs))
 
index 7957a1b74ccbc6188d648537c8c4be33e809b019..9f3929a78abc39e7fe662cdec1b4bcf71597c5a1 100644 (file)
@@ -1,4 +1,3 @@
-#!coding: utf-8
 import os
 import tempfile
 
index d6c3a65d77d5b04ddb12756881feeca9e7bfeeed..d9c14ca45cf19cbc6ed47ef700c750ae5c60bce8 100644 (file)
@@ -1,4 +1,3 @@
-#!coding: utf-8
 import os
 import sys
 
index 9ddc12f07166046f28f77908da651bba020e7f37..de66517e5902aebdeebc1710dd717af22573d33d 100644 (file)
@@ -65,7 +65,7 @@ class EXT_ARRAY(sqla_types.TypeEngine):
         if isinstance(item_type, type):
             item_type = item_type()
         self.item_type = item_type
-        super(EXT_ARRAY, self).__init__()
+        super().__init__()
 
 
 class FOOBARTYPE(sqla_types.TypeEngine):
@@ -94,12 +94,10 @@ class ExternalDialectRenderTest(TestBase):
 
         eq_(
             self.autogen_context.imports,
-            set(
-                [
-                    "from tests.test_external_dialect "
-                    "import custom_dialect_types"
-                ]
-            ),
+            {
+                "from tests.test_external_dialect "
+                "import custom_dialect_types"
+            },
         )
 
     def test_external_nested_render_sqla_type(self):
@@ -121,12 +119,10 @@ class ExternalDialectRenderTest(TestBase):
 
         eq_(
             self.autogen_context.imports,
-            set(
-                [
-                    "from tests.test_external_dialect "
-                    "import custom_dialect_types"
-                ]
-            ),
+            {
+                "from tests.test_external_dialect "
+                "import custom_dialect_types"
+            },
         )
 
     def test_external_nested_render_external_type(self):
@@ -141,10 +137,8 @@ class ExternalDialectRenderTest(TestBase):
 
         eq_(
             self.autogen_context.imports,
-            set(
-                [
-                    "from tests.test_external_dialect "
-                    "import custom_dialect_types"
-                ]
-            ),
+            {
+                "from tests.test_external_dialect "
+                "import custom_dialect_types"
+            },
         )
index b9be5cb381798e3a8bcba5abe7a8f7cfba19b008..6a67e0be867bf9b93ba20639bab54ff326b4fc1d 100644 (file)
@@ -838,9 +838,7 @@ class PostgresqlDetectSerialTest(TestBase):
         insp = inspect(config.db)
 
         uo = ops.UpgradeOps(ops=[])
-        _compare_tables(
-            set([(None, "t")]), set([]), insp, uo, self.autogen_context
-        )
+        _compare_tables({(None, "t")}, set(), insp, uo, self.autogen_context)
         diffs = uo.as_diffs()
         tab = diffs[0][1]
 
@@ -857,8 +855,8 @@ class PostgresqlDetectSerialTest(TestBase):
         Table("t", m2, Column("x", BigInteger()))
         self.autogen_context.metadata = m2
         _compare_tables(
-            set([(None, "t")]),
-            set([(None, "t")]),
+            {(None, "t")},
+            {(None, "t")},
             insp,
             uo,
             self.autogen_context,
index d478ae1b3a5413c36ff12f08edf9ac69ef91786d..fa84d7e3789f29f77b22ba74ef4527ba1d9f9959 100644 (file)
@@ -1,5 +1,3 @@
-# coding: utf-8
-
 from contextlib import contextmanager
 import os
 import re
@@ -369,7 +367,7 @@ class CallbackEnvironmentTest(ApplyVersionsFunctionalTest):
         alembic.mock_event_listener = None
         self._env_file_fixture()
         with mock.patch("alembic.mock_event_listener", mock.Mock()) as mymock:
-            super(CallbackEnvironmentTest, self).test_steps()
+            super().test_steps()
         calls = mymock.call_args_list
         assert calls
         for call in calls:
@@ -682,7 +680,7 @@ def downgrade():
             bytes_io=True, output_encoding="utf-8"
         ) as buf:
             command.upgrade(self.cfg, self.a, sql=True)
-        assert "« S’il vous plaît…".encode("utf-8") in buf.getvalue()
+        assert "« S’il vous plaît…".encode() in buf.getvalue()
 
 
 class VersionNameTemplateTest(TestBase):
index 2cf9052ada43df8ae182b3d86d0039ca7a6457da..bedf545d9d2f9b41c865ccbf709164a8e60cf8ef 100644 (file)
@@ -1,6 +1,7 @@
 import datetime
 import os
 import re
+from unittest.mock import patch
 
 from dateutil import tz
 import sqlalchemy as sa
@@ -36,10 +37,6 @@ from alembic.testing.env import write_script
 from alembic.testing.fixtures import TestBase
 from alembic.util import CommandError
 
-try:
-    from unittest.mock import patch
-except ImportError:
-    from mock import patch  # noqa
 env, abc, def_ = None, None, None
 
 
@@ -62,7 +59,7 @@ class GeneralOrderedTests(TestBase):
         self._test_008_long_name_configurable()
 
     def _test_001_environment(self):
-        assert_set = set(["env.py", "script.py.mako", "README"])
+        assert_set = {"env.py", "script.py.mako", "README"}
         eq_(assert_set.intersection(os.listdir(env.dir)), assert_set)
 
     def _test_002_rev_ids(self):
@@ -101,7 +98,7 @@ class GeneralOrderedTests(TestBase):
         )
         eq_(script.revision, def_)
         eq_(script.down_revision, abc)
-        eq_(env.get_revision(abc).nextrev, set([def_]))
+        eq_(env.get_revision(abc).nextrev, {def_})
         assert script.module.down_revision == abc
         assert callable(script.module.upgrade)
         assert callable(script.module.downgrade)
@@ -115,7 +112,7 @@ class GeneralOrderedTests(TestBase):
         env = staging_env(create=False)
         abc_rev = env.get_revision(abc)
         def_rev = env.get_revision(def_)
-        eq_(abc_rev.nextrev, set([def_]))
+        eq_(abc_rev.nextrev, {def_})
         eq_(abc_rev.revision, abc)
         eq_(def_rev.down_revision, abc)
         eq_(env.get_heads(), [def_])
@@ -319,7 +316,7 @@ class RevisionCommandTest(TestBase):
         rev = script.get_revision(rev.revision)
         eq_(rev.down_revision, self.b)
         assert "some message" in rev.doc
-        eq_(set(script.get_heads()), set([rev.revision, self.c]))
+        eq_(set(script.get_heads()), {rev.revision, self.c})
 
     def test_create_script_missing_splice(self):
         assert_raises_message(
index 92413ac0c1b494fd2c94443f61ece960af44b064..f7ad4f08a72d4bc6828433803e74ff8197297561 100644 (file)
@@ -75,14 +75,14 @@ class RevisionPathTest(MigrationTest):
             self.e.revision,
             self.c.revision,
             [self.up_(self.d), self.up_(self.e)],
-            set([self.e.revision]),
+            {self.e.revision},
         )
 
         self._assert_upgrade(
             self.c.revision,
             None,
             [self.up_(self.a), self.up_(self.b), self.up_(self.c)],
-            set([self.c.revision]),
+            {self.c.revision},
         )
 
     def test_relative_upgrade_path(self):
@@ -90,32 +90,32 @@ class RevisionPathTest(MigrationTest):
             "+2",
             self.a.revision,
             [self.up_(self.b), self.up_(self.c)],
-            set([self.c.revision]),
+            {self.c.revision},
         )
 
         self._assert_upgrade(
-            "+1", self.a.revision, [self.up_(self.b)], set([self.b.revision])
+            "+1", self.a.revision, [self.up_(self.b)], {self.b.revision}
         )
 
         self._assert_upgrade(
             "+3",
             self.b.revision,
             [self.up_(self.c), self.up_(self.d), self.up_(self.e)],
-            set([self.e.revision]),
+            {self.e.revision},
         )
 
         self._assert_upgrade(
             "%s+2" % self.b.revision,
             self.a.revision,
             [self.up_(self.b), self.up_(self.c), self.up_(self.d)],
-            set([self.d.revision]),
+            {self.d.revision},
         )
 
         self._assert_upgrade(
             "%s-2" % self.d.revision,
             self.a.revision,
             [self.up_(self.b)],
-            set([self.b.revision]),
+            {self.b.revision},
         )
 
     def test_invalid_relative_upgrade_path(self):
@@ -142,7 +142,7 @@ class RevisionPathTest(MigrationTest):
             self.c.revision,
             self.e.revision,
             [self.down_(self.e), self.down_(self.d)],
-            set([self.c.revision]),
+            {self.c.revision},
         )
 
         self._assert_downgrade(
@@ -155,28 +155,28 @@ class RevisionPathTest(MigrationTest):
     def test_relative_downgrade_path(self):
 
         self._assert_downgrade(
-            "-1", self.c.revision, [self.down_(self.c)], set([self.b.revision])
+            "-1", self.c.revision, [self.down_(self.c)], {self.b.revision}
         )
 
         self._assert_downgrade(
             "-3",
             self.e.revision,
             [self.down_(self.e), self.down_(self.d), self.down_(self.c)],
-            set([self.b.revision]),
+            {self.b.revision},
         )
 
         self._assert_downgrade(
             "%s+2" % self.a.revision,
             self.d.revision,
             [self.down_(self.d)],
-            set([self.c.revision]),
+            {self.c.revision},
         )
 
         self._assert_downgrade(
             "%s-2" % self.c.revision,
             self.d.revision,
             [self.down_(self.d), self.down_(self.c), self.down_(self.b)],
-            set([self.a.revision]),
+            {self.a.revision},
         )
 
     def test_invalid_relative_downgrade_path(self):
@@ -287,7 +287,7 @@ class BranchedPathTest(MigrationTest):
             self.d1.revision,
             self.b.revision,
             [self.up_(self.c1), self.up_(self.d1)],
-            set([self.d1.revision]),
+            {self.d1.revision},
         )
 
     def test_upgrade_multiple_branch(self):
@@ -303,7 +303,7 @@ class BranchedPathTest(MigrationTest):
                 self.up_(self.c1),
                 self.up_(self.d1),
             ],
-            set([self.d1.revision, self.d2.revision]),
+            {self.d1.revision, self.d2.revision},
         )
 
     def test_downgrade_multiple_branch(self):
@@ -317,7 +317,7 @@ class BranchedPathTest(MigrationTest):
                 self.down_(self.c2),
                 self.down_(self.b),
             ],
-            set([self.a.revision]),
+            {self.a.revision},
         )
 
     def test_relative_upgrade(self):
@@ -326,7 +326,7 @@ class BranchedPathTest(MigrationTest):
             "c2branch@head-1",
             self.b.revision,
             [self.up_(self.c2)],
-            set([self.c2.revision]),
+            {self.c2.revision},
         )
 
     def test_relative_downgrade_baseplus2(self):
@@ -340,7 +340,7 @@ class BranchedPathTest(MigrationTest):
                 self.down_(self.d2),
                 self.down_(self.c2),
             ],
-            set([self.b.revision]),
+            {self.b.revision},
         )
 
     def test_relative_downgrade_branchplus2(self):
@@ -353,7 +353,7 @@ class BranchedPathTest(MigrationTest):
             "c2branch@base+2",
             [self.d2.revision, self.d1.revision],
             [self.down_(self.d2), self.down_(self.c2)],
-            set([self.d1.revision]),
+            {self.d1.revision},
         )
 
     def test_relative_downgrade_branchplus3(self):
@@ -362,13 +362,13 @@ class BranchedPathTest(MigrationTest):
             self.c2.revision,
             [self.d2.revision, self.d1.revision],
             [self.down_(self.d2)],
-            set([self.d1.revision, self.c2.revision]),
+            {self.d1.revision, self.c2.revision},
         )
         self._assert_downgrade(
             "c2branch@base+3",
             [self.d2.revision, self.d1.revision],
             [self.down_(self.d2)],
-            set([self.d1.revision, self.c2.revision]),
+            {self.d1.revision, self.c2.revision},
         )
 
     # Old downgrade -1 behaviour depends on order of branch upgrades.
@@ -381,7 +381,7 @@ class BranchedPathTest(MigrationTest):
                 "-1",
                 [self.d2.revision, self.d1.revision],
                 [self.down_(self.d2)],
-                set([self.d1.revision, self.c2.revision]),
+                {self.d1.revision, self.c2.revision},
             )
 
     def test_downgrade_once_order_right_unbalanced(self):
@@ -390,7 +390,7 @@ class BranchedPathTest(MigrationTest):
                 "-1",
                 [self.c2.revision, self.d1.revision],
                 [self.down_(self.c2)],
-                set([self.d1.revision]),
+                {self.d1.revision},
             )
 
     def test_downgrade_once_order_left(self):
@@ -399,7 +399,7 @@ class BranchedPathTest(MigrationTest):
                 "-1",
                 [self.d1.revision, self.d2.revision],
                 [self.down_(self.d1)],
-                set([self.d2.revision, self.c1.revision]),
+                {self.d2.revision, self.c1.revision},
             )
 
     def test_downgrade_once_order_left_unbalanced(self):
@@ -408,7 +408,7 @@ class BranchedPathTest(MigrationTest):
                 "-1",
                 [self.c1.revision, self.d2.revision],
                 [self.down_(self.c1)],
-                set([self.d2.revision]),
+                {self.d2.revision},
             )
 
     def test_downgrade_once_order_left_unbalanced_labelled(self):
@@ -416,73 +416,73 @@ class BranchedPathTest(MigrationTest):
             "c1branch@-1",
             [self.d1.revision, self.d2.revision],
             [self.down_(self.d1)],
-            set([self.c1.revision, self.d2.revision]),
+            {self.c1.revision, self.d2.revision},
         )
 
     # Captures https://github.com/sqlalchemy/alembic/issues/765
 
     def test_downgrade_relative_order_right(self):
         self._assert_downgrade(
-            "{}-1".format(self.d2.revision),
+            f"{self.d2.revision}-1",
             [self.d2.revision, self.c1.revision],
             [self.down_(self.d2)],
-            set([self.c1.revision, self.c2.revision]),
+            {self.c1.revision, self.c2.revision},
         )
 
     def test_downgrade_relative_order_left(self):
         self._assert_downgrade(
-            "{}-1".format(self.d2.revision),
+            f"{self.d2.revision}-1",
             [self.c1.revision, self.d2.revision],
             [self.down_(self.d2)],
-            set([self.c1.revision, self.c2.revision]),
+            {self.c1.revision, self.c2.revision},
         )
 
     def test_downgrade_single_branch_c1branch(self):
         """Use branch label to specify the branch to downgrade."""
         self._assert_downgrade(
-            "c1branch@{}".format(self.b.revision),
+            f"c1branch@{self.b.revision}",
             (self.c1.revision, self.d2.revision),
             [
                 self.down_(self.c1),
             ],
-            set([self.d2.revision]),
+            {self.d2.revision},
         )
 
     def test_downgrade_single_branch_c1branch_from_d1_head(self):
         """Use branch label to specify the branch (where the branch label is
         not on the head revision)."""
         self._assert_downgrade(
-            "c2branch@{}".format(self.b.revision),
+            f"c2branch@{self.b.revision}",
             (self.c1.revision, self.d2.revision),
             [
                 self.down_(self.d2),
                 self.down_(self.c2),
             ],
-            set([self.c1.revision]),
+            {self.c1.revision},
         )
 
     def test_downgrade_single_branch_c2(self):
         """Use a revision on the branch (not head) to specify the branch."""
         self._assert_downgrade(
-            "{}@{}".format(self.c2.revision, self.b.revision),
+            f"{self.c2.revision}@{self.b.revision}",
             (self.d1.revision, self.d2.revision),
             [
                 self.down_(self.d2),
                 self.down_(self.c2),
             ],
-            set([self.d1.revision]),
+            {self.d1.revision},
         )
 
     def test_downgrade_single_branch_d1(self):
         """Use the head revision to specify the branch."""
         self._assert_downgrade(
-            "{}@{}".format(self.d1.revision, self.b.revision),
+            f"{self.d1.revision}@{self.b.revision}",
             (self.d1.revision, self.d2.revision),
             [
                 self.down_(self.d1),
                 self.down_(self.c1),
             ],
-            set([self.d2.revision]),
+            {self.d2.revision},
         )
 
     def test_downgrade_relative_to_branch_head(self):
@@ -490,7 +490,7 @@ class BranchedPathTest(MigrationTest):
             "c1branch@head-1",
             (self.d1.revision, self.d2.revision),
             [self.down_(self.d1)],
-            set([self.c1.revision, self.d2.revision]),
+            {self.c1.revision, self.d2.revision},
         )
 
     def test_upgrade_other_branch_from_mergepoint(self):
@@ -500,7 +500,7 @@ class BranchedPathTest(MigrationTest):
             "c2branch@+1",
             (self.c1.revision),
             [self.up_(self.c2)],
-            set([self.c1.revision, self.c2.revision]),
+            {self.c1.revision, self.c2.revision},
         )
 
     def test_upgrade_one_branch_of_heads(self):
@@ -511,7 +511,7 @@ class BranchedPathTest(MigrationTest):
             "c2branch@+1",
             (self.c1.revision, self.c2.revision),
             [self.up_(self.d2)],
-            set([self.c1.revision, self.d2.revision]),
+            {self.c1.revision, self.d2.revision},
         )
 
     def test_ambiguous_upgrade(self):
@@ -525,13 +525,11 @@ class BranchedPathTest(MigrationTest):
 
     def test_upgrade_from_base(self):
         self._assert_upgrade(
-            "base+1", [], [self.up_(self.a)], set([self.a.revision])
+            "base+1", [], [self.up_(self.a)], {self.a.revision}
         )
 
     def test_upgrade_from_base_implicit(self):
-        self._assert_upgrade(
-            "+1", [], [self.up_(self.a)], set([self.a.revision])
-        )
+        self._assert_upgrade("+1", [], [self.up_(self.a)], {self.a.revision})
 
     def test_downgrade_minus1_to_base(self):
         self._assert_downgrade(
@@ -553,13 +551,13 @@ class BranchedPathTest(MigrationTest):
             self.c2.revision,
             [self.d1.revision, self.c2.revision],
             [],
-            set([self.d1.revision, self.c2.revision]),
+            {self.d1.revision, self.c2.revision},
         )
         self._assert_downgrade(
             self.d1.revision,
             [self.d1.revision, self.c2.revision],
             [],
-            set([self.d1.revision, self.c2.revision]),
+            {self.d1.revision, self.c2.revision},
         )
 
 
@@ -614,7 +612,7 @@ class BranchFromMergepointTest(MigrationTest):
             self.d1.revision,
             (self.d2.revision, self.b1.revision),
             [self.up_(self.c1), self.up_(self.d1)],
-            set([self.d2.revision, self.d1.revision]),
+            {self.d2.revision, self.d1.revision},
         )
 
     def test_mergepoint_to_only_one_side_downgrade(self):
@@ -623,7 +621,7 @@ class BranchFromMergepointTest(MigrationTest):
             self.b1.revision,
             (self.d2.revision, self.d1.revision),
             [self.down_(self.d1), self.down_(self.c1)],
-            set([self.d2.revision, self.b1.revision]),
+            {self.d2.revision, self.b1.revision},
         )
 
 
@@ -698,7 +696,7 @@ class BranchFrom3WayMergepointTest(MigrationTest):
             self.d1.revision,
             (self.d3.revision, self.d2.revision, self.b1.revision),
             [self.up_(self.c1), self.up_(self.d1)],
-            set([self.d3.revision, self.d2.revision, self.d1.revision]),
+            {self.d3.revision, self.d2.revision, self.d1.revision},
         )
 
     def test_mergepoint_to_only_one_side_downgrade(self):
@@ -706,7 +704,7 @@ class BranchFrom3WayMergepointTest(MigrationTest):
             self.b1.revision,
             (self.d3.revision, self.d2.revision, self.d1.revision),
             [self.down_(self.d1), self.down_(self.c1)],
-            set([self.d3.revision, self.d2.revision, self.b1.revision]),
+            {self.d3.revision, self.d2.revision, self.b1.revision},
         )
 
     def test_mergepoint_to_two_sides_upgrade(self):
@@ -716,7 +714,7 @@ class BranchFrom3WayMergepointTest(MigrationTest):
             (self.d3.revision, self.b2.revision, self.b1.revision),
             [self.up_(self.c2), self.up_(self.c1), self.up_(self.d1)],
             # this will merge b2 and b1 into d1
-            set([self.d3.revision, self.d1.revision]),
+            {self.d3.revision, self.d1.revision},
         )
 
         # but then!  b2 will break out again if we keep going with it
@@ -724,7 +722,7 @@ class BranchFrom3WayMergepointTest(MigrationTest):
             self.d2.revision,
             (self.d3.revision, self.d1.revision),
             [self.up_(self.d2)],
-            set([self.d3.revision, self.d2.revision, self.d1.revision]),
+            {self.d3.revision, self.d2.revision, self.d1.revision},
         )
 
 
@@ -916,14 +914,14 @@ class DependsOnBranchTestOne(MigrationTest):
         heads = [self.c2.revision, self.d1.revision]
         head = HeadMaintainer(mock.Mock(), heads)
         head.update_to_step(self.down_(self.d1))
-        eq_(head.heads, set([self.c2.revision]))
+        eq_(head.heads, {self.c2.revision})
 
     def test_stamp_across_dependency(self):
         heads = [self.e1.revision, self.c2.revision]
         head = HeadMaintainer(mock.Mock(), heads)
         for step in self.env._stamp_revs(self.b1.revision, heads):
             head.update_to_step(step)
-        eq_(head.heads, set([self.b1.revision]))
+        eq_(head.heads, {self.b1.revision})
 
 
 class DependsOnBranchTestTwo(MigrationTest):
@@ -1010,15 +1008,13 @@ class DependsOnBranchTestTwo(MigrationTest):
             self.b2.revision,
             heads,
             [self.down_(self.bmerge)],
-            set(
-                [
-                    self.amerge.revision,
-                    self.b1.revision,
-                    self.cmerge.revision,
-                    # b2 isn't here, but d1 is, which implies b2. OK!
-                    self.d1.revision,
-                ]
-            ),
+            {
+                self.amerge.revision,
+                self.b1.revision,
+                self.cmerge.revision,
+                # b2 isn't here, but d1 is, which implies b2. OK!
+                self.d1.revision,
+            },
         )
 
         # start with those heads..
@@ -1034,15 +1030,13 @@ class DependsOnBranchTestTwo(MigrationTest):
             "d1@base",
             heads,
             [self.down_(self.d1)],
-            set(
-                [
-                    self.amerge.revision,
-                    self.b1.revision,
-                    # b2 has to be INSERTed, because it was implied by d1
-                    self.b2.revision,
-                    self.cmerge.revision,
-                ]
-            ),
+            {
+                self.amerge.revision,
+                self.b1.revision,
+                # b2 has to be INSERTed, because it was implied by d1
+                self.b2.revision,
+                self.cmerge.revision,
+            },
         )
 
         # start with those heads ...
@@ -1071,7 +1065,7 @@ class DependsOnBranchTestTwo(MigrationTest):
                 self.down_(self.c2),
                 self.down_(self.c3),
             ],
-            set([]),
+            set(),
         )
 
 
@@ -1122,7 +1116,7 @@ class DependsOnBranchTestThree(MigrationTest):
             "b1",
             ["a3", "b2"],
             [self.down_(self.b2)],
-            set(["a3"]),  # we have b1 also, which is implied by a3
+            {"a3"},  # we have b1 also, which is implied by a3
         )
 
 
@@ -1145,7 +1139,7 @@ class DependsOnOwnDownrevTest(MigrationTest):
             self.a2.revision,
             None,
             [self.up_(self.a1), self.up_(self.a2)],
-            set(["a2"]),
+            {"a2"},
         )
 
     def test_traverse_down(self):
@@ -1153,7 +1147,7 @@ class DependsOnOwnDownrevTest(MigrationTest):
             self.a1.revision,
             self.a2.revision,
             [self.down_(self.a2)],
-            set(["a1"]),
+            {"a1"},
         )
 
 
@@ -1190,7 +1184,7 @@ class DependsOnBranchTestFour(MigrationTest):
             heads,
             [self.down_(self.b4)],
             # a3 isn't here, because b3 still implies a3
-            set([self.b3.revision]),
+            {self.b3.revision},
         )
 
 
@@ -1239,7 +1233,7 @@ class DependsOnBranchLabelTest(MigrationTest):
                 self.up_(self.b2),
                 self.up_(self.c2),
             ],
-            set([self.c2.revision]),
+            {self.c2.revision},
         )
 
 
@@ -1276,8 +1270,8 @@ class ForestTest(MigrationTest):
         revs = self.env._stamp_revs("heads", ())
         eq_(len(revs), 2)
         eq_(
-            set(r.to_revisions for r in revs),
-            set([(self.b1.revision,), (self.b2.revision,)]),
+            {r.to_revisions for r in revs},
+            {(self.b1.revision,), (self.b2.revision,)},
         )
 
     def test_stamp_to_heads_no_moves_needed(self):
@@ -1448,19 +1442,19 @@ class BranchedPathTestCrossDependencies(MigrationTest):
         """c2branch depends on c1branch so can be taken down on its own.
         Current behaviour also takes down the dependency unnecessarily."""
         self._assert_downgrade(
-            "c2branch@{}".format(self.b.revision),
+            f"c2branch@{self.b.revision}",
             (self.d1.revision, self.d2.revision),
             [
                 self.down_(self.d2),
                 self.down_(self.c2),
             ],
-            set([self.d1.revision]),
+            {self.d1.revision},
         )
 
     def test_downgrade_branch_dependency(self):
         """c2branch depends on c1branch so taking down c1branch requires taking
         down both"""
-        destination = "c1branch@{}".format(self.b.revision)
+        destination = f"c1branch@{self.b.revision}"
         source = self.d1.revision, self.d2.revision
         revs = self.env._downgrade_revs(destination, source)
         # Drops c1, d1 as requested, also drops d2 due to dependence on d1.
@@ -1483,4 +1477,4 @@ class BranchedPathTestCrossDependencies(MigrationTest):
         head = HeadMaintainer(mock.Mock(), heads)
         for rev in revs:
             head.update_to_step(rev)
-        eq_(head.heads, set([self.c2.revision]))
+        eq_(head.heads, {self.c2.revision})