]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
pure black run + flake8
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 6 Jan 2019 17:37:53 +0000 (12:37 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 6 Jan 2019 18:22:59 +0000 (13:22 -0500)
run black -l 79 against source code, set up for
full flake8 testing.

Change-Id: I4108e1274d49894b9898ec5bd3a1147933a473d7

85 files changed:
alembic/__init__.py
alembic/autogenerate/__init__.py
alembic/autogenerate/api.py
alembic/autogenerate/compare.py
alembic/autogenerate/render.py
alembic/autogenerate/rewriter.py
alembic/command.py
alembic/config.py
alembic/ddl/base.py
alembic/ddl/impl.py
alembic/ddl/mssql.py
alembic/ddl/mysql.py
alembic/ddl/oracle.py
alembic/ddl/postgresql.py
alembic/ddl/sqlite.py
alembic/op.py
alembic/operations/__init__.py
alembic/operations/base.py
alembic/operations/batch.py
alembic/operations/ops.py
alembic/operations/schemaobj.py
alembic/operations/toimpl.py
alembic/runtime/environment.py
alembic/runtime/migration.py
alembic/script/__init__.py
alembic/script/base.py
alembic/script/revision.py
alembic/templates/generic/env.py
alembic/templates/multidb/env.py
alembic/templates/pylons/env.py
alembic/testing/__init__.py
alembic/testing/assertions.py
alembic/testing/compat.py
alembic/testing/config.py
alembic/testing/engines.py
alembic/testing/env.py
alembic/testing/exclusions.py
alembic/testing/fixtures.py
alembic/testing/mock.py
alembic/testing/plugin/bootstrap.py
alembic/testing/plugin/noseplugin.py
alembic/testing/plugin/plugin_base.py
alembic/testing/plugin/pytestplugin.py
alembic/testing/provision.py
alembic/testing/requirements.py
alembic/testing/runner.py
alembic/testing/util.py
alembic/testing/warnings.py
alembic/util/__init__.py
alembic/util/compat.py
alembic/util/langhelpers.py
alembic/util/messaging.py
alembic/util/pyfiles.py
alembic/util/sqla_compat.py
setup.cfg
setup.py
tests/_autogen_fixtures.py
tests/_large_map.py
tests/conftest.py
tests/requirements.py
tests/test_autogen_composition.py
tests/test_autogen_diffs.py
tests/test_autogen_fks.py
tests/test_autogen_indexes.py
tests/test_autogen_render.py
tests/test_batch.py
tests/test_bulk_insert.py
tests/test_command.py
tests/test_config.py
tests/test_environment.py
tests/test_external_dialect.py
tests/test_mssql.py
tests/test_mysql.py
tests/test_offline_environment.py
tests/test_op.py
tests/test_op_naming_convention.py
tests/test_oracle.py
tests/test_postgresql.py
tests/test_revision.py
tests/test_script_consumption.py
tests/test_script_production.py
tests/test_sqlite.py
tests/test_version_table.py
tests/test_version_traversal.py
tox.ini

index 3432a885ef0ef790a492a4de46a059d451bacfa0..a7a2845102138d7675cd564d8e75cc51787fa0b2 100644 (file)
@@ -1,6 +1,6 @@
 from os import path
 
-__version__ = '1.0.6'
+__version__ = "1.0.6"
 
 package_dir = path.abspath(path.dirname(__file__))
 
@@ -11,5 +11,6 @@ from . import context  # noqa
 import sys
 from .runtime import environment
 from .runtime import migration
-sys.modules['alembic.migration'] = migration
-sys.modules['alembic.environment'] = environment
+
+sys.modules["alembic.migration"] = migration
+sys.modules["alembic.environment"] = environment
index 142f55d04fee32dc1af56f86b7e0dc13646f448b..ad3e6e1b4ad62067880faa84197260268388cbb2 100644 (file)
@@ -1,8 +1,10 @@
-from .api import ( # noqa
-    compare_metadata, _render_migration_diffs,
-    produce_migrations, render_python_code,
-    RevisionContext
-    )
+from .api import (  # noqa
+    compare_metadata,
+    _render_migration_diffs,
+    produce_migrations,
+    render_python_code,
+    RevisionContext,
+)
 from .compare import _produce_net_changes, comparators  # noqa
 from .render import render_op_text, renderers  # noqa
-from .rewriter import Rewriter  # noqa
\ No newline at end of file
+from .rewriter import Rewriter  # noqa
index 15b5b6b0bcb6e141abb858574c24a511be209f7c..cfd6e86250408be1fdf47ea8615cd8b185b57b66 100644 (file)
@@ -136,8 +136,8 @@ def produce_migrations(context, metadata):
 
 def render_python_code(
     up_or_down_op,
-    sqlalchemy_module_prefix='sa.',
-    alembic_module_prefix='op.',
+    sqlalchemy_module_prefix="sa.",
+    alembic_module_prefix="op.",
     render_as_batch=False,
     imports=(),
     render_item=None,
@@ -150,16 +150,17 @@ def render_python_code(
 
     """
     opts = {
-        'sqlalchemy_module_prefix': sqlalchemy_module_prefix,
-        'alembic_module_prefix': alembic_module_prefix,
-        'render_item': render_item,
-        'render_as_batch': render_as_batch,
+        "sqlalchemy_module_prefix": sqlalchemy_module_prefix,
+        "alembic_module_prefix": alembic_module_prefix,
+        "render_item": render_item,
+        "render_as_batch": render_as_batch,
     }
 
     autogen_context = AutogenContext(None, opts=opts)
     autogen_context.imports = set(imports)
-    return render._indent(render._render_cmd_body(
-        up_or_down_op, autogen_context))
+    return render._indent(
+        render._render_cmd_body(up_or_down_op, autogen_context)
+    )
 
 
 def _render_migration_diffs(context, template_args):
@@ -240,42 +241,53 @@ class AutogenContext(object):
     """The :class:`.MigrationContext` established by the ``env.py`` script."""
 
     def __init__(
-            self, migration_context, metadata=None,
-            opts=None, autogenerate=True):
-
-        if autogenerate and \
-                migration_context is not None and migration_context.as_sql:
+        self, migration_context, metadata=None, opts=None, autogenerate=True
+    ):
+
+        if (
+            autogenerate
+            and migration_context is not None
+            and migration_context.as_sql
+        ):
             raise util.CommandError(
                 "autogenerate can't use as_sql=True as it prevents querying "
-                "the database for schema information")
+                "the database for schema information"
+            )
 
         if opts is None:
             opts = migration_context.opts
 
-        self.metadata = metadata = opts.get('target_metadata', None) \
-            if metadata is None else metadata
+        self.metadata = metadata = (
+            opts.get("target_metadata", None) if metadata is None else metadata
+        )
 
-        if autogenerate and metadata is None and \
-                migration_context is not None and \
-                migration_context.script is not None:
+        if (
+            autogenerate
+            and metadata is None
+            and migration_context is not None
+            and migration_context.script is not None
+        ):
             raise util.CommandError(
                 "Can't proceed with --autogenerate option; environment "
                 "script %s does not provide "
-                "a MetaData object or sequence of objects to the context." % (
-                    migration_context.script.env_py_location
-                ))
+                "a MetaData object or sequence of objects to the context."
+                % (migration_context.script.env_py_location)
+            )
 
-        include_symbol = opts.get('include_symbol', None)
-        include_object = opts.get('include_object', None)
+        include_symbol = opts.get("include_symbol", None)
+        include_object = opts.get("include_object", None)
 
         object_filters = []
         if include_symbol:
+
             def include_symbol_filter(
-                    object, name, type_, reflected, compare_to):
+                object, name, type_, reflected, compare_to
+            ):
                 if type_ == "table":
                     return include_symbol(name, object.schema)
                 else:
                     return True
+
             object_filters.append(include_symbol_filter)
         if include_object:
             object_filters.append(include_object)
@@ -357,8 +369,8 @@ class AutogenContext(object):
             if intersect:
                 raise ValueError(
                     "Duplicate table keys across multiple "
-                    "MetaData objects: %s" %
-                    (", ".join('"%s"' % key for key in sorted(intersect)))
+                    "MetaData objects: %s"
+                    (", ".join('"%s"' % key for key in sorted(intersect)))
                 )
 
             result.update(m.tables)
@@ -369,26 +381,29 @@ class RevisionContext(object):
     """Maintains configuration and state that's specific to a revision
     file generation operation."""
 
-    def __init__(self, config, script_directory, command_args,
-                 process_revision_directives=None):
+    def __init__(
+        self,
+        config,
+        script_directory,
+        command_args,
+        process_revision_directives=None,
+    ):
         self.config = config
         self.script_directory = script_directory
         self.command_args = command_args
         self.process_revision_directives = process_revision_directives
         self.template_args = {
-            'config': config  # Let templates use config for
-                              # e.g. multiple databases
+            "config": config  # Let templates use config for
+            # e.g. multiple databases
         }
-        self.generated_revisions = [
-            self._default_revision()
-        ]
+        self.generated_revisions = [self._default_revision()]
 
     def _to_script(self, migration_script):
         template_args = {}
         for k, v in self.template_args.items():
             template_args.setdefault(k, v)
 
-        if getattr(migration_script, '_needs_render', False):
+        if getattr(migration_script, "_needs_render", False):
             autogen_context = self._last_autogen_context
 
             # clear out existing imports if we are doing multiple
@@ -409,7 +424,8 @@ class RevisionContext(object):
             branch_labels=migration_script.branch_label,
             version_path=migration_script.version_path,
             depends_on=migration_script.depends_on,
-            **template_args)
+            **template_args
+        )
 
     def run_autogenerate(self, rev, migration_context):
         self._run_environment(rev, migration_context, True)
@@ -419,21 +435,24 @@ class RevisionContext(object):
 
     def _run_environment(self, rev, migration_context, autogenerate):
         if autogenerate:
-            if self.command_args['sql']:
+            if self.command_args["sql"]:
                 raise util.CommandError(
-                    "Using --sql with --autogenerate does not make any sense")
-            if set(self.script_directory.get_revisions(rev)) != \
-                    set(self.script_directory.get_revisions("heads")):
+                    "Using --sql with --autogenerate does not make any sense"
+                )
+            if set(self.script_directory.get_revisions(rev)) != set(
+                self.script_directory.get_revisions("heads")
+            ):
                 raise util.CommandError("Target database is not up to date.")
 
-        upgrade_token = migration_context.opts['upgrade_token']
-        downgrade_token = migration_context.opts['downgrade_token']
+        upgrade_token = migration_context.opts["upgrade_token"]
+        downgrade_token = migration_context.opts["downgrade_token"]
 
         migration_script = self.generated_revisions[-1]
-        if not getattr(migration_script, '_needs_render', False):
+        if not getattr(migration_script, "_needs_render", False):
             migration_script.upgrade_ops_list[-1].upgrade_token = upgrade_token
-            migration_script.downgrade_ops_list[-1].downgrade_token = \
-                downgrade_token
+            migration_script.downgrade_ops_list[
+                -1
+            ].downgrade_token = downgrade_token
             migration_script._needs_render = True
         else:
             migration_script._upgrade_ops.append(
@@ -443,18 +462,21 @@ class RevisionContext(object):
                 ops.DowngradeOps([], downgrade_token=downgrade_token)
             )
 
-        self._last_autogen_context = autogen_context = \
-            AutogenContext(migration_context, autogenerate=autogenerate)
+        self._last_autogen_context = autogen_context = AutogenContext(
+            migration_context, autogenerate=autogenerate
+        )
 
         if autogenerate:
             compare._populate_migration_script(
-                autogen_context, migration_script)
+                autogen_context, migration_script
+            )
 
         if self.process_revision_directives:
             self.process_revision_directives(
-                migration_context, rev, self.generated_revisions)
+                migration_context, rev, self.generated_revisions
+            )
 
-        hook = migration_context.opts['process_revision_directives']
+        hook = migration_context.opts["process_revision_directives"]
         if hook:
             hook(migration_context, rev, self.generated_revisions)
 
@@ -463,15 +485,15 @@ class RevisionContext(object):
 
     def _default_revision(self):
         op = ops.MigrationScript(
-            rev_id=self.command_args['rev_id'] or util.rev_id(),
-            message=self.command_args['message'],
+            rev_id=self.command_args["rev_id"] or util.rev_id(),
+            message=self.command_args["message"],
             upgrade_ops=ops.UpgradeOps([]),
             downgrade_ops=ops.DowngradeOps([]),
-            head=self.command_args['head'],
-            splice=self.command_args['splice'],
-            branch_label=self.command_args['branch_label'],
-            version_path=self.command_args['version_path'],
-            depends_on=self.command_args['depends_on']
+            head=self.command_args["head"],
+            splice=self.command_args["splice"],
+            branch_label=self.command_args["branch_label"],
+            version_path=self.command_args["version_path"],
+            depends_on=self.command_args["depends_on"],
         )
         return op
 
index 8b416475c0f446cd29547c1f1a2d7d1ce80f37d2..7ff8be65437b62f3b126f2e1eae934f785e22f65 100644 (file)
@@ -29,7 +29,7 @@ comparators = util.Dispatcher(uselist=True)
 def _produce_net_changes(autogen_context, upgrade_ops):
 
     connection = autogen_context.connection
-    include_schemas = autogen_context.opts.get('include_schemas', False)
+    include_schemas = autogen_context.opts.get("include_schemas", False)
 
     inspector = Inspector.from_engine(connection)
 
@@ -55,8 +55,9 @@ def _autogen_for_tables(autogen_context, upgrade_ops, schemas):
 
     conn_table_names = set()
 
-    version_table_schema = \
+    version_table_schema = (
         autogen_context.migration_context.version_table_schema
+    )
     version_table = autogen_context.migration_context.version_table
 
     for s in schemas:
@@ -71,12 +72,22 @@ def _autogen_for_tables(autogen_context, upgrade_ops, schemas):
         [(table.schema, table.name) for table in autogen_context.sorted_tables]
     ).difference([(version_table_schema, version_table)])
 
-    _compare_tables(conn_table_names, metadata_table_names,
-                    inspector, upgrade_ops, autogen_context)
+    _compare_tables(
+        conn_table_names,
+        metadata_table_names,
+        inspector,
+        upgrade_ops,
+        autogen_context,
+    )
 
 
-def _compare_tables(conn_table_names, metadata_table_names,
-                    inspector, upgrade_ops, autogen_context):
+def _compare_tables(
+    conn_table_names,
+    metadata_table_names,
+    inspector,
+    upgrade_ops,
+    autogen_context,
+):
 
     default_schema = inspector.bind.dialect.default_schema_name
 
@@ -85,10 +96,12 @@ def _compare_tables(conn_table_names, metadata_table_names,
     # of table names from local metadata that also have "None" if schema
     # == default_schema_name.  Most setups will be like this anyway but
     # some are not (see #170)
-    metadata_table_names_no_dflt_schema = OrderedSet([
-        (schema if schema != default_schema else None, tname)
-        for schema, tname in metadata_table_names
-    ])
+    metadata_table_names_no_dflt_schema = OrderedSet(
+        [
+            (schema if schema != default_schema else None, tname)
+            for schema, tname in metadata_table_names
+        ]
+    )
 
     # to adjust for the MetaData collection storing the tables either
     # as "schemaname.tablename" or just "tablename", create a new lookup
@@ -97,27 +110,34 @@ def _compare_tables(conn_table_names, metadata_table_names,
         (
             no_dflt_schema,
             autogen_context.table_key_to_table[
-                sa_schema._get_table_key(tname, schema)]
+                sa_schema._get_table_key(tname, schema)
+            ],
         )
         for no_dflt_schema, (schema, tname) in zip(
-            metadata_table_names_no_dflt_schema,
-            metadata_table_names)
+            metadata_table_names_no_dflt_schema, metadata_table_names
+        )
     )
     metadata_table_names = metadata_table_names_no_dflt_schema
 
     for s, tname in metadata_table_names.difference(conn_table_names):
-        name = '%s.%s' % (s, tname) if s else tname
+        name = "%s.%s" % (s, tname) if s else tname
         metadata_table = tname_to_table[(s, tname)]
         if autogen_context.run_filters(
-                metadata_table, tname, "table", False, None):
+            metadata_table, tname, "table", False, None
+        ):
             upgrade_ops.ops.append(
-                ops.CreateTableOp.from_table(metadata_table))
+                ops.CreateTableOp.from_table(metadata_table)
+            )
             log.info("Detected added table %r", name)
             modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
 
             comparators.dispatch("table")(
-                autogen_context, modify_table_ops,
-                s, tname, None, metadata_table
+                autogen_context,
+                modify_table_ops,
+                s,
+                tname,
+                None,
+                metadata_table,
             )
             if not modify_table_ops.is_empty():
                 upgrade_ops.ops.append(modify_table_ops)
@@ -132,23 +152,22 @@ def _compare_tables(conn_table_names, metadata_table_names,
             event.listen(
                 t,
                 "column_reflect",
-                autogen_context.migration_context.impl.
-                _compat_autogen_column_reflect(inspector))
+                autogen_context.migration_context.impl._compat_autogen_column_reflect(
+                    inspector
+                ),
+            )
             inspector.reflecttable(t, None)
         if autogen_context.run_filters(t, tname, "table", True, None):
 
             modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
 
             comparators.dispatch("table")(
-                autogen_context, modify_table_ops,
-                s, tname, t, None
+                autogen_context, modify_table_ops, s, tname, t, None
             )
             if not modify_table_ops.is_empty():
                 upgrade_ops.ops.append(modify_table_ops)
 
-            upgrade_ops.ops.append(
-                ops.DropTableOp.from_table(t)
-            )
+            upgrade_ops.ops.append(ops.DropTableOp.from_table(t))
             log.info("Detected removed table %r", name)
 
     existing_tables = conn_table_names.intersection(metadata_table_names)
@@ -163,31 +182,41 @@ def _compare_tables(conn_table_names, metadata_table_names,
             event.listen(
                 t,
                 "column_reflect",
-                autogen_context.migration_context.impl.
-                _compat_autogen_column_reflect(inspector))
+                autogen_context.migration_context.impl._compat_autogen_column_reflect(
+                    inspector
+                ),
+            )
             inspector.reflecttable(t, None)
         conn_column_info[(s, tname)] = t
 
-    for s, tname in sorted(existing_tables, key=lambda x: (x[0] or '', x[1])):
+    for s, tname in sorted(existing_tables, key=lambda x: (x[0] or "", x[1])):
         s = s or None
-        name = '%s.%s' % (s, tname) if s else tname
+        name = "%s.%s" % (s, tname) if s else tname
         metadata_table = tname_to_table[(s, tname)]
         conn_table = existing_metadata.tables[name]
 
         if autogen_context.run_filters(
-                metadata_table, tname, "table", False,
-                conn_table):
+            metadata_table, tname, "table", False, conn_table
+        ):
 
             modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
             with _compare_columns(
-                s, tname,
+                s,
+                tname,
                 conn_table,
                 metadata_table,
-                    modify_table_ops, autogen_context, inspector):
+                modify_table_ops,
+                autogen_context,
+                inspector,
+            ):
 
                 comparators.dispatch("table")(
-                    autogen_context, modify_table_ops,
-                    s, tname, conn_table, metadata_table
+                    autogen_context,
+                    modify_table_ops,
+                    s,
+                    tname,
+                    conn_table,
+                    metadata_table,
                 )
 
             if not modify_table_ops.is_empty():
@@ -196,41 +225,41 @@ def _compare_tables(conn_table_names, metadata_table_names,
 
 def _make_index(params, conn_table):
     ix = sa_schema.Index(
-        params['name'],
-        *[conn_table.c[cname] for cname in params['column_names']],
-        unique=params['unique']
+        params["name"],
+        *[conn_table.c[cname] for cname in params["column_names"]],
+        unique=params["unique"]
     )
-    if 'duplicates_constraint' in params:
-        ix.info['duplicates_constraint'] = params['duplicates_constraint']
+    if "duplicates_constraint" in params:
+        ix.info["duplicates_constraint"] = params["duplicates_constraint"]
     return ix
 
 
 def _make_unique_constraint(params, conn_table):
     uq = sa_schema.UniqueConstraint(
-        *[conn_table.c[cname] for cname in params['column_names']],
-        name=params['name']
+        *[conn_table.c[cname] for cname in params["column_names"]],
+        name=params["name"]
     )
-    if 'duplicates_index' in params:
-        uq.info['duplicates_index'] = params['duplicates_index']
+    if "duplicates_index" in params:
+        uq.info["duplicates_index"] = params["duplicates_index"]
 
     return uq
 
 
 def _make_foreign_key(params, conn_table):
-    tname = params['referred_table']
-    if params['referred_schema']:
-        tname = "%s.%s" % (params['referred_schema'], tname)
+    tname = params["referred_table"]
+    if params["referred_schema"]:
+        tname = "%s.%s" % (params["referred_schema"], tname)
 
-    options = params.get('options', {})
+    options = params.get("options", {})
 
     const = sa_schema.ForeignKeyConstraint(
-        [conn_table.c[cname] for cname in params['constrained_columns']],
-        ["%s.%s" % (tname, n) for n in params['referred_columns']],
-        onupdate=options.get('onupdate'),
-        ondelete=options.get('ondelete'),
-        deferrable=options.get('deferrable'),
-        initially=options.get('initially'),
-        name=params['name']
+        [conn_table.c[cname] for cname in params["constrained_columns"]],
+        ["%s.%s" % (tname, n) for n in params["referred_columns"]],
+        onupdate=options.get("onupdate"),
+        ondelete=options.get("ondelete"),
+        deferrable=options.get("deferrable"),
+        initially=options.get("initially"),
+        name=params["name"],
     )
     # needed by 0.7
     conn_table.append_constraint(const)
@@ -238,21 +267,30 @@ def _make_foreign_key(params, conn_table):
 
 
 @contextlib.contextmanager
-def _compare_columns(schema, tname, conn_table, metadata_table,
-                     modify_table_ops, autogen_context, inspector):
-    name = '%s.%s' % (schema, tname) if schema else tname
+def _compare_columns(
+    schema,
+    tname,
+    conn_table,
+    metadata_table,
+    modify_table_ops,
+    autogen_context,
+    inspector,
+):
+    name = "%s.%s" % (schema, tname) if schema else tname
     metadata_cols_by_name = dict(
-        (c.name, c) for c in metadata_table.c if not c.system)
+        (c.name, c) for c in metadata_table.c if not c.system
+    )
     conn_col_names = dict((c.name, c) for c in conn_table.c)
     metadata_col_names = OrderedSet(sorted(metadata_cols_by_name))
 
     for cname in metadata_col_names.difference(conn_col_names):
         if autogen_context.run_filters(
-                metadata_cols_by_name[cname], cname,
-                "column", False, None):
+            metadata_cols_by_name[cname], cname, "column", False, None
+        ):
             modify_table_ops.ops.append(
                 ops.AddColumnOp.from_column_and_tablename(
-                    schema, tname, metadata_cols_by_name[cname])
+                    schema, tname, metadata_cols_by_name[cname]
+                )
             )
             log.info("Detected added column '%s.%s'", name, cname)
 
@@ -260,15 +298,19 @@ def _compare_columns(schema, tname, conn_table, metadata_table,
         metadata_col = metadata_cols_by_name[colname]
         conn_col = conn_table.c[colname]
         if not autogen_context.run_filters(
-                metadata_col, colname, "column", False,
-                conn_col):
+            metadata_col, colname, "column", False, conn_col
+        ):
             continue
-        alter_column_op = ops.AlterColumnOp(
-            tname, colname, schema=schema)
+        alter_column_op = ops.AlterColumnOp(tname, colname, schema=schema)
 
         comparators.dispatch("column")(
-            autogen_context, alter_column_op,
-            schema, tname, colname, conn_col, metadata_col
+            autogen_context,
+            alter_column_op,
+            schema,
+            tname,
+            colname,
+            conn_col,
+            metadata_col,
         )
 
         if alter_column_op.has_changes():
@@ -278,8 +320,8 @@ def _compare_columns(schema, tname, conn_table, metadata_table,
 
     for cname in set(conn_col_names).difference(metadata_col_names):
         if autogen_context.run_filters(
-                conn_table.c[cname], cname,
-                "column", True, None):
+            conn_table.c[cname], cname, "column", True, None
+        ):
             modify_table_ops.ops.append(
                 ops.DropColumnOp.from_column_and_tablename(
                     schema, tname, conn_table.c[cname]
@@ -289,7 +331,6 @@ def _compare_columns(schema, tname, conn_table, metadata_table,
 
 
 class _constraint_sig(object):
-
     def md_name_to_sql_name(self, context):
         return self.name
 
@@ -340,36 +381,47 @@ class _fk_constraint_sig(_constraint_sig):
         self.name = const.name
 
         (
-            self.source_schema, self.source_table,
-            self.source_columns, self.target_schema, self.target_table,
+            self.source_schema,
+            self.source_table,
+            self.source_columns,
+            self.target_schema,
+            self.target_table,
             self.target_columns,
-            onupdate, ondelete,
-            deferrable, initially) = _fk_spec(const)
+            onupdate,
+            ondelete,
+            deferrable,
+            initially,
+        ) = _fk_spec(const)
 
         self.sig = (
-            self.source_schema, self.source_table, tuple(self.source_columns),
-            self.target_schema, self.target_table, tuple(self.target_columns)
+            self.source_schema,
+            self.source_table,
+            tuple(self.source_columns),
+            self.target_schema,
+            self.target_table,
+            tuple(self.target_columns),
         )
         if include_options:
             self.sig += (
-                (None if onupdate.lower() == 'no action'
-                    else onupdate.lower())
-                if onupdate else None,
-                (None if ondelete.lower() == 'no action'
-                    else ondelete.lower())
-                if ondelete else None,
+                (None if onupdate.lower() == "no action" else onupdate.lower())
+                if onupdate
+                else None,
+                (None if ondelete.lower() == "no action" else ondelete.lower())
+                if ondelete
+                else None,
                 # convert initially + deferrable into one three-state value
                 "initially_deferrable"
                 if initially and initially.lower() == "deferred"
-                else "deferrable" if deferrable
-                else "not deferrable"
+                else "deferrable"
+                if deferrable
+                else "not deferrable",
             )
 
 
 @comparators.dispatch_for("table")
 def _compare_indexes_and_uniques(
-        autogen_context, modify_ops, schema, tname, conn_table,
-        metadata_table):
+    autogen_context, modify_ops, schema, tname, conn_table, metadata_table
+):
 
     inspector = autogen_context.inspector
     is_create_table = conn_table is None
@@ -378,7 +430,8 @@ def _compare_indexes_and_uniques(
     # 1a. get raw indexes and unique constraints from metadata ...
     if metadata_table is not None:
         metadata_unique_constraints = set(
-            uq for uq in metadata_table.constraints
+            uq
+            for uq in metadata_table.constraints
             if isinstance(uq, sa_schema.UniqueConstraint)
         )
         metadata_indexes = set(metadata_table.indexes)
@@ -397,7 +450,8 @@ def _compare_indexes_and_uniques(
         if hasattr(inspector, "get_unique_constraints"):
             try:
                 conn_uniques = inspector.get_unique_constraints(
-                    tname, schema=schema)
+                    tname, schema=schema
+                )
                 supports_unique_constraints = True
             except NotImplementedError:
                 pass
@@ -408,7 +462,7 @@ def _compare_indexes_and_uniques(
                 pass
             else:
                 for uq in conn_uniques:
-                    if uq.get('duplicates_index'):
+                    if uq.get("duplicates_index"):
                         unique_constraints_duplicate_unique_indexes = True
         try:
             conn_indexes = inspector.get_indexes(tname, schema=schema)
@@ -421,8 +475,10 @@ def _compare_indexes_and_uniques(
             # for DROP TABLE uniques are inline, don't need them
             conn_uniques = set()
         else:
-            conn_uniques = set(_make_unique_constraint(uq_def, conn_table)
-                               for uq_def in conn_uniques)
+            conn_uniques = set(
+                _make_unique_constraint(uq_def, conn_table)
+                for uq_def in conn_uniques
+            )
 
         conn_indexes = set(_make_index(ix, conn_table) for ix in conn_indexes)
 
@@ -431,64 +487,71 @@ def _compare_indexes_and_uniques(
 
     if unique_constraints_duplicate_unique_indexes:
         _correct_for_uq_duplicates_uix(
-            conn_uniques, conn_indexes,
+            conn_uniques,
+            conn_indexes,
             metadata_unique_constraints,
-            metadata_indexes
+            metadata_indexes,
         )
 
     # 3. give the dialect a chance to omit indexes and constraints that
     # we know are either added implicitly by the DB or that the DB
     # can't accurately report on
-    autogen_context.migration_context.impl.\
-        correct_for_autogen_constraints(
-            conn_uniques, conn_indexes,
-            metadata_unique_constraints,
-            metadata_indexes)
+    autogen_context.migration_context.impl.correct_for_autogen_constraints(
+        conn_uniques,
+        conn_indexes,
+        metadata_unique_constraints,
+        metadata_indexes,
+    )
 
     # 4. organize the constraints into "signature" collections, the
     # _constraint_sig() objects provide a consistent facade over both
     # Index and UniqueConstraint so we can easily work with them
     # interchangeably
-    metadata_unique_constraints = set(_uq_constraint_sig(uq)
-                                      for uq in metadata_unique_constraints
-                                      )
+    metadata_unique_constraints = set(
+        _uq_constraint_sig(uq) for uq in metadata_unique_constraints
+    )
 
     metadata_indexes = set(_ix_constraint_sig(ix) for ix in metadata_indexes)
 
     conn_unique_constraints = set(
-        _uq_constraint_sig(uq) for uq in conn_uniques)
+        _uq_constraint_sig(uq) for uq in conn_uniques
+    )
 
     conn_indexes = set(_ix_constraint_sig(ix) for ix in conn_indexes)
 
     # 5. index things by name, for those objects that have names
     metadata_names = dict(
-        (c.md_name_to_sql_name(autogen_context), c) for c in
-        metadata_unique_constraints.union(metadata_indexes)
-        if c.name is not None)
+        (c.md_name_to_sql_name(autogen_context), c)
+        for c in metadata_unique_constraints.union(metadata_indexes)
+        if c.name is not None
+    )
 
     conn_uniques_by_name = dict((c.name, c) for c in conn_unique_constraints)
     conn_indexes_by_name = dict((c.name, c) for c in conn_indexes)
 
-    conn_names = dict((c.name, c) for c in
-                      conn_unique_constraints.union(conn_indexes)
-                      if c.name is not None)
+    conn_names = dict(
+        (c.name, c)
+        for c in conn_unique_constraints.union(conn_indexes)
+        if c.name is not None
+    )
 
     doubled_constraints = dict(
         (name, (conn_uniques_by_name[name], conn_indexes_by_name[name]))
-        for name in set(
-            conn_uniques_by_name).intersection(conn_indexes_by_name)
+        for name in set(conn_uniques_by_name).intersection(
+            conn_indexes_by_name
+        )
     )
 
     # 6. index things by "column signature", to help with unnamed unique
     # constraints.
     conn_uniques_by_sig = dict((uq.sig, uq) for uq in conn_unique_constraints)
     metadata_uniques_by_sig = dict(
-        (uq.sig, uq) for uq in metadata_unique_constraints)
-    metadata_indexes_by_sig = dict(
-        (ix.sig, ix) for ix in metadata_indexes)
+        (uq.sig, uq) for uq in metadata_unique_constraints
+    )
+    metadata_indexes_by_sig = dict((ix.sig, ix) for ix in metadata_indexes)
     unnamed_metadata_uniques = dict(
-        (uq.sig, uq) for uq in
-        metadata_unique_constraints if uq.name is None)
+        (uq.sig, uq) for uq in metadata_unique_constraints if uq.name is None
+    )
 
     # assumptions:
     # 1. a unique constraint or an index from the connection *always*
@@ -501,14 +564,14 @@ def _compare_indexes_and_uniques(
     def obj_added(obj):
         if obj.is_index:
             if autogen_context.run_filters(
-                    obj.const, obj.name, "index", False, None):
-                modify_ops.ops.append(
-                    ops.CreateIndexOp.from_index(obj.const)
+                obj.const, obj.name, "index", False, None
+            ):
+                modify_ops.ops.append(ops.CreateIndexOp.from_index(obj.const))
+                log.info(
+                    "Detected added index '%s' on %s",
+                    obj.name,
+                    ", ".join(["'%s'" % obj.column_names]),
                 )
-                log.info("Detected added index '%s' on %s",
-                         obj.name, ', '.join([
-                             "'%s'" % obj.column_names
-                         ]))
         else:
             if not supports_unique_constraints:
                 # can't report unique indexes as added if we don't
@@ -518,15 +581,16 @@ def _compare_indexes_and_uniques(
                 # unique constraints are created inline with table defs
                 return
             if autogen_context.run_filters(
-                    obj.const, obj.name,
-                    "unique_constraint", False, None):
+                obj.const, obj.name, "unique_constraint", False, None
+            ):
                 modify_ops.ops.append(
                     ops.AddConstraintOp.from_constraint(obj.const)
                 )
-                log.info("Detected added unique constraint '%s' on %s",
-                         obj.name, ', '.join([
-                             "'%s'" % obj.column_names
-                         ]))
+                log.info(
+                    "Detected added unique constraint '%s' on %s",
+                    obj.name,
+                    ", ".join(["'%s'" % obj.column_names]),
+                )
 
     def obj_removed(obj):
         if obj.is_index:
@@ -537,48 +601,52 @@ def _compare_indexes_and_uniques(
                 return
 
             if autogen_context.run_filters(
-                    obj.const, obj.name, "index", True, None):
-                modify_ops.ops.append(
-                    ops.DropIndexOp.from_index(obj.const)
-                )
+                obj.const, obj.name, "index", True, None
+            ):
+                modify_ops.ops.append(ops.DropIndexOp.from_index(obj.const))
                 log.info(
-                    "Detected removed index '%s' on '%s'", obj.name, tname)
+                    "Detected removed index '%s' on '%s'", obj.name, tname
+                )
         else:
             if is_create_table or is_drop_table:
                 # if the whole table is being dropped, we don't need to
                 # consider unique constraint separately
                 return
             if autogen_context.run_filters(
-                    obj.const, obj.name,
-                    "unique_constraint", True, None):
+                obj.const, obj.name, "unique_constraint", True, None
+            ):
                 modify_ops.ops.append(
                     ops.DropConstraintOp.from_constraint(obj.const)
                 )
-                log.info("Detected removed unique constraint '%s' on '%s'",
-                         obj.name, tname
-                         )
+                log.info(
+                    "Detected removed unique constraint '%s' on '%s'",
+                    obj.name,
+                    tname,
+                )
 
     def obj_changed(old, new, msg):
         if old.is_index:
             if autogen_context.run_filters(
-                    new.const, new.name, "index",
-                    False, old.const):
-                log.info("Detected changed index '%s' on '%s':%s",
-                         old.name, tname, ', '.join(msg)
-                         )
-                modify_ops.ops.append(
-                    ops.DropIndexOp.from_index(old.const)
-                )
-                modify_ops.ops.append(
-                    ops.CreateIndexOp.from_index(new.const)
+                new.const, new.name, "index", False, old.const
+            ):
+                log.info(
+                    "Detected changed index '%s' on '%s':%s",
+                    old.name,
+                    tname,
+                    ", ".join(msg),
                 )
+                modify_ops.ops.append(ops.DropIndexOp.from_index(old.const))
+                modify_ops.ops.append(ops.CreateIndexOp.from_index(new.const))
         else:
             if autogen_context.run_filters(
-                    new.const, new.name,
-                    "unique_constraint", False, old.const):
-                log.info("Detected changed unique constraint '%s' on '%s':%s",
-                         old.name, tname, ', '.join(msg)
-                         )
+                new.const, new.name, "unique_constraint", False, old.const
+            ):
+                log.info(
+                    "Detected changed unique constraint '%s' on '%s':%s",
+                    old.name,
+                    tname,
+                    ", ".join(msg),
+                )
                 modify_ops.ops.append(
                     ops.DropConstraintOp.from_constraint(old.const)
                 )
@@ -608,13 +676,14 @@ def _compare_indexes_and_uniques(
         else:
             msg = []
             if conn_obj.is_unique != metadata_obj.is_unique:
-                msg.append(' unique=%r to unique=%r' % (
-                    conn_obj.is_unique, metadata_obj.is_unique
-                ))
+                msg.append(
+                    " unique=%r to unique=%r"
+                    % (conn_obj.is_unique, metadata_obj.is_unique)
+                )
             if conn_obj.sig != metadata_obj.sig:
-                msg.append(' columns %r to %r' % (
-                    conn_obj.sig, metadata_obj.sig
-                ))
+                msg.append(
+                    " columns %r to %r" % (conn_obj.sig, metadata_obj.sig)
+                )
 
             if msg:
                 obj_changed(conn_obj, metadata_obj, msg)
@@ -624,8 +693,10 @@ def _compare_indexes_and_uniques(
         if not conn_obj.is_index and conn_obj.sig in unnamed_metadata_uniques:
             continue
         elif removed_name in doubled_constraints:
-            if conn_obj.sig not in metadata_indexes_by_sig and \
-                    conn_obj.sig not in metadata_uniques_by_sig:
+            if (
+                conn_obj.sig not in metadata_indexes_by_sig
+                and conn_obj.sig not in metadata_uniques_by_sig
+            ):
                 conn_uq, conn_idx = doubled_constraints[removed_name]
                 obj_removed(conn_uq)
                 obj_removed(conn_idx)
@@ -639,40 +710,51 @@ def _compare_indexes_and_uniques(
 
 def _correct_for_uq_duplicates_uix(
     conn_unique_constraints,
-        conn_indexes,
-        metadata_unique_constraints,
-        metadata_indexes):
+    conn_indexes,
+    metadata_unique_constraints,
+    metadata_indexes,
+):
     # dedupe unique indexes vs. constraints, since MySQL / Oracle
     # doesn't really have unique constraints as a separate construct.
     # but look in the metadata and try to maintain constructs
     # that already seem to be defined one way or the other
     # on that side.  This logic was formerly local to MySQL dialect,
     # generalized to Oracle and others. See #276
-    metadata_uq_names = set([
-        cons.name for cons in metadata_unique_constraints
-        if cons.name is not None])
-
-    unnamed_metadata_uqs = set([
-        _uq_constraint_sig(cons).sig
-        for cons in metadata_unique_constraints
-        if cons.name is None
-    ])
-
-    metadata_ix_names = set([
-        cons.name for cons in metadata_indexes if cons.unique])
+    metadata_uq_names = set(
+        [
+            cons.name
+            for cons in metadata_unique_constraints
+            if cons.name is not None
+        ]
+    )
+
+    unnamed_metadata_uqs = set(
+        [
+            _uq_constraint_sig(cons).sig
+            for cons in metadata_unique_constraints
+            if cons.name is None
+        ]
+    )
+
+    metadata_ix_names = set(
+        [cons.name for cons in metadata_indexes if cons.unique]
+    )
     conn_ix_names = dict(
         (cons.name, cons) for cons in conn_indexes if cons.unique
     )
 
     uqs_dupe_indexes = dict(
-        (cons.name, cons) for cons in conn_unique_constraints
-        if cons.info['duplicates_index']
+        (cons.name, cons)
+        for cons in conn_unique_constraints
+        if cons.info["duplicates_index"]
     )
 
     for overlap in uqs_dupe_indexes:
         if overlap not in metadata_uq_names:
-            if _uq_constraint_sig(uqs_dupe_indexes[overlap]).sig \
-                    not in unnamed_metadata_uqs:
+            if (
+                _uq_constraint_sig(uqs_dupe_indexes[overlap]).sig
+                not in unnamed_metadata_uqs
+            ):
 
                 conn_unique_constraints.discard(uqs_dupe_indexes[overlap])
         elif overlap not in metadata_ix_names:
@@ -681,8 +763,14 @@ def _correct_for_uq_duplicates_uix(
 
 @comparators.dispatch_for("column")
 def _compare_nullable(
-    autogen_context, alter_column_op, schema, tname, cname, conn_col,
-        metadata_col):
+    autogen_context,
+    alter_column_op,
+    schema,
+    tname,
+    cname,
+    conn_col,
+    metadata_col,
+):
 
     # work around SQLAlchemy issue #3023
     if metadata_col.primary_key:
@@ -694,57 +782,83 @@ def _compare_nullable(
 
     if conn_col_nullable is not metadata_col_nullable:
         alter_column_op.modify_nullable = metadata_col_nullable
-        log.info("Detected %s on column '%s.%s'",
-                 "NULL" if metadata_col_nullable else "NOT NULL",
-                 tname,
-                 cname
-                 )
+        log.info(
+            "Detected %s on column '%s.%s'",
+            "NULL" if metadata_col_nullable else "NOT NULL",
+            tname,
+            cname,
+        )
 
 
 @comparators.dispatch_for("column")
 def _setup_autoincrement(
-    autogen_context, alter_column_op, schema, tname, cname, conn_col,
-        metadata_col):
+    autogen_context,
+    alter_column_op,
+    schema,
+    tname,
+    cname,
+    conn_col,
+    metadata_col,
+):
 
     if metadata_col.table._autoincrement_column is metadata_col:
-        alter_column_op.kw['autoincrement'] = True
+        alter_column_op.kw["autoincrement"] = True
     elif util.sqla_110 and metadata_col.autoincrement is True:
-        alter_column_op.kw['autoincrement'] = True
+        alter_column_op.kw["autoincrement"] = True
     elif metadata_col.autoincrement is False:
-        alter_column_op.kw['autoincrement'] = False
+        alter_column_op.kw["autoincrement"] = False
 
 
 @comparators.dispatch_for("column")
 def _compare_type(
-    autogen_context, alter_column_op, schema, tname, cname, conn_col,
-        metadata_col):
+    autogen_context,
+    alter_column_op,
+    schema,
+    tname,
+    cname,
+    conn_col,
+    metadata_col,
+):
 
     conn_type = conn_col.type
     alter_column_op.existing_type = conn_type
     metadata_type = metadata_col.type
     if conn_type._type_affinity is sqltypes.NullType:
-        log.info("Couldn't determine database type "
-                 "for column '%s.%s'", tname, cname)
+        log.info(
+            "Couldn't determine database type " "for column '%s.%s'",
+            tname,
+            cname,
+        )
         return
     if metadata_type._type_affinity is sqltypes.NullType:
-        log.info("Column '%s.%s' has no type within "
-                 "the model; can't compare", tname, cname)
+        log.info(
+            "Column '%s.%s' has no type within " "the model; can't compare",
+            tname,
+            cname,
+        )
         return
 
     isdiff = autogen_context.migration_context._compare_type(
-        conn_col, metadata_col)
+        conn_col, metadata_col
+    )
 
     if isdiff:
         alter_column_op.modify_type = metadata_type
-        log.info("Detected type change from %r to %r on '%s.%s'",
-                 conn_type, metadata_type, tname, cname
-                 )
+        log.info(
+            "Detected type change from %r to %r on '%s.%s'",
+            conn_type,
+            metadata_type,
+            tname,
+            cname,
+        )
 
 
-def _render_server_default_for_compare(metadata_default,
-                                       metadata_col, autogen_context):
+def _render_server_default_for_compare(
+    metadata_default, metadata_col, autogen_context
+):
     rendered = _user_defined_render(
-        "server_default", metadata_default, autogen_context)
+        "server_default", metadata_default, autogen_context
+    )
     if rendered is not False:
         return rendered
 
@@ -752,8 +866,9 @@ def _render_server_default_for_compare(metadata_default,
         if isinstance(metadata_default.arg, compat.string_types):
             metadata_default = metadata_default.arg
         else:
-            metadata_default = str(metadata_default.arg.compile(
-                dialect=autogen_context.dialect))
+            metadata_default = str(
+                metadata_default.arg.compile(dialect=autogen_context.dialect)
+            )
     if isinstance(metadata_default, compat.string_types):
         if metadata_col.type._type_affinity is sqltypes.String:
             metadata_default = re.sub(r"^'|'$", "", metadata_default)
@@ -766,37 +881,49 @@ def _render_server_default_for_compare(metadata_default,
 
 @comparators.dispatch_for("column")
 def _compare_server_default(
-    autogen_context, alter_column_op, schema, tname, cname,
-        conn_col, metadata_col):
+    autogen_context,
+    alter_column_op,
+    schema,
+    tname,
+    cname,
+    conn_col,
+    metadata_col,
+):
 
     metadata_default = metadata_col.server_default
     conn_col_default = conn_col.server_default
     if conn_col_default is None and metadata_default is None:
         return False
     rendered_metadata_default = _render_server_default_for_compare(
-        metadata_default, metadata_col, autogen_context)
+        metadata_default, metadata_col, autogen_context
+    )
 
-    rendered_conn_default = conn_col.server_default.arg.text \
-        if conn_col.server_default else None
+    rendered_conn_default = (
+        conn_col.server_default.arg.text if conn_col.server_default else None
+    )
 
     alter_column_op.existing_server_default = conn_col_default
 
     isdiff = autogen_context.migration_context._compare_server_default(
-        conn_col, metadata_col,
+        conn_col,
+        metadata_col,
         rendered_metadata_default,
-        rendered_conn_default
+        rendered_conn_default,
     )
     if isdiff:
         alter_column_op.modify_server_default = metadata_default
-        log.info(
-            "Detected server default on column '%s.%s'",
-            tname, cname)
+        log.info("Detected server default on column '%s.%s'", tname, cname)
 
 
 @comparators.dispatch_for("table")
 def _compare_foreign_keys(
-    autogen_context, modify_table_ops, schema, tname, conn_table,
-        metadata_table):
+    autogen_context,
+    modify_table_ops,
+    schema,
+    tname,
+    conn_table,
+    metadata_table,
+):
 
     # if we're doing CREATE TABLE, all FKs are created
     # inline within the table def
@@ -805,22 +932,22 @@ def _compare_foreign_keys(
 
     inspector = autogen_context.inspector
     metadata_fks = set(
-        fk for fk in metadata_table.constraints
+        fk
+        for fk in metadata_table.constraints
         if isinstance(fk, sa_schema.ForeignKeyConstraint)
     )
 
     conn_fks = inspector.get_foreign_keys(tname, schema=schema)
 
-    backend_reflects_fk_options = conn_fks and 'options' in conn_fks[0]
+    backend_reflects_fk_options = conn_fks and "options" in conn_fks[0]
 
     conn_fks = set(_make_foreign_key(const, conn_table) for const in conn_fks)
 
     # give the dialect a chance to correct the FKs to match more
     # closely
-    autogen_context.migration_context.impl.\
-        correct_for_autogen_foreignkeys(
-            conn_fks, metadata_fks,
-        )
+    autogen_context.migration_context.impl.correct_for_autogen_foreignkeys(
+        conn_fks, metadata_fks
+    )
 
     metadata_fks = set(
         _fk_constraint_sig(fk, include_options=backend_reflects_fk_options)
@@ -832,12 +959,8 @@ def _compare_foreign_keys(
         for fk in conn_fks
     )
 
-    conn_fks_by_sig = dict(
-        (c.sig, c) for c in conn_fks
-    )
-    metadata_fks_by_sig = dict(
-        (c.sig, c) for c in metadata_fks
-    )
+    conn_fks_by_sig = dict((c.sig, c) for c in conn_fks)
+    metadata_fks_by_sig = dict((c.sig, c) for c in metadata_fks)
 
     metadata_fks_by_name = dict(
         (c.name, c) for c in metadata_fks if c.name is not None
@@ -848,8 +971,8 @@ def _compare_foreign_keys(
 
     def _add_fk(obj, compare_to):
         if autogen_context.run_filters(
-                obj.const, obj.name, "foreign_key_constraint", False,
-                compare_to):
+            obj.const, obj.name, "foreign_key_constraint", False, compare_to
+        ):
             modify_table_ops.ops.append(
                 ops.CreateForeignKeyOp.from_constraint(const.const)
             )
@@ -859,12 +982,13 @@ def _compare_foreign_keys(
                 ", ".join(obj.source_columns),
                 ", ".join(obj.target_columns),
                 "%s." % obj.source_schema if obj.source_schema else "",
-                obj.source_table)
+                obj.source_table,
+            )
 
     def _remove_fk(obj, compare_to):
         if autogen_context.run_filters(
-                obj.const, obj.name, "foreign_key_constraint", True,
-                compare_to):
+            obj.const, obj.name, "foreign_key_constraint", True, compare_to
+        ):
             modify_table_ops.ops.append(
                 ops.DropConstraintOp.from_constraint(obj.const)
             )
@@ -873,7 +997,8 @@ def _compare_foreign_keys(
                 ", ".join(obj.source_columns),
                 ", ".join(obj.target_columns),
                 "%s." % obj.source_schema if obj.source_schema else "",
-                obj.source_table)
+                obj.source_table,
+            )
 
     # so far it appears we don't need to do this by name at all.
     # SQLite doesn't preserve constraint names anyway
@@ -881,13 +1006,19 @@ def _compare_foreign_keys(
     for removed_sig in set(conn_fks_by_sig).difference(metadata_fks_by_sig):
         const = conn_fks_by_sig[removed_sig]
         if removed_sig not in metadata_fks_by_sig:
-            compare_to = metadata_fks_by_name[const.name].const \
-                if const.name in metadata_fks_by_name else None
+            compare_to = (
+                metadata_fks_by_name[const.name].const
+                if const.name in metadata_fks_by_name
+                else None
+            )
             _remove_fk(const, compare_to)
 
     for added_sig in set(metadata_fks_by_sig).difference(conn_fks_by_sig):
         const = metadata_fks_by_sig[added_sig]
         if added_sig not in conn_fks_by_sig:
-            compare_to = conn_fks_by_name[const.name].const \
-                if const.name in conn_fks_by_name else None
+            compare_to = (
+                conn_fks_by_name[const.name].const
+                if const.name in conn_fks_by_name
+                else None
+            )
             _add_fk(const, compare_to)
index 4fbe91fd07d539b7ca43c7f8d5289f2d9c500562..573ee02c4b7d1dc9fd62a9b45fa8fb565109c106 100644 (file)
@@ -19,29 +19,35 @@ try:
             return _f_name(_alembic_autogenerate_prefix(autogen_context), name)
         else:
             return name
+
+
 except ImportError:
+
     def _render_gen_name(autogen_context, name):
         return name
 
 
 def _indent(text):
-    text = re.compile(r'^', re.M).sub("    ", text).strip()
-    text = re.compile(r' +$', re.M).sub("", text)
+    text = re.compile(r"^", re.M).sub("    ", text).strip()
+    text = re.compile(r" +$", re.M).sub("", text)
     return text
 
 
 def _render_python_into_templatevars(
-        autogen_context, migration_script, template_args):
+    autogen_context, migration_script, template_args
+):
     imports = autogen_context.imports
 
     for upgrade_ops, downgrade_ops in zip(
-            migration_script.upgrade_ops_list,
-            migration_script.downgrade_ops_list):
+        migration_script.upgrade_ops_list, migration_script.downgrade_ops_list
+    ):
         template_args[upgrade_ops.upgrade_token] = _indent(
-            _render_cmd_body(upgrade_ops, autogen_context))
+            _render_cmd_body(upgrade_ops, autogen_context)
+        )
         template_args[downgrade_ops.downgrade_token] = _indent(
-            _render_cmd_body(downgrade_ops, autogen_context))
-    template_args['imports'] = "\n".join(sorted(imports))
+            _render_cmd_body(downgrade_ops, autogen_context)
+        )
+    template_args["imports"] = "\n".join(sorted(imports))
 
 
 default_renderers = renderers = util.Dispatcher()
@@ -83,7 +89,7 @@ def render_op_text(autogen_context, op):
 @renderers.dispatch_for(ops.ModifyTableOps)
 def _render_modify_table(autogen_context, op):
     opts = autogen_context.opts
-    render_as_batch = opts.get('render_as_batch', False)
+    render_as_batch = opts.get("render_as_batch", False)
 
     if op.ops:
         lines = []
@@ -104,33 +110,39 @@ def _render_modify_table(autogen_context, op):
 
         return lines
     else:
-        return [
-            "pass"
-        ]
+        return ["pass"]
 
 
 @renderers.dispatch_for(ops.CreateTableOp)
 def _add_table(autogen_context, op):
     table = op.to_table()
 
-    args = [col for col in
-            [_render_column(col, autogen_context) for col in table.columns]
-            if col] + \
-        sorted([rcons for rcons in
-                [_render_constraint(cons, autogen_context) for cons in
-                 table.constraints]
-                if rcons is not None
-                ])
+    args = [
+        col
+        for col in [
+            _render_column(col, autogen_context) for col in table.columns
+        ]
+        if col
+    ] + sorted(
+        [
+            rcons
+            for rcons in [
+                _render_constraint(cons, autogen_context)
+                for cons in table.constraints
+            ]
+            if rcons is not None
+        ]
+    )
 
     if len(args) > MAX_PYTHON_ARGS:
-        args = '*[' + ',\n'.join(args) + ']'
+        args = "*[" + ",\n".join(args) + "]"
     else:
-        args = ',\n'.join(args)
+        args = ",\n".join(args)
 
     text = "%(prefix)screate_table(%(tablename)r,\n%(args)s" % {
-        'tablename': _ident(op.table_name),
-        'prefix': _alembic_autogenerate_prefix(autogen_context),
-        'args': args,
+        "tablename": _ident(op.table_name),
+        "prefix": _alembic_autogenerate_prefix(autogen_context),
+        "args": args,
     }
     if op.schema:
         text += ",\nschema=%r" % _ident(op.schema)
@@ -144,7 +156,7 @@ def _add_table(autogen_context, op):
 def _drop_table(autogen_context, op):
     text = "%(prefix)sdrop_table(%(tname)r" % {
         "prefix": _alembic_autogenerate_prefix(autogen_context),
-        "tname": _ident(op.table_name)
+        "tname": _ident(op.table_name),
     }
     if op.schema:
         text += ", schema=%r" % _ident(op.schema)
@@ -159,28 +171,39 @@ def _add_index(autogen_context, op):
     has_batch = autogen_context._has_batch
 
     if has_batch:
-        tmpl = "%(prefix)screate_index(%(name)r, [%(columns)s], "\
+        tmpl = (
+            "%(prefix)screate_index(%(name)r, [%(columns)s], "
             "unique=%(unique)r%(kwargs)s)"
+        )
     else:
-        tmpl = "%(prefix)screate_index(%(name)r, %(table)r, [%(columns)s], "\
+        tmpl = (
+            "%(prefix)screate_index(%(name)r, %(table)r, [%(columns)s], "
             "unique=%(unique)r%(schema)s%(kwargs)s)"
+        )
 
     text = tmpl % {
-        'prefix': _alembic_autogenerate_prefix(autogen_context),
-        'name': _render_gen_name(autogen_context, index.name),
-        'table': _ident(index.table.name),
-        'columns': ", ".join(
-            _get_index_rendered_expressions(index, autogen_context)),
-        'unique': index.unique or False,
-        'schema': (", schema=%r" % _ident(index.table.schema))
-        if index.table.schema else '',
-        'kwargs': (
-            ', ' +
-            ', '.join(
-                ["%s=%s" %
-                 (key, _render_potential_expr(val, autogen_context))
-                 for key, val in index.kwargs.items()]))
-        if len(index.kwargs) else ''
+        "prefix": _alembic_autogenerate_prefix(autogen_context),
+        "name": _render_gen_name(autogen_context, index.name),
+        "table": _ident(index.table.name),
+        "columns": ", ".join(
+            _get_index_rendered_expressions(index, autogen_context)
+        ),
+        "unique": index.unique or False,
+        "schema": (", schema=%r" % _ident(index.table.schema))
+        if index.table.schema
+        else "",
+        "kwargs": (
+            ", "
+            + ", ".join(
+                [
+                    "%s=%s"
+                    % (key, _render_potential_expr(val, autogen_context))
+                    for key, val in index.kwargs.items()
+                ]
+            )
+        )
+        if len(index.kwargs)
+        else "",
     }
     return text
 
@@ -192,15 +215,16 @@ def _drop_index(autogen_context, op):
     if has_batch:
         tmpl = "%(prefix)sdrop_index(%(name)r)"
     else:
-        tmpl = "%(prefix)sdrop_index(%(name)r, "\
+        tmpl = (
+            "%(prefix)sdrop_index(%(name)r, "
             "table_name=%(table_name)r%(schema)s)"
+        )
 
     text = tmpl % {
-        'prefix': _alembic_autogenerate_prefix(autogen_context),
-        'name': _render_gen_name(autogen_context, op.index_name),
-        'table_name': _ident(op.table_name),
-        'schema': ((", schema=%r" % _ident(op.schema))
-                   if op.schema else '')
+        "prefix": _alembic_autogenerate_prefix(autogen_context),
+        "name": _render_gen_name(autogen_context, op.index_name),
+        "table_name": _ident(op.table_name),
+        "schema": ((", schema=%r" % _ident(op.schema)) if op.schema else ""),
     }
     return text
 
@@ -213,30 +237,28 @@ def _add_unique_constraint(autogen_context, op):
 @renderers.dispatch_for(ops.CreateForeignKeyOp)
 def _add_fk_constraint(autogen_context, op):
 
-    args = [
-        repr(
-            _render_gen_name(autogen_context, op.constraint_name)),
-    ]
+    args = [repr(_render_gen_name(autogen_context, op.constraint_name))]
     if not autogen_context._has_batch:
-        args.append(
-            repr(_ident(op.source_table))
-        )
+        args.append(repr(_ident(op.source_table)))
 
     args.extend(
         [
             repr(_ident(op.referent_table)),
             repr([_ident(col) for col in op.local_cols]),
-            repr([_ident(col) for col in op.remote_cols])
+            repr([_ident(col) for col in op.remote_cols]),
         ]
     )
 
     kwargs = [
-        'referent_schema',
-        'onupdate', 'ondelete', 'initially',
-        'deferrable', 'use_alter'
+        "referent_schema",
+        "onupdate",
+        "ondelete",
+        "initially",
+        "deferrable",
+        "use_alter",
     ]
     if not autogen_context._has_batch:
-        kwargs.insert(0, 'source_schema')
+        kwargs.insert(0, "source_schema")
 
     for k in kwargs:
         if k in op.kw:
@@ -245,8 +267,8 @@ def _add_fk_constraint(autogen_context, op):
                 args.append("%s=%r" % (k, value))
 
     return "%(prefix)screate_foreign_key(%(args)s)" % {
-        'prefix': _alembic_autogenerate_prefix(autogen_context),
-        'args': ", ".join(args)
+        "prefix": _alembic_autogenerate_prefix(autogen_context),
+        "args": ", ".join(args),
     }
 
 
@@ -264,20 +286,19 @@ def _add_check_constraint(constraint, autogen_context):
 def _drop_constraint(autogen_context, op):
 
     if autogen_context._has_batch:
-        template = "%(prefix)sdrop_constraint"\
-            "(%(name)r, type_=%(type)r)"
+        template = "%(prefix)sdrop_constraint" "(%(name)r, type_=%(type)r)"
     else:
-        template = "%(prefix)sdrop_constraint"\
+        template = (
+            "%(prefix)sdrop_constraint"
             "(%(name)r, '%(table_name)s'%(schema)s, type_=%(type)r)"
+        )
 
     text = template % {
-        'prefix': _alembic_autogenerate_prefix(autogen_context),
-        'name': _render_gen_name(
-            autogen_context, op.constraint_name),
-        'table_name': _ident(op.table_name),
-        'type': op.constraint_type,
-        'schema': (", schema=%r" % _ident(op.schema))
-        if op.schema else '',
+        "prefix": _alembic_autogenerate_prefix(autogen_context),
+        "name": _render_gen_name(autogen_context, op.constraint_name),
+        "table_name": _ident(op.table_name),
+        "type": op.constraint_type,
+        "schema": (", schema=%r" % _ident(op.schema)) if op.schema else "",
     }
     return text
 
@@ -297,7 +318,7 @@ def _add_column(autogen_context, op):
         "prefix": _alembic_autogenerate_prefix(autogen_context),
         "tname": tname,
         "column": _render_column(column, autogen_context),
-        "schema": schema
+        "schema": schema,
     }
     return text
 
@@ -319,7 +340,7 @@ def _drop_column(autogen_context, op):
         "prefix": _alembic_autogenerate_prefix(autogen_context),
         "tname": _ident(tname),
         "cname": _ident(column_name),
-        "schema": _ident(schema)
+        "schema": _ident(schema),
     }
     return text
 
@@ -332,7 +353,7 @@ def _alter_column(autogen_context, op):
     server_default = op.modify_server_default
     type_ = op.modify_type
     nullable = op.modify_nullable
-    autoincrement = op.kw.get('autoincrement', None)
+    autoincrement = op.kw.get("autoincrement", None)
     existing_type = op.existing_type
     existing_nullable = op.existing_nullable
     existing_server_default = op.existing_server_default
@@ -346,37 +367,32 @@ def _alter_column(autogen_context, op):
         template = "%(prefix)salter_column(%(tname)r, %(cname)r"
 
     text = template % {
-        'prefix': _alembic_autogenerate_prefix(
-            autogen_context),
-        'tname': tname,
-        'cname': cname}
+        "prefix": _alembic_autogenerate_prefix(autogen_context),
+        "tname": tname,
+        "cname": cname,
+    }
     if existing_type is not None:
         text += ",\n%sexisting_type=%s" % (
             indent,
-            _repr_type(existing_type, autogen_context))
+            _repr_type(existing_type, autogen_context),
+        )
     if server_default is not False:
-        rendered = _render_server_default(
-            server_default, autogen_context)
+        rendered = _render_server_default(server_default, autogen_context)
         text += ",\n%sserver_default=%s" % (indent, rendered)
 
     if type_ is not None:
-        text += ",\n%stype_=%s" % (indent,
-                                   _repr_type(type_, autogen_context))
+        text += ",\n%stype_=%s" % (indent, _repr_type(type_, autogen_context))
     if nullable is not None:
-        text += ",\n%snullable=%r" % (
-            indent, nullable,)
+        text += ",\n%snullable=%r" % (indent, nullable)
     if nullable is None and existing_nullable is not None:
-        text += ",\n%sexisting_nullable=%r" % (
-            indent, existing_nullable)
+        text += ",\n%sexisting_nullable=%r" % (indent, existing_nullable)
     if autoincrement is not None:
-        text += ",\n%sautoincrement=%r" % (
-            indent, autoincrement)
+        text += ",\n%sautoincrement=%r" % (indent, autoincrement)
     if server_default is False and existing_server_default:
         rendered = _render_server_default(
-            existing_server_default,
-            autogen_context)
-        text += ",\n%sexisting_server_default=%s" % (
-            indent, rendered)
+            existing_server_default, autogen_context
+        )
+        text += ",\n%sexisting_server_default=%s" % (indent, rendered)
     if schema and not autogen_context._has_batch:
         text += ",\n%sschema=%r" % (indent, schema)
     text += ")"
@@ -384,7 +400,6 @@ def _alter_column(autogen_context, op):
 
 
 class _f_name(object):
-
     def __init__(self, prefix, name):
         self.prefix = prefix
         self.name = name
@@ -410,7 +425,7 @@ def _ident(name):
             # u'' literals only when py2k + SQLA 0.9, in particular
             # makes unit tests testing code generation very difficult
             try:
-                return name.encode('ascii')
+                return name.encode("ascii")
             except UnicodeError:
                 return compat.text_type(name)
         else:
@@ -421,8 +436,9 @@ def _ident(name):
 
 def _render_potential_expr(value, autogen_context, wrap_in_text=True):
     if isinstance(value, sql.ClauseElement):
-        compile_kw = dict(compile_kwargs={
-            'literal_binds': True, "include_table": False})
+        compile_kw = dict(
+            compile_kwargs={"literal_binds": True, "include_table": False}
+        )
 
         if wrap_in_text:
             template = "%(prefix)stext(%(sql)r)"
@@ -432,9 +448,8 @@ def _render_potential_expr(value, autogen_context, wrap_in_text=True):
         return template % {
             "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
             "sql": compat.text_type(
-                value.compile(dialect=autogen_context.dialect,
-                              **compile_kw)
-            )
+                value.compile(dialect=autogen_context.dialect, **compile_kw)
+            ),
         }
 
     else:
@@ -442,10 +457,12 @@ def _render_potential_expr(value, autogen_context, wrap_in_text=True):
 
 
 def _get_index_rendered_expressions(idx, autogen_context):
-    return [repr(_ident(getattr(exp, "name", None)))
-            if isinstance(exp, sa_schema.Column)
-            else _render_potential_expr(exp, autogen_context)
-            for exp in idx.expressions]
+    return [
+        repr(_ident(getattr(exp, "name", None)))
+        if isinstance(exp, sa_schema.Column)
+        else _render_potential_expr(exp, autogen_context)
+        for exp in idx.expressions
+    ]
 
 
 def _uq_constraint(constraint, autogen_context, alter):
@@ -461,32 +478,30 @@ def _uq_constraint(constraint, autogen_context, alter):
         opts.append(("schema", _ident(constraint.table.schema)))
     if not alter and constraint.name:
         opts.append(
-            ("name",
-             _render_gen_name(autogen_context, constraint.name)))
+            ("name", _render_gen_name(autogen_context, constraint.name))
+        )
 
     if alter:
-        args = [
-            repr(_render_gen_name(
-                autogen_context, constraint.name))]
+        args = [repr(_render_gen_name(autogen_context, constraint.name))]
         if not has_batch:
             args += [repr(_ident(constraint.table.name))]
         args.append(repr([_ident(col.name) for col in constraint.columns]))
         args.extend(["%s=%r" % (k, v) for k, v in opts])
         return "%(prefix)screate_unique_constraint(%(args)s)" % {
-            'prefix': _alembic_autogenerate_prefix(autogen_context),
-            'args': ", ".join(args)
+            "prefix": _alembic_autogenerate_prefix(autogen_context),
+            "args": ", ".join(args),
         }
     else:
         args = [repr(_ident(col.name)) for col in constraint.columns]
         args.extend(["%s=%r" % (k, v) for k, v in opts])
         return "%(prefix)sUniqueConstraint(%(args)s)" % {
             "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
-            "args": ", ".join(args)
+            "args": ", ".join(args),
         }
 
 
 def _user_autogenerate_prefix(autogen_context, target):
-    prefix = autogen_context.opts['user_module_prefix']
+    prefix = autogen_context.opts["user_module_prefix"]
     if prefix is None:
         return "%s." % target.__module__
     else:
@@ -494,19 +509,19 @@ def _user_autogenerate_prefix(autogen_context, target):
 
 
 def _sqlalchemy_autogenerate_prefix(autogen_context):
-    return autogen_context.opts['sqlalchemy_module_prefix'] or ''
+    return autogen_context.opts["sqlalchemy_module_prefix"] or ""
 
 
 def _alembic_autogenerate_prefix(autogen_context):
     if autogen_context._has_batch:
-        return 'batch_op.'
+        return "batch_op."
     else:
-        return autogen_context.opts['alembic_module_prefix'] or ''
+        return autogen_context.opts["alembic_module_prefix"] or ""
 
 
 def _user_defined_render(type_, object_, autogen_context):
-    if 'render_item' in autogen_context.opts:
-        render = autogen_context.opts['render_item']
+    if "render_item" in autogen_context.opts:
+        render = autogen_context.opts["render_item"]
         if render:
             rendered = render(type_, object_, autogen_context)
             if rendered is not False:
@@ -527,8 +542,10 @@ def _render_column(column, autogen_context):
         if rendered:
             opts.append(("server_default", rendered))
 
-    if column.autoincrement is not None and \
-            column.autoincrement != sqla_compat.AUTOINCREMENT_DEFAULT:
+    if (
+        column.autoincrement is not None
+        and column.autoincrement != sqla_compat.AUTOINCREMENT_DEFAULT
+    ):
         opts.append(("autoincrement", column.autoincrement))
 
     if column.nullable is not None:
@@ -539,10 +556,10 @@ def _render_column(column, autogen_context):
 
     # TODO: for non-ascii colname, assign a "key"
     return "%(prefix)sColumn(%(name)r, %(type)s, %(kw)s)" % {
-        'prefix': _sqlalchemy_autogenerate_prefix(autogen_context),
-        'name': _ident(column.name),
-        'type': _repr_type(column.type, autogen_context),
-        'kw': ", ".join(["%s=%s" % (kwname, val) for kwname, val in opts])
+        "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
+        "name": _ident(column.name),
+        "type": _repr_type(column.type, autogen_context),
+        "kw": ", ".join(["%s=%s" % (kwname, val) for kwname, val in opts]),
     }
 
 
@@ -568,9 +585,10 @@ def _repr_type(type_, autogen_context):
     if rendered is not False:
         return rendered
 
-    if hasattr(autogen_context.migration_context, 'impl'):
+    if hasattr(autogen_context.migration_context, "impl"):
         impl_rt = autogen_context.migration_context.impl.render_type(
-            type_, autogen_context)
+            type_, autogen_context
+        )
     else:
         impl_rt = None
 
@@ -587,8 +605,8 @@ def _repr_type(type_, autogen_context):
     elif impl_rt:
         return impl_rt
     elif mod.startswith("sqlalchemy."):
-        if '_render_%s_type' % type_.__visit_name__ in globals():
-            fn = globals()['_render_%s_type' % type_.__visit_name__]
+        if "_render_%s_type" % type_.__visit_name__ in globals():
+            fn = globals()["_render_%s_type" % type_.__visit_name__]
             return fn(type_, autogen_context)
         else:
             prefix = _sqlalchemy_autogenerate_prefix(autogen_context)
@@ -600,12 +618,13 @@ def _repr_type(type_, autogen_context):
 
 def _render_ARRAY_type(type_, autogen_context):
     return _render_type_w_subtype(
-        type_, autogen_context, 'item_type', r'(.+?\()'
+        type_, autogen_context, "item_type", r"(.+?\()"
     )
 
 
 def _render_type_w_subtype(
-        type_, autogen_context, attrname, regexp, prefix=None):
+    type_, autogen_context, attrname, regexp, prefix=None
+):
     outer_repr = repr(type_)
     inner_type = getattr(type_, attrname, None)
     if inner_type is None:
@@ -613,11 +632,9 @@ def _render_type_w_subtype(
 
     inner_repr = repr(inner_type)
 
-    inner_repr = re.sub(r'([\(\)])', r'\\\1', inner_repr)
+    inner_repr = re.sub(r"([\(\)])", r"\\\1", inner_repr)
     sub_type = _repr_type(getattr(type_, attrname), autogen_context)
-    outer_type = re.sub(
-        regexp + inner_repr,
-        r"\1%s" % sub_type, outer_repr)
+    outer_type = re.sub(regexp + inner_repr, r"\1%s" % sub_type, outer_repr)
 
     if prefix:
         return "%s%s" % (prefix, outer_type)
@@ -632,6 +649,7 @@ def _render_type_w_subtype(
     else:
         return None
 
+
 _constraint_renderers = util.Dispatcher()
 
 
@@ -656,13 +674,14 @@ def _render_primary_key(constraint, autogen_context):
 
     opts = []
     if constraint.name:
-        opts.append(("name", repr(
-            _render_gen_name(autogen_context, constraint.name))))
+        opts.append(
+            ("name", repr(_render_gen_name(autogen_context, constraint.name)))
+        )
     return "%(prefix)sPrimaryKeyConstraint(%(args)s)" % {
         "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
         "args": ", ".join(
-            [repr(c.name) for c in constraint.columns] +
-            ["%s=%s" % (kwname, val) for kwname, val in opts]
+            [repr(c.name) for c in constraint.columns]
+            ["%s=%s" % (kwname, val) for kwname, val in opts]
         ),
     }
 
@@ -681,8 +700,11 @@ def _fk_colspec(fk, metadata_schema):
     else:
         table_fullname = ".".join(tokens[0:-1])
 
-    if not fk.link_to_name and \
-            fk.parent is not None and fk.parent.table is not None:
+    if (
+        not fk.link_to_name
+        and fk.parent is not None
+        and fk.parent.table is not None
+    ):
         # try to resolve the remote table in order to adjust for column.key.
         # the FK constraint needs to be rendered in terms of the column
         # name.
@@ -719,23 +741,30 @@ def _render_foreign_key(constraint, autogen_context):
 
     opts = []
     if constraint.name:
-        opts.append(("name", repr(
-            _render_gen_name(autogen_context, constraint.name))))
+        opts.append(
+            ("name", repr(_render_gen_name(autogen_context, constraint.name)))
+        )
 
     _populate_render_fk_opts(constraint, opts)
 
     apply_metadata_schema = constraint.parent.metadata.schema
-    return "%(prefix)sForeignKeyConstraint([%(cols)s], "\
-        "[%(refcols)s], %(args)s)" % {
+    return (
+        "%(prefix)sForeignKeyConstraint([%(cols)s], "
+        "[%(refcols)s], %(args)s)"
+        % {
             "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
             "cols": ", ".join(
-                "%r" % _ident(f.parent.name) for f in constraint.elements),
-            "refcols": ", ".join(repr(_fk_colspec(f, apply_metadata_schema))
-                                 for f in constraint.elements),
+                "%r" % _ident(f.parent.name) for f in constraint.elements
+            ),
+            "refcols": ", ".join(
+                repr(_fk_colspec(f, apply_metadata_schema))
+                for f in constraint.elements
+            ),
             "args": ", ".join(
-                    ["%s=%s" % (kwname, val) for kwname, val in opts]
+                ["%s=%s" % (kwname, val) for kwname, val in opts]
             ),
         }
+    )
 
 
 @_constraint_renderers.dispatch_for(sa_schema.UniqueConstraint)
@@ -757,27 +786,25 @@ def _render_check_constraint(constraint, autogen_context):
     # a parent type which is probably in the Table already.
     # ideally SQLAlchemy would give us more of a first class
     # way to detect this.
-    if constraint._create_rule and \
-        hasattr(constraint._create_rule, 'target') and \
-        isinstance(constraint._create_rule.target,
-                   sqltypes.TypeEngine):
+    if (
+        constraint._create_rule
+        and hasattr(constraint._create_rule, "target")
+        and isinstance(constraint._create_rule.target, sqltypes.TypeEngine)
+    ):
         return None
     opts = []
     if constraint.name:
         opts.append(
-            (
-                "name",
-                repr(
-                    _render_gen_name(
-                        autogen_context, constraint.name))
-            )
+            ("name", repr(_render_gen_name(autogen_context, constraint.name)))
         )
     return "%(prefix)sCheckConstraint(%(sqltext)s%(opts)s)" % {
         "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
-        "opts": ", " + (", ".join("%s=%s" % (k, v)
-                                  for k, v in opts)) if opts else "",
+        "opts": ", " + (", ".join("%s=%s" % (k, v) for k, v in opts))
+        if opts
+        else "",
         "sqltext": _render_potential_expr(
-            constraint.sqltext, autogen_context, wrap_in_text=False)
+            constraint.sqltext, autogen_context, wrap_in_text=False
+        ),
     }
 
 
@@ -788,7 +815,7 @@ def _execute_sql(autogen_context, op):
             "Autogenerate rendering of SQL Expression language constructs "
             "not supported here; please use a plain SQL string"
         )
-    return 'op.execute(%r)' % op.sqltext
+    return "op.execute(%r)" % op.sqltext
 
 
 renderers = default_renderers.branch()
index 941bd4b34c321b677b7193c6c1b03aeb230a7207..1e9522bd401975abe8927072fec7f5e24540ad34 100644 (file)
@@ -95,7 +95,8 @@ class Rewriter(object):
             yield directive
         else:
             for r_directive in util.to_list(
-                    _rewriter(context, revision, directive)):
+                _rewriter(context, revision, directive)
+            ):
                 yield r_directive
 
     def __call__(self, context, revision, directives):
@@ -110,17 +111,20 @@ class Rewriter(object):
             ret = self._traverse_for(context, revision, directive.upgrade_ops)
             if len(ret) != 1:
                 raise ValueError(
-                    "Can only return single object for UpgradeOps traverse")
+                    "Can only return single object for UpgradeOps traverse"
+                )
             upgrade_ops_list.append(ret[0])
         directive.upgrade_ops = upgrade_ops_list
 
         downgrade_ops_list = []
         for downgrade_ops in directive.downgrade_ops_list:
             ret = self._traverse_for(
-                context, revision, directive.downgrade_ops)
+                context, revision, directive.downgrade_ops
+            )
             if len(ret) != 1:
                 raise ValueError(
-                    "Can only return single object for DowngradeOps traverse")
+                    "Can only return single object for DowngradeOps traverse"
+                )
             downgrade_ops_list.append(ret[0])
         directive.downgrade_ops = downgrade_ops_list
 
index cd61fd1314a83146238988b9e52a24a9242f3e77..20027b40fd0663b8922b98e8bfb573f1025cc95a 100644 (file)
@@ -15,10 +15,9 @@ def list_templates(config):
 
     config.print_stdout("Available templates:\n")
     for tempname in os.listdir(config.get_template_directory()):
-        with open(os.path.join(
-                config.get_template_directory(),
-                tempname,
-                'README')) as readme:
+        with open(
+            os.path.join(config.get_template_directory(), tempname, "README")
+        ) as readme:
             synopsis = next(readme)
         config.print_stdout("%s - %s", tempname, synopsis)
 
@@ -26,7 +25,7 @@ def list_templates(config):
     config.print_stdout("\n  alembic init --template generic ./scripts")
 
 
-def init(config, directory, template='generic'):
+def init(config, directory, template="generic"):
     """Initialize a new scripts directory.
 
     :param config: a :class:`.Config` object.
@@ -41,48 +40,58 @@ def init(config, directory, template='generic'):
     if os.access(directory, os.F_OK):
         raise util.CommandError("Directory %s already exists" % directory)
 
-    template_dir = os.path.join(config.get_template_directory(),
-                                template)
+    template_dir = os.path.join(config.get_template_directory(), template)
     if not os.access(template_dir, os.F_OK):
         raise util.CommandError("No such template %r" % template)
 
-    util.status("Creating directory %s" % os.path.abspath(directory),
-                os.makedirs, directory)
+    util.status(
+        "Creating directory %s" % os.path.abspath(directory),
+        os.makedirs,
+        directory,
+    )
 
-    versions = os.path.join(directory, 'versions')
-    util.status("Creating directory %s" % os.path.abspath(versions),
-                os.makedirs, versions)
+    versions = os.path.join(directory, "versions")
+    util.status(
+        "Creating directory %s" % os.path.abspath(versions),
+        os.makedirs,
+        versions,
+    )
 
     script = ScriptDirectory(directory)
 
     for file_ in os.listdir(template_dir):
         file_path = os.path.join(template_dir, file_)
-        if file_ == 'alembic.ini.mako':
+        if file_ == "alembic.ini.mako":
             config_file = os.path.abspath(config.config_file_name)
             if os.access(config_file, os.F_OK):
                 util.msg("File %s already exists, skipping" % config_file)
             else:
                 script._generate_template(
-                    file_path,
-                    config_file,
-                    script_location=directory
+                    file_path, config_file, script_location=directory
                 )
         elif os.path.isfile(file_path):
             output_file = os.path.join(directory, file_)
-            script._copy_file(
-                file_path,
-                output_file
-            )
+            script._copy_file(file_path, output_file)
 
-    util.msg("Please edit configuration/connection/logging "
-             "settings in %r before proceeding." % config_file)
+    util.msg(
+        "Please edit configuration/connection/logging "
+        "settings in %r before proceeding." % config_file
+    )
 
 
 def revision(
-        config, message=None, autogenerate=False, sql=False,
-        head="head", splice=False, branch_label=None,
-        version_path=None, rev_id=None, depends_on=None,
-        process_revision_directives=None):
+    config,
+    message=None,
+    autogenerate=False,
+    sql=False,
+    head="head",
+    splice=False,
+    branch_label=None,
+    version_path=None,
+    rev_id=None,
+    depends_on=None,
+    process_revision_directives=None,
+):
     """Create a new revision file.
 
     :param config: a :class:`.Config` object.
@@ -134,35 +143,46 @@ def revision(
     command_args = dict(
         message=message,
         autogenerate=autogenerate,
-        sql=sql, head=head, splice=splice, branch_label=branch_label,
-        version_path=version_path, rev_id=rev_id, depends_on=depends_on
+        sql=sql,
+        head=head,
+        splice=splice,
+        branch_label=branch_label,
+        version_path=version_path,
+        rev_id=rev_id,
+        depends_on=depends_on,
     )
     revision_context = autogen.RevisionContext(
-        config, script_directory, command_args,
-        process_revision_directives=process_revision_directives)
-
-    environment = util.asbool(
-        config.get_main_option("revision_environment")
+        config,
+        script_directory,
+        command_args,
+        process_revision_directives=process_revision_directives,
     )
 
+    environment = util.asbool(config.get_main_option("revision_environment"))
+
     if autogenerate:
         environment = True
 
         if sql:
             raise util.CommandError(
-                "Using --sql with --autogenerate does not make any sense")
+                "Using --sql with --autogenerate does not make any sense"
+            )
 
         def retrieve_migrations(rev, context):
             revision_context.run_autogenerate(rev, context)
             return []
+
     elif environment:
+
         def retrieve_migrations(rev, context):
             revision_context.run_no_autogenerate(rev, context)
             return []
+
     elif sql:
         raise util.CommandError(
             "Using --sql with the revision command when "
-            "revision_environment is not configured does not make any sense")
+            "revision_environment is not configured does not make any sense"
+        )
 
     if environment:
         with EnvironmentContext(
@@ -171,14 +191,11 @@ def revision(
             fn=retrieve_migrations,
             as_sql=sql,
             template_args=revision_context.template_args,
-            revision_context=revision_context
+            revision_context=revision_context,
         ):
             script_directory.run_env()
 
-    scripts = [
-        script for script in
-        revision_context.generate_scripts()
-    ]
+    scripts = [script for script in revision_context.generate_scripts()]
     if len(scripts) == 1:
         return scripts[0]
     else:
@@ -207,13 +224,17 @@ def merge(config, revisions, message=None, branch_label=None, rev_id=None):
 
     script = ScriptDirectory.from_config(config)
     template_args = {
-        'config': config  # Let templates use config for
-                          # e.g. multiple databases
+        "config": config  # Let templates use config for
+        # e.g. multiple databases
     }
     return script.generate_revision(
-        rev_id or util.rev_id(), message, refresh=True,
-        head=revisions, branch_labels=branch_label,
-        **template_args)
+        rev_id or util.rev_id(),
+        message,
+        refresh=True,
+        head=revisions,
+        branch_labels=branch_label,
+        **template_args
+    )
 
 
 def upgrade(config, revision, sql=False, tag=None):
@@ -237,7 +258,7 @@ def upgrade(config, revision, sql=False, tag=None):
     if ":" in revision:
         if not sql:
             raise util.CommandError("Range revision not allowed")
-        starting_rev, revision = revision.split(':', 2)
+        starting_rev, revision = revision.split(":", 2)
 
     def upgrade(rev, context):
         return script._upgrade_revs(revision, rev)
@@ -249,7 +270,7 @@ def upgrade(config, revision, sql=False, tag=None):
         as_sql=sql,
         starting_rev=starting_rev,
         destination_rev=revision,
-        tag=tag
+        tag=tag,
     ):
         script.run_env()
 
@@ -274,10 +295,11 @@ def downgrade(config, revision, sql=False, tag=None):
     if ":" in revision:
         if not sql:
             raise util.CommandError("Range revision not allowed")
-        starting_rev, revision = revision.split(':', 2)
+        starting_rev, revision = revision.split(":", 2)
     elif sql:
         raise util.CommandError(
-            "downgrade with --sql requires <fromrev>:<torev>")
+            "downgrade with --sql requires <fromrev>:<torev>"
+        )
 
     def downgrade(rev, context):
         return script._downgrade_revs(revision, rev)
@@ -289,7 +311,7 @@ def downgrade(config, revision, sql=False, tag=None):
         as_sql=sql,
         starting_rev=starting_rev,
         destination_rev=revision,
-        tag=tag
+        tag=tag,
     ):
         script.run_env()
 
@@ -306,15 +328,13 @@ def show(config, rev):
     script = ScriptDirectory.from_config(config)
 
     if rev == "current":
+
         def show_current(rev, context):
             for sc in script.get_revisions(rev):
                 config.print_stdout(sc.log_entry)
             return []
-        with EnvironmentContext(
-            config,
-            script,
-            fn=show_current
-        ):
+
+        with EnvironmentContext(config, script, fn=show_current):
             script.run_env()
     else:
         for sc in script.get_revisions(rev):
@@ -340,44 +360,45 @@ def history(config, rev_range=None, verbose=False, indicate_current=False):
     if rev_range is not None:
         if ":" not in rev_range:
             raise util.CommandError(
-                "History range requires [start]:[end], "
-                "[start]:, or :[end]")
+                "History range requires [start]:[end], " "[start]:, or :[end]"
+            )
         base, head = rev_range.strip().split(":")
     else:
         base = head = None
 
-    environment = util.asbool(
-        config.get_main_option("revision_environment")
-    ) or indicate_current
+    environment = (
+        util.asbool(config.get_main_option("revision_environment"))
+        or indicate_current
+    )
 
     def _display_history(config, script, base, head, currents=()):
         for sc in script.walk_revisions(
-                base=base or "base",
-                head=head or "heads"):
+            base=base or "base", head=head or "heads"
+        ):
 
             if indicate_current:
                 sc._db_current_indicator = sc.revision in currents
 
             config.print_stdout(
                 sc.cmd_format(
-                    verbose=verbose, include_branches=True,
-                    include_doc=True, include_parents=True))
+                    verbose=verbose,
+                    include_branches=True,
+                    include_doc=True,
+                    include_parents=True,
+                )
+            )
 
     def _display_history_w_current(config, script, base, head):
         def _display_current_history(rev, context):
-            if head == 'current':
+            if head == "current":
                 _display_history(config, script, base, rev, rev)
-            elif base == 'current':
+            elif base == "current":
                 _display_history(config, script, rev, head, rev)
             else:
                 _display_history(config, script, base, head, rev)
             return []
 
-        with EnvironmentContext(
-            config,
-            script,
-            fn=_display_current_history
-        ):
+        with EnvironmentContext(config, script, fn=_display_current_history):
             script.run_env()
 
     if base == "current" or head == "current" or environment:
@@ -406,7 +427,9 @@ def heads(config, verbose=False, resolve_dependencies=False):
     for rev in heads:
         config.print_stdout(
             rev.cmd_format(
-                verbose, include_branches=True, tree_indicators=False))
+                verbose, include_branches=True, tree_indicators=False
+            )
+        )
 
 
 def branches(config, verbose=False):
@@ -424,13 +447,17 @@ def branches(config, verbose=False):
                 "%s\n%s\n",
                 sc.cmd_format(verbose, include_branches=True),
                 "\n".join(
-                    "%s -> %s" % (
+                    "%s -> %s"
+                    % (
                         " " * len(str(sc.revision)),
                         rev_obj.cmd_format(
-                            False, include_branches=True, include_doc=verbose)
-                    ) for rev_obj in
-                    (script.get_revision(rev) for rev in sc.nextrev)
-                )
+                            False, include_branches=True, include_doc=verbose
+                        ),
+                    )
+                    for rev_obj in (
+                        script.get_revision(rev) for rev in sc.nextrev
+                    )
+                ),
             )
 
 
@@ -454,18 +481,14 @@ def current(config, verbose=False, head_only=False):
         if verbose:
             config.print_stdout(
                 "Current revision(s) for %s:",
-                util.obfuscate_url_pw(context.connection.engine.url)
+                util.obfuscate_url_pw(context.connection.engine.url),
             )
         for rev in script.get_all_current(rev):
             config.print_stdout(rev.cmd_format(verbose))
 
         return []
 
-    with EnvironmentContext(
-        config,
-        script,
-        fn=display_version
-    ):
+    with EnvironmentContext(config, script, fn=display_version):
         script.run_env()
 
 
@@ -491,7 +514,7 @@ def stamp(config, revision, sql=False, tag=None):
     if ":" in revision:
         if not sql:
             raise util.CommandError("Range revision not allowed")
-        starting_rev, revision = revision.split(':', 2)
+        starting_rev, revision = revision.split(":", 2)
 
     def do_stamp(rev, context):
         return script._stamp_revs(revision, rev)
@@ -503,7 +526,7 @@ def stamp(config, revision, sql=False, tag=None):
         as_sql=sql,
         destination_rev=revision,
         starting_rev=starting_rev,
-        tag=tag
+        tag=tag,
     ):
         script.run_env()
 
@@ -520,23 +543,21 @@ def edit(config, rev):
     script = ScriptDirectory.from_config(config)
 
     if rev == "current":
+
         def edit_current(rev, context):
             if not rev:
                 raise util.CommandError("No current revisions")
             for sc in script.get_revisions(rev):
                 util.edit(sc.path)
             return []
-        with EnvironmentContext(
-            config,
-            script,
-            fn=edit_current
-        ):
+
+        with EnvironmentContext(config, script, fn=edit_current):
             script.run_env()
     else:
         revs = script.get_revisions(rev)
         if not revs:
             raise util.CommandError(
-                "No revision files indicated by symbol '%s'" % rev)
+                "No revision files indicated by symbol '%s'" % rev
+            )
         for sc in revs:
             util.edit(sc.path)
-
index 5856099a9020936ca65362e1d0665c01f78f5c51..915091c3f5db8736501ab8df9153e2e2f12f5763 100644 (file)
@@ -90,9 +90,16 @@ class Config(object):
 
     """
 
-    def __init__(self, file_=None, ini_section='alembic', output_buffer=None,
-                 stdout=sys.stdout, cmd_opts=None,
-                 config_args=util.immutabledict(), attributes=None):
+    def __init__(
+        self,
+        file_=None,
+        ini_section="alembic",
+        output_buffer=None,
+        stdout=sys.stdout,
+        cmd_opts=None,
+        config_args=util.immutabledict(),
+        attributes=None,
+    ):
         """Construct a new :class:`.Config`
 
         """
@@ -167,15 +174,11 @@ class Config(object):
         """
 
         if arg:
-            output = (compat.text_type(text) % arg)
+            output = compat.text_type(text) % arg
         else:
             output = compat.text_type(text)
 
-        util.write_outstream(
-            self.stdout,
-            output,
-            "\n"
-        )
+        util.write_outstream(self.stdout, output, "\n")
 
     @util.memoized_property
     def file_config(self):
@@ -192,7 +195,7 @@ class Config(object):
             here = os.path.abspath(os.path.dirname(self.config_file_name))
         else:
             here = ""
-        self.config_args['here'] = here
+        self.config_args["here"] = here
         file_config = SafeConfigParser(self.config_args)
         if self.config_file_name:
             file_config.read([self.config_file_name])
@@ -207,7 +210,7 @@ class Config(object):
         commands.
 
         """
-        return os.path.join(package_dir, 'templates')
+        return os.path.join(package_dir, "templates")
 
     def get_section(self, name):
         """Return all the configuration options from a given .ini file section
@@ -265,9 +268,10 @@ class Config(object):
 
         """
         if not self.file_config.has_section(section):
-            raise util.CommandError("No config file %r found, or file has no "
-                                    "'[%s]' section" %
-                                    (self.config_file_name, section))
+            raise util.CommandError(
+                "No config file %r found, or file has no "
+                "'[%s]' section" % (self.config_file_name, section)
+            )
         if self.file_config.has_option(section, name):
             return self.file_config.get(section, name)
         else:
@@ -285,140 +289,144 @@ class Config(object):
 
 
 class CommandLine(object):
-
     def __init__(self, prog=None):
         self._generate_args(prog)
 
     def _generate_args(self, prog):
         def add_options(parser, positional, kwargs):
             kwargs_opts = {
-                'template': (
-                    "-t", "--template",
+                "template": (
+                    "-t",
+                    "--template",
                     dict(
-                        default='generic',
+                        default="generic",
                         type=str,
-                        help="Setup template for use with 'init'"
-                    )
+                        help="Setup template for use with 'init'",
+                    ),
                 ),
-                'message': (
-                    "-m", "--message",
+                "message": (
+                    "-m",
+                    "--message",
                     dict(
-                        type=str,
-                        help="Message string to use with 'revision'")
+                        type=str, help="Message string to use with 'revision'"
+                    ),
                 ),
-                'sql': (
+                "sql": (
                     "--sql",
                     dict(
                         action="store_true",
                         help="Don't emit SQL to database - dump to "
                         "standard output/file instead. See docs on "
-                        "offline mode."
-                    )
+                        "offline mode.",
+                    ),
                 ),
-                'tag': (
+                "tag": (
                     "--tag",
                     dict(
                         type=str,
                         help="Arbitrary 'tag' name - can be used by "
-                        "custom env.py scripts.")
+                        "custom env.py scripts.",
+                    ),
                 ),
-                'head': (
+                "head": (
                     "--head",
                     dict(
                         type=str,
                         help="Specify head revision or <branchname>@head "
-                        "to base new revision on."
-                    )
+                        "to base new revision on.",
+                    ),
                 ),
-                'splice': (
+                "splice": (
                     "--splice",
                     dict(
                         action="store_true",
                         help="Allow a non-head revision as the "
-                        "'head' to splice onto"
-                    )
+                        "'head' to splice onto",
+                    ),
                 ),
-                'depends_on': (
+                "depends_on": (
                     "--depends-on",
                     dict(
                         action="append",
                         help="Specify one or more revision identifiers "
-                        "which this revision should depend on."
-                    )
+                        "which this revision should depend on.",
+                    ),
                 ),
-                'rev_id': (
+                "rev_id": (
                     "--rev-id",
                     dict(
                         type=str,
                         help="Specify a hardcoded revision id instead of "
-                        "generating one"
-                    )
+                        "generating one",
+                    ),
                 ),
-                'version_path': (
+                "version_path": (
                     "--version-path",
                     dict(
                         type=str,
                         help="Specify specific path from config for "
-                        "version file"
-                    )
+                        "version file",
+                    ),
                 ),
-                'branch_label': (
+                "branch_label": (
                     "--branch-label",
                     dict(
                         type=str,
                         help="Specify a branch label to apply to the "
-                        "new revision"
-                    )
+                        "new revision",
+                    ),
                 ),
-                'verbose': (
-                    "-v", "--verbose",
-                    dict(
-                        action="store_true",
-                        help="Use more verbose output"
-                    )
+                "verbose": (
+                    "-v",
+                    "--verbose",
+                    dict(action="store_true", help="Use more verbose output"),
                 ),
-                'resolve_dependencies': (
-                    '--resolve-dependencies',
+                "resolve_dependencies": (
+                    "--resolve-dependencies",
                     dict(
                         action="store_true",
-                        help="Treat dependency versions as down revisions"
-                    )
+                        help="Treat dependency versions as down revisions",
+                    ),
                 ),
-                'autogenerate': (
+                "autogenerate": (
                     "--autogenerate",
                     dict(
                         action="store_true",
                         help="Populate revision script with candidate "
                         "migration operations, based on comparison "
-                        "of database to model.")
+                        "of database to model.",
+                    ),
                 ),
-                'head_only': (
+                "head_only": (
                     "--head-only",
                     dict(
                         action="store_true",
                         help="Deprecated.  Use --verbose for "
-                        "additional output")
+                        "additional output",
+                    ),
                 ),
-                'rev_range': (
-                    "-r", "--rev-range",
+                "rev_range": (
+                    "-r",
+                    "--rev-range",
                     dict(
                         action="store",
                         help="Specify a revision range; "
-                        "format is [start]:[end]")
+                        "format is [start]:[end]",
+                    ),
                 ),
-                'indicate_current': (
-                    "-i", "--indicate-current",
+                "indicate_current": (
+                    "-i",
+                    "--indicate-current",
                     dict(
                         action="store_true",
-                        help="Indicate the current revision"
-                    )
-                )
+                        help="Indicate the current revision",
+                    ),
+                ),
             }
             positional_help = {
-                'directory': "location of scripts directory",
-                'revision': "revision identifier",
-                'revisions': "one or more revisions, or 'heads' for all heads"
-
+                "directory": "location of scripts directory",
+                "revision": "revision identifier",
+                "revisions": "one or more revisions, or 'heads' for all heads",
             }
             for arg in kwargs:
                 if arg in kwargs_opts:
@@ -429,44 +437,56 @@ class CommandLine(object):
             for arg in positional:
                 if arg == "revisions":
                     subparser.add_argument(
-                        arg, nargs='+', help=positional_help.get(arg))
+                        arg, nargs="+", help=positional_help.get(arg)
+                    )
                 else:
                     subparser.add_argument(arg, help=positional_help.get(arg))
 
         parser = ArgumentParser(prog=prog)
-        parser.add_argument("-c", "--config",
-                            type=str,
-                            default="alembic.ini",
-                            help="Alternate config file")
-        parser.add_argument("-n", "--name",
-                            type=str,
-                            default="alembic",
-                            help="Name of section in .ini file to "
-                                    "use for Alembic config")
-        parser.add_argument("-x", action="append",
-                            help="Additional arguments consumed by "
-                            "custom env.py scripts, e.g. -x "
-                            "setting1=somesetting -x setting2=somesetting")
-        parser.add_argument("--raiseerr", action="store_true",
-                            help="Raise a full stack trace on error")
+        parser.add_argument(
+            "-c",
+            "--config",
+            type=str,
+            default="alembic.ini",
+            help="Alternate config file",
+        )
+        parser.add_argument(
+            "-n",
+            "--name",
+            type=str,
+            default="alembic",
+            help="Name of section in .ini file to " "use for Alembic config",
+        )
+        parser.add_argument(
+            "-x",
+            action="append",
+            help="Additional arguments consumed by "
+            "custom env.py scripts, e.g. -x "
+            "setting1=somesetting -x setting2=somesetting",
+        )
+        parser.add_argument(
+            "--raiseerr",
+            action="store_true",
+            help="Raise a full stack trace on error",
+        )
         subparsers = parser.add_subparsers()
 
         for fn in [getattr(command, n) for n in dir(command)]:
-            if inspect.isfunction(fn) and \
-                    fn.__name__[0] != '_' and \
-                    fn.__module__ == 'alembic.command':
+            if (
+                inspect.isfunction(fn)
+                and fn.__name__[0] != "_"
+                and fn.__module__ == "alembic.command"
+            ):
 
                 spec = compat.inspect_getargspec(fn)
                 if spec[3]:
-                    positional = spec[0][1:-len(spec[3])]
-                    kwarg = spec[0][-len(spec[3]):]
+                    positional = spec[0][1 : -len(spec[3])]
+                    kwarg = spec[0][-len(spec[3]) :]
                 else:
                     positional = spec[0][1:]
                     kwarg = []
 
-                subparser = subparsers.add_parser(
-                    fn.__name__,
-                    help=fn.__doc__)
+                subparser = subparsers.add_parser(fn.__name__, help=fn.__doc__)
                 add_options(subparser, positional, kwarg)
                 subparser.set_defaults(cmd=(fn, positional, kwarg))
         self.parser = parser
@@ -475,10 +495,11 @@ class CommandLine(object):
         fn, positional, kwarg = options.cmd
 
         try:
-            fn(config,
-               *[getattr(options, k, None) for k in positional],
-               **dict((k, getattr(options, k, None)) for k in kwarg)
-               )
+            fn(
+                config,
+                *[getattr(options, k, None) for k in positional],
+                **dict((k, getattr(options, k, None)) for k in kwarg)
+            )
         except util.CommandError as e:
             if options.raiseerr:
                 raise
@@ -492,8 +513,11 @@ class CommandLine(object):
             # behavior changed incompatibly in py3.3
             self.parser.error("too few arguments")
         else:
-            cfg = Config(file_=options.config,
-                         ini_section=options.name, cmd_opts=options)
+            cfg = Config(
+                file_=options.config,
+                ini_section=options.name,
+                cmd_opts=options,
+            )
             self.run_cmd(cfg, options)
 
 
@@ -502,5 +526,6 @@ def main(argv=None, prog=None, **kwargs):
 
     CommandLine(prog=prog).main(argv=argv)
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     main()
index f4a525f2725ca18a39a1af0504cf68908ce21f69..f177a0760c4f11951e5e932294f0ed2223334ab4 100644 (file)
@@ -9,7 +9,11 @@ from .. import util
 # backwards compat
 from ..util.sqla_compat import (  # noqa
     _table_for_constraint,
-    _columns_for_constraint, _fk_spec, _is_type_bound, _find_columns)
+    _columns_for_constraint,
+    _fk_spec,
+    _is_type_bound,
+    _find_columns,
+)
 
 if util.sqla_09:
     from sqlalchemy.sql.elements import quoted_name
@@ -30,65 +34,63 @@ class AlterTable(DDLElement):
 
 
 class RenameTable(AlterTable):
-
     def __init__(self, old_table_name, new_table_name, schema=None):
         super(RenameTable, self).__init__(old_table_name, schema=schema)
         self.new_table_name = new_table_name
 
 
 class AlterColumn(AlterTable):
-
-    def __init__(self, name, column_name, schema=None,
-                 existing_type=None,
-                 existing_nullable=None,
-                 existing_server_default=None):
+    def __init__(
+        self,
+        name,
+        column_name,
+        schema=None,
+        existing_type=None,
+        existing_nullable=None,
+        existing_server_default=None,
+    ):
         super(AlterColumn, self).__init__(name, schema=schema)
         self.column_name = column_name
-        self.existing_type = sqltypes.to_instance(existing_type) \
-            if existing_type is not None else None
+        self.existing_type = (
+            sqltypes.to_instance(existing_type)
+            if existing_type is not None
+            else None
+        )
         self.existing_nullable = existing_nullable
         self.existing_server_default = existing_server_default
 
 
 class ColumnNullable(AlterColumn):
-
     def __init__(self, name, column_name, nullable, **kw):
-        super(ColumnNullable, self).__init__(name, column_name,
-                                             **kw)
+        super(ColumnNullable, self).__init__(name, column_name, **kw)
         self.nullable = nullable
 
 
 class ColumnType(AlterColumn):
-
     def __init__(self, name, column_name, type_, **kw):
-        super(ColumnType, self).__init__(name, column_name,
-                                         **kw)
+        super(ColumnType, self).__init__(name, column_name, **kw)
         self.type_ = sqltypes.to_instance(type_)
 
 
 class ColumnName(AlterColumn):
-
     def __init__(self, name, column_name, newname, **kw):
         super(ColumnName, self).__init__(name, column_name, **kw)
         self.newname = newname
 
 
 class ColumnDefault(AlterColumn):
-
     def __init__(self, name, column_name, default, **kw):
         super(ColumnDefault, self).__init__(name, column_name, **kw)
         self.default = default
 
 
 class AddColumn(AlterTable):
-
     def __init__(self, name, column, schema=None):
         super(AddColumn, self).__init__(name, schema=schema)
         self.column = column
 
 
 class DropColumn(AlterTable):
-
     def __init__(self, name, column, schema=None):
         super(DropColumn, self).__init__(name, schema=schema)
         self.column = column
@@ -98,7 +100,7 @@ class DropColumn(AlterTable):
 def visit_rename_table(element, compiler, **kw):
     return "%s RENAME TO %s" % (
         alter_table(compiler, element.table_name, element.schema),
-        format_table_name(compiler, element.new_table_name, element.schema)
+        format_table_name(compiler, element.new_table_name, element.schema),
     )
 
 
@@ -106,7 +108,7 @@ def visit_rename_table(element, compiler, **kw):
 def visit_add_column(element, compiler, **kw):
     return "%s %s" % (
         alter_table(compiler, element.table_name, element.schema),
-        add_column(compiler, element.column, **kw)
+        add_column(compiler, element.column, **kw),
     )
 
 
@@ -114,7 +116,7 @@ def visit_add_column(element, compiler, **kw):
 def visit_drop_column(element, compiler, **kw):
     return "%s %s" % (
         alter_table(compiler, element.table_name, element.schema),
-        drop_column(compiler, element.column.name, **kw)
+        drop_column(compiler, element.column.name, **kw),
     )
 
 
@@ -123,7 +125,7 @@ def visit_column_nullable(element, compiler, **kw):
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
-        "DROP NOT NULL" if element.nullable else "SET NOT NULL"
+        "DROP NOT NULL" if element.nullable else "SET NOT NULL",
     )
 
 
@@ -132,7 +134,7 @@ def visit_column_type(element, compiler, **kw):
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
-        "TYPE %s" % format_type(compiler, element.type_)
+        "TYPE %s" % format_type(compiler, element.type_),
     )
 
 
@@ -141,7 +143,7 @@ def visit_column_name(element, compiler, **kw):
     return "%s RENAME %s TO %s" % (
         alter_table(compiler, element.table_name, element.schema),
         format_column_name(compiler, element.column_name),
-        format_column_name(compiler, element.newname)
+        format_column_name(compiler, element.newname),
     )
 
 
@@ -150,10 +152,9 @@ def visit_column_default(element, compiler, **kw):
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
-        "SET DEFAULT %s" %
-        format_server_default(compiler, element.default)
+        "SET DEFAULT %s" % format_server_default(compiler, element.default)
         if element.default is not None
-        else "DROP DEFAULT"
+        else "DROP DEFAULT",
     )
 
 
@@ -162,7 +163,7 @@ def quote_dotted(name, quote):
 
     if util.sqla_09 and isinstance(name, quoted_name):
         return quote(name)
-    result = '.'.join([quote(x) for x in name.split('.')])
+    result = ".".join([quote(x) for x in name.split(".")])
     return result
 
 
@@ -193,11 +194,11 @@ def alter_table(compiler, name, schema):
 
 
 def drop_column(compiler, name):
-    return 'DROP COLUMN %s' % format_column_name(compiler, name)
+    return "DROP COLUMN %s" % format_column_name(compiler, name)
 
 
 def alter_column(compiler, name):
-    return 'ALTER COLUMN %s' % format_column_name(compiler, name)
+    return "ALTER COLUMN %s" % format_column_name(compiler, name)
 
 
 def add_column(compiler, column, **kw):
index 98be164a9dfc201d8ebc0ad72bc582cf73f981d2..4e3ff04da06aa9d6d15362b3380611a52258a96d 100644 (file)
@@ -1,22 +1,20 @@
 from sqlalchemy import schema, text
 from sqlalchemy import types as sqltypes
 
-from ..util.compat import (
-    string_types, text_type, with_metaclass
-)
+from ..util.compat import string_types, text_type, with_metaclass
 from ..util import sqla_compat
 from .. import util
 from . import base
 
 
 class ImplMeta(type):
-
     def __init__(cls, classname, bases, dict_):
         newtype = type.__init__(cls, classname, bases, dict_)
-        if '__dialect__' in dict_:
-            _impls[dict_['__dialect__']] = cls
+        if "__dialect__" in dict_:
+            _impls[dict_["__dialect__"]] = cls
         return newtype
 
+
 _impls = {}
 
 
@@ -33,18 +31,25 @@ class DefaultImpl(with_metaclass(ImplMeta)):
     bulk inserts.
 
     """
-    __dialect__ = 'default'
+
+    __dialect__ = "default"
 
     transactional_ddl = False
     command_terminator = ";"
 
-    def __init__(self, dialect, connection, as_sql,
-                 transactional_ddl, output_buffer,
-                 context_opts):
+    def __init__(
+        self,
+        dialect,
+        connection,
+        as_sql,
+        transactional_ddl,
+        output_buffer,
+        context_opts,
+    ):
         self.dialect = dialect
         self.connection = connection
         self.as_sql = as_sql
-        self.literal_binds = context_opts.get('literal_binds', False)
+        self.literal_binds = context_opts.get("literal_binds", False)
 
         self.output_buffer = output_buffer
         self.memo = {}
@@ -55,7 +60,8 @@ class DefaultImpl(with_metaclass(ImplMeta)):
         if self.literal_binds:
             if not self.as_sql:
                 raise util.CommandError(
-                    "Can't use literal_binds setting without as_sql mode")
+                    "Can't use literal_binds setting without as_sql mode"
+                )
 
     @classmethod
     def get_by_dialect(cls, dialect):
@@ -89,9 +95,13 @@ class DefaultImpl(with_metaclass(ImplMeta)):
     def bind(self):
         return self.connection
 
-    def _exec(self, construct, execution_options=None,
-              multiparams=(),
-              params=util.immutabledict()):
+    def _exec(
+        self,
+        construct,
+        execution_options=None,
+        multiparams=(),
+        params=util.immutabledict(),
+    ):
         if isinstance(construct, string_types):
             construct = text(construct)
         if self.as_sql:
@@ -100,14 +110,20 @@ class DefaultImpl(with_metaclass(ImplMeta)):
                 raise Exception("Execution arguments not allowed with as_sql")
 
             if self.literal_binds and not isinstance(
-                    construct, schema.DDLElement):
+                construct, schema.DDLElement
+            ):
                 compile_kw = dict(compile_kwargs={"literal_binds": True})
             else:
                 compile_kw = {}
 
-            self.static_output(text_type(
-                construct.compile(dialect=self.dialect, **compile_kw)
-            ).replace("\t", "    ").strip() + self.command_terminator)
+            self.static_output(
+                text_type(
+                    construct.compile(dialect=self.dialect, **compile_kw)
+                )
+                .replace("\t", "    ")
+                .strip()
+                + self.command_terminator
+            )
         else:
             conn = self.connection
             if execution_options:
@@ -117,53 +133,75 @@ class DefaultImpl(with_metaclass(ImplMeta)):
     def execute(self, sql, execution_options=None):
         self._exec(sql, execution_options)
 
-    def alter_column(self, table_name, column_name,
-                     nullable=None,
-                     server_default=False,
-                     name=None,
-                     type_=None,
-                     schema=None,
-                     autoincrement=None,
-                     existing_type=None,
-                     existing_server_default=None,
-                     existing_nullable=None,
-                     existing_autoincrement=None
-                     ):
+    def alter_column(
+        self,
+        table_name,
+        column_name,
+        nullable=None,
+        server_default=False,
+        name=None,
+        type_=None,
+        schema=None,
+        autoincrement=None,
+        existing_type=None,
+        existing_server_default=None,
+        existing_nullable=None,
+        existing_autoincrement=None,
+    ):
         if autoincrement is not None or existing_autoincrement is not None:
             util.warn(
                 "autoincrement and existing_autoincrement "
-                "only make sense for MySQL")
+                "only make sense for MySQL"
+            )
         if nullable is not None:
-            self._exec(base.ColumnNullable(
-                table_name, column_name,
-                nullable, schema=schema,
-                existing_type=existing_type,
-                existing_server_default=existing_server_default,
-                existing_nullable=existing_nullable,
-            ))
+            self._exec(
+                base.ColumnNullable(
+                    table_name,
+                    column_name,
+                    nullable,
+                    schema=schema,
+                    existing_type=existing_type,
+                    existing_server_default=existing_server_default,
+                    existing_nullable=existing_nullable,
+                )
+            )
         if server_default is not False:
-            self._exec(base.ColumnDefault(
-                table_name, column_name, server_default,
-                schema=schema,
-                existing_type=existing_type,
-                existing_server_default=existing_server_default,
-                existing_nullable=existing_nullable,
-            ))
+            self._exec(
+                base.ColumnDefault(
+                    table_name,
+                    column_name,
+                    server_default,
+                    schema=schema,
+                    existing_type=existing_type,
+                    existing_server_default=existing_server_default,
+                    existing_nullable=existing_nullable,
+                )
+            )
         if type_ is not None:
-            self._exec(base.ColumnType(
-                table_name, column_name, type_, schema=schema,
-                existing_type=existing_type,
-                existing_server_default=existing_server_default,
-                existing_nullable=existing_nullable,
-            ))
+            self._exec(
+                base.ColumnType(
+                    table_name,
+                    column_name,
+                    type_,
+                    schema=schema,
+                    existing_type=existing_type,
+                    existing_server_default=existing_server_default,
+                    existing_nullable=existing_nullable,
+                )
+            )
         # do the new name last ;)
         if name is not None:
-            self._exec(base.ColumnName(
-                table_name, column_name, name, schema=schema,
-                existing_type=existing_type,
-                existing_server_default=existing_server_default,
-                existing_nullable=existing_nullable,
-            ))
+            self._exec(
+                base.ColumnName(
+                    table_name,
+                    column_name,
+                    name,
+                    schema=schema,
+                    existing_type=existing_type,
+                    existing_server_default=existing_server_default,
+                    existing_nullable=existing_nullable,
+                )
+            )
 
     def add_column(self, table_name, column, schema=None):
         self._exec(base.AddColumn(table_name, column, schema=schema))
@@ -172,25 +210,25 @@ class DefaultImpl(with_metaclass(ImplMeta)):
         self._exec(base.DropColumn(table_name, column, schema=schema))
 
     def add_constraint(self, const):
-        if const._create_rule is None or \
-                const._create_rule(self):
+        if const._create_rule is None or const._create_rule(self):
             self._exec(schema.AddConstraint(const))
 
     def drop_constraint(self, const):
         self._exec(schema.DropConstraint(const))
 
     def rename_table(self, old_table_name, new_table_name, schema=None):
-        self._exec(base.RenameTable(old_table_name,
-                                    new_table_name, schema=schema))
+        self._exec(
+            base.RenameTable(old_table_name, new_table_name, schema=schema)
+        )
 
     def create_table(self, table):
-        table.dispatch.before_create(table, self.connection,
-                                     checkfirst=False,
-                                     _ddl_runner=self)
+        table.dispatch.before_create(
+            table, self.connection, checkfirst=False, _ddl_runner=self
+        )
         self._exec(schema.CreateTable(table))
-        table.dispatch.after_create(table, self.connection,
-                                    checkfirst=False,
-                                    _ddl_runner=self)
+        table.dispatch.after_create(
+            table, self.connection, checkfirst=False, _ddl_runner=self
+        )
         for index in table.indexes:
             self._exec(schema.CreateIndex(index))
 
@@ -210,17 +248,26 @@ class DefaultImpl(with_metaclass(ImplMeta)):
             raise TypeError("List of dictionaries expected")
         if self.as_sql:
             for row in rows:
-                self._exec(table.insert(inline=True).values(**dict(
-                    (k,
-                        sqla_compat._literal_bindparam(
-                            k, v, type_=table.c[k].type)
-                        if not isinstance(
-                            v, sqla_compat._literal_bindparam) else v)
-                    for k, v in row.items()
-                )))
+                self._exec(
+                    table.insert(inline=True).values(
+                        **dict(
+                            (
+                                k,
+                                sqla_compat._literal_bindparam(
+                                    k, v, type_=table.c[k].type
+                                )
+                                if not isinstance(
+                                    v, sqla_compat._literal_bindparam
+                                )
+                                else v,
+                            )
+                            for k, v in row.items()
+                        )
+                    )
+                )
         else:
             # work around http://www.sqlalchemy.org/trac/ticket/2461
-            if not hasattr(table, '_autoincrement_column'):
+            if not hasattr(table, "_autoincrement_column"):
                 table._autoincrement_column = None
             if rows:
                 if multiinsert:
@@ -240,32 +287,38 @@ class DefaultImpl(with_metaclass(ImplMeta)):
 
         # work around SQLAlchemy bug "stale value for type affinity"
         # fixed in 0.7.4
-        metadata_impl.__dict__.pop('_type_affinity', None)
+        metadata_impl.__dict__.pop("_type_affinity", None)
 
         if hasattr(metadata_impl, "compare_against_backend"):
             comparison = metadata_impl.compare_against_backend(
-                self.dialect, conn_type)
+                self.dialect, conn_type
+            )
             if comparison is not None:
                 return not comparison
 
-        if conn_type._compare_type_affinity(
-            metadata_impl
-        ):
+        if conn_type._compare_type_affinity(metadata_impl):
             comparator = _type_comparators.get(conn_type._type_affinity, None)
 
             return comparator and comparator(metadata_impl, conn_type)
         else:
             return True
 
-    def compare_server_default(self, inspector_column,
-                               metadata_column,
-                               rendered_metadata_default,
-                               rendered_inspector_default):
+    def compare_server_default(
+        self,
+        inspector_column,
+        metadata_column,
+        rendered_metadata_default,
+        rendered_inspector_default,
+    ):
         return rendered_inspector_default != rendered_metadata_default
 
-    def correct_for_autogen_constraints(self, conn_uniques, conn_indexes,
-                                        metadata_unique_constraints,
-                                        metadata_indexes):
+    def correct_for_autogen_constraints(
+        self,
+        conn_uniques,
+        conn_indexes,
+        metadata_unique_constraints,
+        metadata_indexes,
+    ):
         pass
 
     def _compat_autogen_column_reflect(self, inspector):
@@ -316,38 +369,37 @@ class DefaultImpl(with_metaclass(ImplMeta)):
 
 
 def _string_compare(t1, t2):
-    return \
-        t1.length is not None and \
-        t1.length != t2.length
+    return t1.length is not None and t1.length != t2.length
 
 
 def _numeric_compare(t1, t2):
-    return (
-        t1.precision is not None and
-        t1.precision != t2.precision
-    ) or (
-        t1.precision is not None and
-        t1.scale is not None and
-        t1.scale != t2.scale
+    return (t1.precision is not None and t1.precision != t2.precision) or (
+        t1.precision is not None
+        and t1.scale is not None
+        and t1.scale != t2.scale
     )
 
 
 def _integer_compare(t1, t2):
     t1_small_or_big = (
-        'S' if isinstance(t1, sqltypes.SmallInteger)
-        else 'B' if isinstance(t1, sqltypes.BigInteger) else 'I'
+        "S"
+        if isinstance(t1, sqltypes.SmallInteger)
+        else "B"
+        if isinstance(t1, sqltypes.BigInteger)
+        else "I"
     )
     t2_small_or_big = (
-        'S' if isinstance(t2, sqltypes.SmallInteger)
-        else 'B' if isinstance(t2, sqltypes.BigInteger) else 'I'
+        "S"
+        if isinstance(t2, sqltypes.SmallInteger)
+        else "B"
+        if isinstance(t2, sqltypes.BigInteger)
+        else "I"
     )
     return t1_small_or_big != t2_small_or_big
 
 
 def _datetime_compare(t1, t2):
-    return (
-        t1.timezone != t2.timezone
-    )
+    return t1.timezone != t2.timezone
 
 
 _type_comparators = {
index f303be480ad888f3a5dd4614ef824454f515168a..7f43a89147aac952898d3221b97b728e4df80631 100644 (file)
@@ -2,24 +2,35 @@ from sqlalchemy.ext.compiler import compiles
 
 from .. import util
 from .impl import DefaultImpl
-from .base import alter_table, AddColumn, ColumnName, RenameTable,\
-    format_table_name, format_column_name, ColumnNullable, alter_column,\
-    format_server_default, ColumnDefault, format_type, ColumnType
+from .base import (
+    alter_table,
+    AddColumn,
+    ColumnName,
+    RenameTable,
+    format_table_name,
+    format_column_name,
+    ColumnNullable,
+    alter_column,
+    format_server_default,
+    ColumnDefault,
+    format_type,
+    ColumnType,
+)
 from sqlalchemy.sql.expression import ClauseElement, Executable
 from sqlalchemy.schema import CreateIndex, Column
 from sqlalchemy import types as sqltypes
 
 
 class MSSQLImpl(DefaultImpl):
-    __dialect__ = 'mssql'
+    __dialect__ = "mssql"
     transactional_ddl = True
     batch_separator = "GO"
 
     def __init__(self, *arg, **kw):
         super(MSSQLImpl, self).__init__(*arg, **kw)
         self.batch_separator = self.context_opts.get(
-            "mssql_batch_separator",
-            self.batch_separator)
+            "mssql_batch_separator", self.batch_separator
+        )
 
     def _exec(self, construct, *args, **kw):
         result = super(MSSQLImpl, self)._exec(construct, *args, **kw)
@@ -35,17 +46,20 @@ class MSSQLImpl(DefaultImpl):
         if self.as_sql and self.batch_separator:
             self.static_output(self.batch_separator)
 
-    def alter_column(self, table_name, column_name,
-                     nullable=None,
-                     server_default=False,
-                     name=None,
-                     type_=None,
-                     schema=None,
-                     existing_type=None,
-                     existing_server_default=None,
-                     existing_nullable=None,
-                     **kw
-                     ):
+    def alter_column(
+        self,
+        table_name,
+        column_name,
+        nullable=None,
+        server_default=False,
+        name=None,
+        type_=None,
+        schema=None,
+        existing_type=None,
+        existing_server_default=None,
+        existing_nullable=None,
+        **kw
+    ):
 
         if nullable is not None and existing_type is None:
             if type_ is not None:
@@ -57,10 +71,12 @@ class MSSQLImpl(DefaultImpl):
                 raise util.CommandError(
                     "MS-SQL ALTER COLUMN operations "
                     "with NULL or NOT NULL require the "
-                    "existing_type or a new type_ be passed.")
+                    "existing_type or a new type_ be passed."
+                )
 
         super(MSSQLImpl, self).alter_column(
-            table_name, column_name,
+            table_name,
+            column_name,
             nullable=nullable,
             type_=type_,
             schema=schema,
@@ -70,30 +86,30 @@ class MSSQLImpl(DefaultImpl):
         )
 
         if server_default is not False:
-            if existing_server_default is not False or \
-                    server_default is None:
+            if existing_server_default is not False or server_default is None:
                 self._exec(
                     _ExecDropConstraint(
-                        table_name, column_name,
-                        'sys.default_constraints')
+                        table_name, column_name, "sys.default_constraints"
+                    )
                 )
             if server_default is not None:
                 super(MSSQLImpl, self).alter_column(
-                    table_name, column_name,
+                    table_name,
+                    column_name,
                     schema=schema,
-                    server_default=server_default)
+                    server_default=server_default,
+                )
 
         if name is not None:
             super(MSSQLImpl, self).alter_column(
-                table_name, column_name,
-                schema=schema,
-                name=name)
+                table_name, column_name, schema=schema, name=name
+            )
 
     def create_index(self, index):
         # this likely defaults to None if not present, so get()
         # should normally not return the default value.  being
         # defensive in any case
-        mssql_include = index.kwargs.get('mssql_include', None) or ()
+        mssql_include = index.kwargs.get("mssql_include", None) or ()
         for col in mssql_include:
             if col not in index.table.c:
                 index.table.append_column(Column(col, sqltypes.NullType))
@@ -102,42 +118,39 @@ class MSSQLImpl(DefaultImpl):
     def bulk_insert(self, table, rows, **kw):
         if self.as_sql:
             self._exec(
-                "SET IDENTITY_INSERT %s ON" %
-                self.dialect.identifier_preparer.format_table(table)
+                "SET IDENTITY_INSERT %s ON"
+                self.dialect.identifier_preparer.format_table(table)
             )
             super(MSSQLImpl, self).bulk_insert(table, rows, **kw)
             self._exec(
-                "SET IDENTITY_INSERT %s OFF" %
-                self.dialect.identifier_preparer.format_table(table)
+                "SET IDENTITY_INSERT %s OFF"
+                self.dialect.identifier_preparer.format_table(table)
             )
         else:
             super(MSSQLImpl, self).bulk_insert(table, rows, **kw)
 
     def drop_column(self, table_name, column, **kw):
-        drop_default = kw.pop('mssql_drop_default', False)
+        drop_default = kw.pop("mssql_drop_default", False)
         if drop_default:
             self._exec(
                 _ExecDropConstraint(
-                    table_name, column,
-                    'sys.default_constraints')
+                    table_name, column, "sys.default_constraints"
+                )
             )
-        drop_check = kw.pop('mssql_drop_check', False)
+        drop_check = kw.pop("mssql_drop_check", False)
         if drop_check:
             self._exec(
                 _ExecDropConstraint(
-                    table_name, column,
-                    'sys.check_constraints')
+                    table_name, column, "sys.check_constraints"
+                )
             )
-        drop_fks = kw.pop('mssql_drop_foreign_key', False)
+        drop_fks = kw.pop("mssql_drop_foreign_key", False)
         if drop_fks:
-            self._exec(
-                _ExecDropFKConstraint(table_name, column)
-            )
+            self._exec(_ExecDropFKConstraint(table_name, column))
         super(MSSQLImpl, self).drop_column(table_name, column, **kw)
 
 
 class _ExecDropConstraint(Executable, ClauseElement):
-
     def __init__(self, tname, colname, type_):
         self.tname = tname
         self.colname = colname
@@ -145,13 +158,12 @@ class _ExecDropConstraint(Executable, ClauseElement):
 
 
 class _ExecDropFKConstraint(Executable, ClauseElement):
-
     def __init__(self, tname, colname):
         self.tname = tname
         self.colname = colname
 
 
-@compiles(_ExecDropConstraint, 'mssql')
+@compiles(_ExecDropConstraint, "mssql")
 def _exec_drop_col_constraint(element, compiler, **kw):
     tname, colname, type_ = element.tname, element.colname, element.type_
     # from http://www.mssqltips.com/sqlservertip/1425/\
@@ -162,14 +174,14 @@ select @const_name = [name] from %(type)s
 where parent_object_id = object_id('%(tname)s')
 and col_name(parent_object_id, parent_column_id) = '%(colname)s'
 exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
-        'type': type_,
-        'tname': tname,
-        'colname': colname,
-        'tname_quoted': format_table_name(compiler, tname, None),
+        "type": type_,
+        "tname": tname,
+        "colname": colname,
+        "tname_quoted": format_table_name(compiler, tname, None),
     }
 
 
-@compiles(_ExecDropFKConstraint, 'mssql')
+@compiles(_ExecDropFKConstraint, "mssql")
 def _exec_drop_col_fk_constraint(element, compiler, **kw):
     tname, colname = element.tname, element.colname
 
@@ -180,17 +192,17 @@ select @const_name = [name] from
 where fkc.parent_object_id = object_id('%(tname)s')
 and col_name(fkc.parent_object_id, fkc.parent_column_id) = '%(colname)s'
 exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
-        'tname': tname,
-        'colname': colname,
-        'tname_quoted': format_table_name(compiler, tname, None),
+        "tname": tname,
+        "colname": colname,
+        "tname_quoted": format_table_name(compiler, tname, None),
     }
 
 
-@compiles(AddColumn, 'mssql')
+@compiles(AddColumn, "mssql")
 def visit_add_column(element, compiler, **kw):
     return "%s %s" % (
         alter_table(compiler, element.table_name, element.schema),
-        mssql_add_column(compiler, element.column, **kw)
+        mssql_add_column(compiler, element.column, **kw),
     )
 
 
@@ -198,49 +210,48 @@ def mssql_add_column(compiler, column, **kw):
     return "ADD %s" % compiler.get_column_specification(column, **kw)
 
 
-@compiles(ColumnNullable, 'mssql')
+@compiles(ColumnNullable, "mssql")
 def visit_column_nullable(element, compiler, **kw):
     return "%s %s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
         format_type(compiler, element.existing_type),
-        "NULL" if element.nullable else "NOT NULL"
+        "NULL" if element.nullable else "NOT NULL",
     )
 
 
-@compiles(ColumnDefault, 'mssql')
+@compiles(ColumnDefault, "mssql")
 def visit_column_default(element, compiler, **kw):
     # TODO: there can also be a named constraint
     # with ADD CONSTRAINT here
     return "%s ADD DEFAULT %s FOR %s" % (
         alter_table(compiler, element.table_name, element.schema),
         format_server_default(compiler, element.default),
-        format_column_name(compiler, element.column_name)
+        format_column_name(compiler, element.column_name),
     )
 
 
-@compiles(ColumnName, 'mssql')
+@compiles(ColumnName, "mssql")
 def visit_rename_column(element, compiler, **kw):
     return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % (
         format_table_name(compiler, element.table_name, element.schema),
         format_column_name(compiler, element.column_name),
-        format_column_name(compiler, element.newname)
+        format_column_name(compiler, element.newname),
     )
 
 
-@compiles(ColumnType, 'mssql')
+@compiles(ColumnType, "mssql")
 def visit_column_type(element, compiler, **kw):
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
-        format_type(compiler, element.type_)
+        format_type(compiler, element.type_),
     )
 
 
-@compiles(RenameTable, 'mssql')
+@compiles(RenameTable, "mssql")
 def visit_rename_table(element, compiler, **kw):
     return "EXEC sp_rename '%s', %s" % (
         format_table_name(compiler, element.table_name, element.schema),
-        format_table_name(compiler, element.new_table_name, None)
+        format_table_name(compiler, element.new_table_name, None),
     )
-
index 1f4b345ac9f4d1c3540d4a1539f997b0b049a527..bc98005aeff63c66e1ba81eb267b66822298e9c3 100644 (file)
@@ -5,9 +5,15 @@ from sqlalchemy import schema
 from ..util.compat import string_types
 from .. import util
 from .impl import DefaultImpl
-from .base import ColumnNullable, ColumnName, ColumnDefault, \
-    ColumnType, AlterColumn, format_column_name, \
-    format_server_default
+from .base import (
+    ColumnNullable,
+    ColumnName,
+    ColumnDefault,
+    ColumnType,
+    AlterColumn,
+    format_column_name,
+    format_server_default,
+)
 from .base import alter_table
 from ..autogenerate import compare
 from ..util.sqla_compat import _is_type_bound, sqla_100
@@ -15,64 +21,76 @@ import re
 
 
 class MySQLImpl(DefaultImpl):
-    __dialect__ = 'mysql'
+    __dialect__ = "mysql"
 
     transactional_ddl = False
 
-    def alter_column(self, table_name, column_name,
-                     nullable=None,
-                     server_default=False,
-                     name=None,
-                     type_=None,
-                     schema=None,
-                     existing_type=None,
-                     existing_server_default=None,
-                     existing_nullable=None,
-                     autoincrement=None,
-                     existing_autoincrement=None,
-                     **kw
-                     ):
+    def alter_column(
+        self,
+        table_name,
+        column_name,
+        nullable=None,
+        server_default=False,
+        name=None,
+        type_=None,
+        schema=None,
+        existing_type=None,
+        existing_server_default=None,
+        existing_nullable=None,
+        autoincrement=None,
+        existing_autoincrement=None,
+        **kw
+    ):
         if name is not None:
             self._exec(
                 MySQLChangeColumn(
-                    table_name, column_name,
+                    table_name,
+                    column_name,
                     schema=schema,
                     newname=name,
-                    nullable=nullable if nullable is not None else
-                    existing_nullable
+                    nullable=nullable
+                    if nullable is not None
+                    else existing_nullable
                     if existing_nullable is not None
                     else True,
                     type_=type_ if type_ is not None else existing_type,
-                    default=server_default if server_default is not False
+                    default=server_default
+                    if server_default is not False
                     else existing_server_default,
-                    autoincrement=autoincrement if autoincrement is not None
-                    else existing_autoincrement
+                    autoincrement=autoincrement
+                    if autoincrement is not None
+                    else existing_autoincrement,
                 )
             )
-        elif nullable is not None or \
-                type_ is not None or \
-                autoincrement is not None:
+        elif (
+            nullable is not None
+            or type_ is not None
+            or autoincrement is not None
+        ):
             self._exec(
                 MySQLModifyColumn(
-                    table_name, column_name,
+                    table_name,
+                    column_name,
                     schema=schema,
                     newname=name if name is not None else column_name,
-                    nullable=nullable if nullable is not None else
-                    existing_nullable
+                    nullable=nullable
+                    if nullable is not None
+                    else existing_nullable
                     if existing_nullable is not None
                     else True,
                     type_=type_ if type_ is not None else existing_type,
-                    default=server_default if server_default is not False
+                    default=server_default
+                    if server_default is not False
                     else existing_server_default,
-                    autoincrement=autoincrement if autoincrement is not None
-                    else existing_autoincrement
+                    autoincrement=autoincrement
+                    if autoincrement is not None
+                    else existing_autoincrement,
                 )
             )
         elif server_default is not False:
             self._exec(
                 MySQLAlterDefault(
-                    table_name, column_name, server_default,
-                    schema=schema,
+                    table_name, column_name, server_default, schema=schema
                 )
             )
 
@@ -82,41 +100,47 @@ class MySQLImpl(DefaultImpl):
 
         super(MySQLImpl, self).drop_constraint(const)
 
-    def compare_server_default(self, inspector_column,
-                               metadata_column,
-                               rendered_metadata_default,
-                               rendered_inspector_default):
+    def compare_server_default(
+        self,
+        inspector_column,
+        metadata_column,
+        rendered_metadata_default,
+        rendered_inspector_default,
+    ):
         # partially a workaround for SQLAlchemy issue #3023; if the
         # column were created without "NOT NULL", MySQL may have added
         # an implicit default of '0' which we need to skip
         # TODO: this is not really covered anymore ?
-        if metadata_column.type._type_affinity is sqltypes.Integer and \
-            inspector_column.primary_key and \
-                not inspector_column.autoincrement and \
-                not rendered_metadata_default and \
-                rendered_inspector_default == "'0'":
+        if (
+            metadata_column.type._type_affinity is sqltypes.Integer
+            and inspector_column.primary_key
+            and not inspector_column.autoincrement
+            and not rendered_metadata_default
+            and rendered_inspector_default == "'0'"
+        ):
             return False
         elif inspector_column.type._type_affinity is sqltypes.Integer:
             rendered_inspector_default = re.sub(
-                r"^'|'$", '', rendered_inspector_default)
+                r"^'|'$", "", rendered_inspector_default
+            )
             return rendered_inspector_default != rendered_metadata_default
         elif rendered_inspector_default and rendered_metadata_default:
             # adjust for "function()" vs. "FUNCTION"
-            return (
-                re.sub(
-                    r'(.*?)(?:\(\))?$', r'\1',
-                    rendered_inspector_default.lower()) !=
-                re.sub(
-                    r'(.*?)(?:\(\))?$', r'\1',
-                    rendered_metadata_default.lower())
+            return re.sub(
+                r"(.*?)(?:\(\))?$", r"\1", rendered_inspector_default.lower()
+            ) != re.sub(
+                r"(.*?)(?:\(\))?$", r"\1", rendered_metadata_default.lower()
             )
         else:
             return rendered_inspector_default != rendered_metadata_default
 
-    def correct_for_autogen_constraints(self, conn_unique_constraints,
-                                        conn_indexes,
-                                        metadata_unique_constraints,
-                                        metadata_indexes):
+    def correct_for_autogen_constraints(
+        self,
+        conn_unique_constraints,
+        conn_indexes,
+        metadata_unique_constraints,
+        metadata_indexes,
+    ):
 
         # TODO: if SQLA 1.0, make use of "duplicates_index"
         # metadata
@@ -153,31 +177,41 @@ class MySQLImpl(DefaultImpl):
                 conn_unique_constraints,
                 conn_indexes,
                 metadata_unique_constraints,
-                metadata_indexes
+                metadata_indexes,
             )
 
-    def _legacy_correct_for_dupe_uq_uix(self, conn_unique_constraints,
-                                        conn_indexes,
-                                        metadata_unique_constraints,
-                                        metadata_indexes):
+    def _legacy_correct_for_dupe_uq_uix(
+        self,
+        conn_unique_constraints,
+        conn_indexes,
+        metadata_unique_constraints,
+        metadata_indexes,
+    ):
 
         # then dedupe unique indexes vs. constraints, since MySQL
         # doesn't really have unique constraints as a separate construct.
         # but look in the metadata and try to maintain constructs
         # that already seem to be defined one way or the other
         # on that side.  See #276
-        metadata_uq_names = set([
-            cons.name for cons in metadata_unique_constraints
-            if cons.name is not None])
-
-        unnamed_metadata_uqs = set([
-            compare._uq_constraint_sig(cons).sig
-            for cons in metadata_unique_constraints
-            if cons.name is None
-        ])
-
-        metadata_ix_names = set([
-            cons.name for cons in metadata_indexes if cons.unique])
+        metadata_uq_names = set(
+            [
+                cons.name
+                for cons in metadata_unique_constraints
+                if cons.name is not None
+            ]
+        )
+
+        unnamed_metadata_uqs = set(
+            [
+                compare._uq_constraint_sig(cons).sig
+                for cons in metadata_unique_constraints
+                if cons.name is None
+            ]
+        )
+
+        metadata_ix_names = set(
+            [cons.name for cons in metadata_indexes if cons.unique]
+        )
         conn_uq_names = dict(
             (cons.name, cons) for cons in conn_unique_constraints
         )
@@ -187,8 +221,10 @@ class MySQLImpl(DefaultImpl):
 
         for overlap in set(conn_uq_names).intersection(conn_ix_names):
             if overlap not in metadata_uq_names:
-                if compare._uq_constraint_sig(conn_uq_names[overlap]).sig \
-                        not in unnamed_metadata_uqs:
+                if (
+                    compare._uq_constraint_sig(conn_uq_names[overlap]).sig
+                    not in unnamed_metadata_uqs
+                ):
 
                     conn_unique_constraints.discard(conn_uq_names[overlap])
             elif overlap not in metadata_ix_names:
@@ -208,18 +244,21 @@ class MySQLImpl(DefaultImpl):
             # MySQL considers RESTRICT to be the default and doesn't
             # report on it.  if the model has explicit RESTRICT and
             # the conn FK has None, set it to RESTRICT
-            if mdfk.ondelete is not None and \
-                    mdfk.ondelete.lower() == 'restrict' and \
-                    cnfk.ondelete is None:
-                cnfk.ondelete = 'RESTRICT'
-            if mdfk.onupdate is not None and \
-                    mdfk.onupdate.lower() == 'restrict' and \
-                    cnfk.onupdate is None:
-                cnfk.onupdate = 'RESTRICT'
+            if (
+                mdfk.ondelete is not None
+                and mdfk.ondelete.lower() == "restrict"
+                and cnfk.ondelete is None
+            ):
+                cnfk.ondelete = "RESTRICT"
+            if (
+                mdfk.onupdate is not None
+                and mdfk.onupdate.lower() == "restrict"
+                and cnfk.onupdate is None
+            ):
+                cnfk.onupdate = "RESTRICT"
 
 
 class MySQLAlterDefault(AlterColumn):
-
     def __init__(self, name, column_name, default, schema=None):
         super(AlterColumn, self).__init__(name, schema=schema)
         self.column_name = column_name
@@ -227,13 +266,17 @@ class MySQLAlterDefault(AlterColumn):
 
 
 class MySQLChangeColumn(AlterColumn):
-
-    def __init__(self, name, column_name, schema=None,
-                 newname=None,
-                 type_=None,
-                 nullable=None,
-                 default=False,
-                 autoincrement=None):
+    def __init__(
+        self,
+        name,
+        column_name,
+        schema=None,
+        newname=None,
+        type_=None,
+        nullable=None,
+        default=False,
+        autoincrement=None,
+    ):
         super(AlterColumn, self).__init__(name, schema=schema)
         self.column_name = column_name
         self.nullable = nullable
@@ -253,10 +296,10 @@ class MySQLModifyColumn(MySQLChangeColumn):
     pass
 
 
-@compiles(ColumnNullable, 'mysql')
-@compiles(ColumnName, 'mysql')
-@compiles(ColumnDefault, 'mysql')
-@compiles(ColumnType, 'mysql')
+@compiles(ColumnNullable, "mysql")
+@compiles(ColumnName, "mysql")
+@compiles(ColumnDefault, "mysql")
+@compiles(ColumnType, "mysql")
 def _mysql_doesnt_support_individual(element, compiler, **kw):
     raise NotImplementedError(
         "Individual alter column constructs not supported by MySQL"
@@ -270,7 +313,7 @@ def _mysql_alter_default(element, compiler, **kw):
         format_column_name(compiler, element.column_name),
         "SET DEFAULT %s" % format_server_default(compiler, element.default)
         if element.default is not None
-        else "DROP DEFAULT"
+        else "DROP DEFAULT",
     )
 
 
@@ -284,7 +327,7 @@ def _mysql_modify_column(element, compiler, **kw):
             nullable=element.nullable,
             server_default=element.default,
             type_=element.type_,
-            autoincrement=element.autoincrement
+            autoincrement=element.autoincrement,
         ),
     )
 
@@ -300,7 +343,7 @@ def _mysql_change_column(element, compiler, **kw):
             nullable=element.nullable,
             server_default=element.default,
             type_=element.type_,
-            autoincrement=element.autoincrement
+            autoincrement=element.autoincrement,
         ),
     )
 
@@ -312,11 +355,10 @@ def _render_value(compiler, expr):
         return compiler.sql_compiler.process(expr)
 
 
-def _mysql_colspec(compiler, nullable, server_default, type_,
-                   autoincrement):
+def _mysql_colspec(compiler, nullable, server_default, type_, autoincrement):
     spec = "%s %s" % (
         compiler.dialect.type_compiler.process(type_),
-        "NULL" if nullable else "NOT NULL"
+        "NULL" if nullable else "NOT NULL",
     )
     if autoincrement:
         spec += " AUTO_INCREMENT"
@@ -332,21 +374,25 @@ def _mysql_drop_constraint(element, compiler, **kw):
     raise errors for invalid constraint type."""
 
     constraint = element.element
-    if isinstance(constraint, (schema.ForeignKeyConstraint,
-                               schema.PrimaryKeyConstraint,
-                               schema.UniqueConstraint)
-                  ):
+    if isinstance(
+        constraint,
+        (
+            schema.ForeignKeyConstraint,
+            schema.PrimaryKeyConstraint,
+            schema.UniqueConstraint,
+        ),
+    ):
         return compiler.visit_drop_constraint(element, **kw)
     elif isinstance(constraint, schema.CheckConstraint):
         # note that SQLAlchemy as of 1.2 does not yet support
         # DROP CONSTRAINT for MySQL/MariaDB, so we implement fully
         # here.
-        return "ALTER TABLE %s DROP CONSTRAINT %s" % \
-            (compiler.preparer.format_table(constraint.table),
-             compiler.preparer.format_constraint(constraint))
+        return "ALTER TABLE %s DROP CONSTRAINT %s" % (
+            compiler.preparer.format_table(constraint.table),
+            compiler.preparer.format_constraint(constraint),
+        )
     else:
         raise NotImplementedError(
             "No generic 'DROP CONSTRAINT' in MySQL - "
-            "please specify constraint type")
-
-
+            "please specify constraint type"
+        )
index e528744ccce1bb21e4c5f59131c89739a0806b99..3376155388fd3ed927a466ab5d3f796ea18ec943 100644 (file)
@@ -1,13 +1,21 @@
 from sqlalchemy.ext.compiler import compiles
 
 from .impl import DefaultImpl
-from .base import alter_table, AddColumn, ColumnName, \
-    format_column_name, ColumnNullable, \
-    format_server_default, ColumnDefault, format_type, ColumnType
+from .base import (
+    alter_table,
+    AddColumn,
+    ColumnName,
+    format_column_name,
+    ColumnNullable,
+    format_server_default,
+    ColumnDefault,
+    format_type,
+    ColumnType,
+)
 
 
 class OracleImpl(DefaultImpl):
-    __dialect__ = 'oracle'
+    __dialect__ = "oracle"
     transactional_ddl = False
     batch_separator = "/"
     command_terminator = ""
@@ -15,8 +23,8 @@ class OracleImpl(DefaultImpl):
     def __init__(self, *arg, **kw):
         super(OracleImpl, self).__init__(*arg, **kw)
         self.batch_separator = self.context_opts.get(
-            "oracle_batch_separator",
-            self.batch_separator)
+            "oracle_batch_separator", self.batch_separator
+        )
 
     def _exec(self, construct, *args, **kw):
         result = super(OracleImpl, self)._exec(construct, *args, **kw)
@@ -31,7 +39,7 @@ class OracleImpl(DefaultImpl):
         self._exec("COMMIT")
 
 
-@compiles(AddColumn, 'oracle')
+@compiles(AddColumn, "oracle")
 def visit_add_column(element, compiler, **kw):
     return "%s %s" % (
         alter_table(compiler, element.table_name, element.schema),
@@ -39,47 +47,46 @@ def visit_add_column(element, compiler, **kw):
     )
 
 
-@compiles(ColumnNullable, 'oracle')
+@compiles(ColumnNullable, "oracle")
 def visit_column_nullable(element, compiler, **kw):
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
-        "NULL" if element.nullable else "NOT NULL"
+        "NULL" if element.nullable else "NOT NULL",
     )
 
 
-@compiles(ColumnType, 'oracle')
+@compiles(ColumnType, "oracle")
 def visit_column_type(element, compiler, **kw):
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
-        "%s" % format_type(compiler, element.type_)
+        "%s" % format_type(compiler, element.type_),
     )
 
 
-@compiles(ColumnName, 'oracle')
+@compiles(ColumnName, "oracle")
 def visit_column_name(element, compiler, **kw):
     return "%s RENAME COLUMN %s TO %s" % (
         alter_table(compiler, element.table_name, element.schema),
         format_column_name(compiler, element.column_name),
-        format_column_name(compiler, element.newname)
+        format_column_name(compiler, element.newname),
     )
 
 
-@compiles(ColumnDefault, 'oracle')
+@compiles(ColumnDefault, "oracle")
 def visit_column_default(element, compiler, **kw):
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
-        "DEFAULT %s" %
-        format_server_default(compiler, element.default)
+        "DEFAULT %s" % format_server_default(compiler, element.default)
         if element.default is not None
-        else "DEFAULT NULL"
+        else "DEFAULT NULL",
     )
 
 
 def alter_column(compiler, name):
-    return 'MODIFY %s' % format_column_name(compiler, name)
+    return "MODIFY %s" % format_column_name(compiler, name)
 
 
 def add_column(compiler, column, **kw):
index d3998336fd550a211223fa0abceba92c193bc234..f133a056c5660c430e6b1e2a0cf2d85fef08431b 100644 (file)
@@ -2,8 +2,15 @@ import re
 
 from ..util import compat
 from .. import util
-from .base import compiles, alter_column, alter_table, format_table_name, \
-    format_type, AlterColumn, RenameTable
+from .base import (
+    compiles,
+    alter_column,
+    alter_table,
+    format_table_name,
+    format_type,
+    AlterColumn,
+    RenameTable,
+)
 from .impl import DefaultImpl
 from sqlalchemy.dialects.postgresql import INTEGER, BIGINT
 from ..autogenerate import render
@@ -30,7 +37,7 @@ log = logging.getLogger(__name__)
 
 
 class PostgresqlImpl(DefaultImpl):
-    __dialect__ = 'postgresql'
+    __dialect__ = "postgresql"
     transactional_ddl = True
 
     def prep_table_for_batch(self, table):
@@ -38,13 +45,18 @@ class PostgresqlImpl(DefaultImpl):
             if constraint.name is not None:
                 self.drop_constraint(constraint)
 
-    def compare_server_default(self, inspector_column,
-                               metadata_column,
-                               rendered_metadata_default,
-                               rendered_inspector_default):
+    def compare_server_default(
+        self,
+        inspector_column,
+        metadata_column,
+        rendered_metadata_default,
+        rendered_inspector_default,
+    ):
         # don't do defaults for SERIAL columns
-        if metadata_column.primary_key and \
-                metadata_column is metadata_column.table._autoincrement_column:
+        if (
+            metadata_column.primary_key
+            and metadata_column is metadata_column.table._autoincrement_column
+        ):
             return False
 
         conn_col_default = rendered_inspector_default
@@ -56,53 +68,65 @@ class PostgresqlImpl(DefaultImpl):
         if None in (conn_col_default, rendered_metadata_default):
             return not defaults_equal
 
-        if metadata_column.server_default is not None and \
-            isinstance(metadata_column.server_default.arg,
-                       compat.string_types) and \
-                not re.match(r"^'.+'$", rendered_metadata_default) and \
-                not isinstance(inspector_column.type, Numeric):
-                # don't single quote if the column type is float/numeric,
-                # otherwise a comparison such as SELECT 5 = '5.0' will fail
+        if (
+            metadata_column.server_default is not None
+            and isinstance(
+                metadata_column.server_default.arg, compat.string_types
+            )
+            and not re.match(r"^'.+'$", rendered_metadata_default)
+            and not isinstance(inspector_column.type, Numeric)
+        ):
+            # don't single quote if the column type is float/numeric,
+            # otherwise a comparison such as SELECT 5 = '5.0' will fail
             rendered_metadata_default = re.sub(
-                r"^u?'?|'?$", "'", rendered_metadata_default)
+                r"^u?'?|'?$", "'", rendered_metadata_default
+            )
 
         return not self.connection.scalar(
-            "SELECT %s = %s" % (
-                conn_col_default,
-                rendered_metadata_default
-            )
+            "SELECT %s = %s" % (conn_col_default, rendered_metadata_default)
         )
 
-    def alter_column(self, table_name, column_name,
-                     nullable=None,
-                     server_default=False,
-                     name=None,
-                     type_=None,
-                     schema=None,
-                     autoincrement=None,
-                     existing_type=None,
-                     existing_server_default=None,
-                     existing_nullable=None,
-                     existing_autoincrement=None,
-                     **kw
-                     ):
-
-        using = kw.pop('postgresql_using', None)
+    def alter_column(
+        self,
+        table_name,
+        column_name,
+        nullable=None,
+        server_default=False,
+        name=None,
+        type_=None,
+        schema=None,
+        autoincrement=None,
+        existing_type=None,
+        existing_server_default=None,
+        existing_nullable=None,
+        existing_autoincrement=None,
+        **kw
+    ):
+
+        using = kw.pop("postgresql_using", None)
 
         if using is not None and type_ is None:
             raise util.CommandError(
-                "postgresql_using must be used with the type_ parameter")
+                "postgresql_using must be used with the type_ parameter"
+            )
 
         if type_ is not None:
-            self._exec(PostgresqlColumnType(
-                table_name, column_name, type_, schema=schema,
-                using=using, existing_type=existing_type,
-                existing_server_default=existing_server_default,
-                existing_nullable=existing_nullable,
-            ))
+            self._exec(
+                PostgresqlColumnType(
+                    table_name,
+                    column_name,
+                    type_,
+                    schema=schema,
+                    using=using,
+                    existing_type=existing_type,
+                    existing_server_default=existing_server_default,
+                    existing_nullable=existing_nullable,
+                )
+            )
 
         super(PostgresqlImpl, self).alter_column(
-            table_name, column_name,
+            table_name,
+            column_name,
             nullable=nullable,
             server_default=server_default,
             name=name,
@@ -112,57 +136,70 @@ class PostgresqlImpl(DefaultImpl):
             existing_server_default=existing_server_default,
             existing_nullable=existing_nullable,
             existing_autoincrement=existing_autoincrement,
-            **kw)
+            **kw
+        )
 
     def autogen_column_reflect(self, inspector, table, column_info):
-        if column_info.get('default') and \
-                isinstance(column_info['type'], (INTEGER, BIGINT)):
+        if column_info.get("default") and isinstance(
+            column_info["type"], (INTEGER, BIGINT)
+        ):
             seq_match = re.match(
-                r"nextval\('(.+?)'::regclass\)",
-                column_info['default'])
+                r"nextval\('(.+?)'::regclass\)", column_info["default"]
+            )
             if seq_match:
-                info = inspector.bind.execute(text(
-                    "select c.relname, a.attname "
-                    "from pg_class as c join pg_depend d on d.objid=c.oid and "
-                    "d.classid='pg_class'::regclass and "
-                    "d.refclassid='pg_class'::regclass "
-                    "join pg_class t on t.oid=d.refobjid "
-                    "join pg_attribute a on a.attrelid=t.oid and "
-                    "a.attnum=d.refobjsubid "
-                    "where c.relkind='S' and c.relname=:seqname"
-                ), seqname=seq_match.group(1)).first()
+                info = inspector.bind.execute(
+                    text(
+                        "select c.relname, a.attname "
+                        "from pg_class as c join pg_depend d on d.objid=c.oid and "
+                        "d.classid='pg_class'::regclass and "
+                        "d.refclassid='pg_class'::regclass "
+                        "join pg_class t on t.oid=d.refobjid "
+                        "join pg_attribute a on a.attrelid=t.oid and "
+                        "a.attnum=d.refobjsubid "
+                        "where c.relkind='S' and c.relname=:seqname"
+                    ),
+                    seqname=seq_match.group(1),
+                ).first()
                 if info:
                     seqname, colname = info
-                    if colname == column_info['name']:
+                    if colname == column_info["name"]:
                         log.info(
                             "Detected sequence named '%s' as "
                             "owned by integer column '%s(%s)', "
                             "assuming SERIAL and omitting",
-                            seqname, table.name, colname)
+                            seqname,
+                            table.name,
+                            colname,
+                        )
                         # sequence, and the owner is this column,
                         # its a SERIAL - whack it!
-                        del column_info['default']
+                        del column_info["default"]
 
-    def correct_for_autogen_constraints(self, conn_unique_constraints,
-                                        conn_indexes,
-                                        metadata_unique_constraints,
-                                        metadata_indexes):
+    def correct_for_autogen_constraints(
+        self,
+        conn_unique_constraints,
+        conn_indexes,
+        metadata_unique_constraints,
+        metadata_indexes,
+    ):
 
         conn_uniques_by_name = dict(
-            (c.name, c) for c in conn_unique_constraints)
-        conn_indexes_by_name = dict(
-            (c.name, c) for c in conn_indexes)
+            (c.name, c) for c in conn_unique_constraints
+        )
+        conn_indexes_by_name = dict((c.name, c) for c in conn_indexes)
 
         if not util.sqla_100:
             doubled_constraints = set(
                 conn_indexes_by_name[name]
                 for name in set(conn_uniques_by_name).intersection(
-                    conn_indexes_by_name)
+                    conn_indexes_by_name
+                )
             )
         else:
             doubled_constraints = set(
-                index for index in
-                conn_indexes if index.info.get('duplicates_constraint')
+                index
+                for index in conn_indexes
+                if index.info.get("duplicates_constraint")
             )
 
         for ix in doubled_constraints:
@@ -187,37 +224,36 @@ class PostgresqlImpl(DefaultImpl):
         if not mod.startswith("sqlalchemy.dialects.postgresql"):
             return False
 
-        if hasattr(self, '_render_%s_type' % type_.__visit_name__):
-            meth = getattr(self, '_render_%s_type' % type_.__visit_name__)
+        if hasattr(self, "_render_%s_type" % type_.__visit_name__):
+            meth = getattr(self, "_render_%s_type" % type_.__visit_name__)
             return meth(type_, autogen_context)
 
         return False
 
     def _render_HSTORE_type(self, type_, autogen_context):
         return render._render_type_w_subtype(
-            type_, autogen_context, 'text_type', r'(.+?\(.*text_type=)'
+            type_, autogen_context, "text_type", r"(.+?\(.*text_type=)"
         )
 
     def _render_ARRAY_type(self, type_, autogen_context):
         return render._render_type_w_subtype(
-            type_, autogen_context, 'item_type', r'(.+?\()'
+            type_, autogen_context, "item_type", r"(.+?\()"
         )
 
     def _render_JSON_type(self, type_, autogen_context):
         return render._render_type_w_subtype(
-            type_, autogen_context, 'astext_type', r'(.+?\(.*astext_type=)'
+            type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
         )
 
     def _render_JSONB_type(self, type_, autogen_context):
         return render._render_type_w_subtype(
-            type_, autogen_context, 'astext_type', r'(.+?\(.*astext_type=)'
+            type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
         )
 
 
 class PostgresqlColumnType(AlterColumn):
-
     def __init__(self, name, column_name, type_, **kw):
-        using = kw.pop('using', None)
+        using = kw.pop("using", None)
         super(PostgresqlColumnType, self).__init__(name, column_name, **kw)
         self.type_ = sqltypes.to_instance(type_)
         self.using = using
@@ -227,7 +263,7 @@ class PostgresqlColumnType(AlterColumn):
 def visit_rename_table(element, compiler, **kw):
     return "%s RENAME TO %s" % (
         alter_table(compiler, element.table_name, element.schema),
-        format_table_name(compiler, element.new_table_name, None)
+        format_table_name(compiler, element.new_table_name, None),
     )
 
 
@@ -237,13 +273,14 @@ def visit_column_type(element, compiler, **kw):
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
         "TYPE %s" % format_type(compiler, element.type_),
-        "USING %s" % element.using if element.using else ""
+        "USING %s" % element.using if element.using else "",
     )
 
 
 @Operations.register_operation("create_exclude_constraint")
 @BatchOperations.register_operation(
-    "create_exclude_constraint", "batch_create_exclude_constraint")
+    "create_exclude_constraint", "batch_create_exclude_constraint"
+)
 @ops.AddConstraintOp.register_add_constraint("exclude_constraint")
 class CreateExcludeConstraintOp(ops.AddConstraintOp):
     """Represent a create exclude constraint operation."""
@@ -251,9 +288,15 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
     constraint_type = "exclude"
 
     def __init__(
-            self, constraint_name, table_name,
-            elements, where=None, schema=None,
-            _orig_constraint=None, **kw):
+        self,
+        constraint_name,
+        table_name,
+        elements,
+        where=None,
+        schema=None,
+        _orig_constraint=None,
+        **kw
+    ):
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.elements = elements
@@ -275,13 +318,14 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
             _orig_constraint=constraint,
             deferrable=constraint.deferrable,
             initially=constraint.initially,
-            using=constraint.using
+            using=constraint.using,
         )
 
     def to_constraint(self, migration_context=None):
         if not util.sqla_100:
             raise NotImplementedError(
-                "ExcludeConstraint not supported until SQLAlchemy 1.0")
+                "ExcludeConstraint not supported until SQLAlchemy 1.0"
+            )
         if self._orig_constraint is not None:
             return self._orig_constraint
         schema_obj = schemaobj.SchemaObjects(migration_context)
@@ -299,8 +343,8 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
 
     @classmethod
     def create_exclude_constraint(
-            cls, operations,
-            constraint_name, table_name, *elements, **kw):
+        cls, operations, constraint_name, table_name, *elements, **kw
+    ):
         """Issue an alter to create an EXCLUDE constraint using the
         current migration context.
 
@@ -344,7 +388,8 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
 
     @classmethod
     def batch_create_exclude_constraint(
-            cls, operations, constraint_name, *elements, **kw):
+        cls, operations, constraint_name, *elements, **kw
+    ):
         """Issue a "create exclude constraint" instruction using the
         current batch migration context.
 
@@ -358,24 +403,23 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
             :meth:`.Operations.create_exclude_constraint`
 
         """
-        kw['schema'] = operations.impl.schema
+        kw["schema"] = operations.impl.schema
         op = cls(constraint_name, operations.impl.table_name, elements, **kw)
         return operations.invoke(op)
 
 
 @render.renderers.dispatch_for(CreateExcludeConstraintOp)
 def _add_exclude_constraint(autogen_context, op):
-    return _exclude_constraint(
-        op.to_constraint(),
-        autogen_context,
-        alter=True
-    )
+    return _exclude_constraint(op.to_constraint(), autogen_context, alter=True)
+
 
 if util.sqla_100:
+
     @render._constraint_renderers.dispatch_for(ExcludeConstraint)
     def _render_inline_exclude_constraint(constraint, autogen_context):
         rendered = render._user_defined_render(
-            "exclude", constraint, autogen_context)
+            "exclude", constraint, autogen_context
+        )
         if rendered is not False:
             return rendered
 
@@ -405,48 +449,54 @@ def _exclude_constraint(constraint, autogen_context, alter):
         opts.append(("schema", render._ident(constraint.table.schema)))
     if not alter and constraint.name:
         opts.append(
-            ("name",
-             render._render_gen_name(autogen_context, constraint.name)))
+            ("name", render._render_gen_name(autogen_context, constraint.name))
+        )
 
     if alter:
         args = [
-            repr(render._render_gen_name(
-                autogen_context, constraint.name))]
+            repr(render._render_gen_name(autogen_context, constraint.name))
+        ]
         if not has_batch:
             args += [repr(render._ident(constraint.table.name))]
-        args.extend([
-            "(%s, %r)" % (
-                _render_potential_column(sqltext, autogen_context),
-                opstring
-            )
-            for sqltext, name, opstring in constraint._render_exprs
-        ])
+        args.extend(
+            [
+                "(%s, %r)"
+                % (
+                    _render_potential_column(sqltext, autogen_context),
+                    opstring,
+                )
+                for sqltext, name, opstring in constraint._render_exprs
+            ]
+        )
         if constraint.where is not None:
             args.append(
-                "where=%s" % render._render_potential_expr(
-                    constraint.where, autogen_context)
+                "where=%s"
+                % render._render_potential_expr(
+                    constraint.where, autogen_context
+                )
             )
         args.extend(["%s=%r" % (k, v) for k, v in opts])
         return "%(prefix)screate_exclude_constraint(%(args)s)" % {
-            'prefix': render._alembic_autogenerate_prefix(autogen_context),
-            'args': ", ".join(args)
+            "prefix": render._alembic_autogenerate_prefix(autogen_context),
+            "args": ", ".join(args),
         }
     else:
         args = [
-            "(%s, %r)" % (
-                _render_potential_column(sqltext, autogen_context),
-                opstring
-            ) for sqltext, name, opstring in constraint._render_exprs
+            "(%s, %r)"
+            % (_render_potential_column(sqltext, autogen_context), opstring)
+            for sqltext, name, opstring in constraint._render_exprs
         ]
         if constraint.where is not None:
             args.append(
-                "where=%s" % render._render_potential_expr(
-                    constraint.where, autogen_context)
+                "where=%s"
+                % render._render_potential_expr(
+                    constraint.where, autogen_context
+                )
             )
         args.extend(["%s=%r" % (k, v) for k, v in opts])
         return "%(prefix)sExcludeConstraint(%(args)s)" % {
             "prefix": _postgresql_autogenerate_prefix(autogen_context),
-            "args": ", ".join(args)
+            "args": ", ".join(args),
         }
 
 
@@ -456,8 +506,10 @@ def _render_potential_column(value, autogen_context):
 
         return template % {
             "prefix": render._sqlalchemy_autogenerate_prefix(autogen_context),
-            "name": value.name
+            "name": value.name,
         }
 
     else:
-        return render._render_potential_expr(value, autogen_context, wrap_in_text=False)
+        return render._render_potential_expr(
+            value, autogen_context, wrap_in_text=False
+        )
index 5d231b5f744f0761c7340ad48dc418115ee77496..f7699e63f463c425663d0aa8f92bafb38aeae36c 100644 (file)
@@ -4,7 +4,7 @@ import re
 
 
 class SQLiteImpl(DefaultImpl):
-    __dialect__ = 'sqlite'
+    __dialect__ = "sqlite"
 
     transactional_ddl = False
     """SQLite supports transactional DDL, but pysqlite does not:
@@ -21,7 +21,7 @@ class SQLiteImpl(DefaultImpl):
 
         """
         for op in batch_op.batch:
-            if op[0] not in ('add_column', 'create_index', 'drop_index'):
+            if op[0] not in ("add_column", "create_index", "drop_index"):
                 return True
         else:
             return False
@@ -31,34 +31,46 @@ class SQLiteImpl(DefaultImpl):
         # auto-gen constraint and an explicit one
         if const._create_rule is None:
             raise NotImplementedError(
-                "No support for ALTER of constraints in SQLite dialect")
+                "No support for ALTER of constraints in SQLite dialect"
+            )
         elif const._create_rule(self):
-            util.warn("Skipping unsupported ALTER for "
-                      "creation of implicit constraint")
+            util.warn(
+                "Skipping unsupported ALTER for "
+                "creation of implicit constraint"
+            )
 
     def drop_constraint(self, const):
         if const._create_rule is None:
             raise NotImplementedError(
-                "No support for ALTER of constraints in SQLite dialect")
+                "No support for ALTER of constraints in SQLite dialect"
+            )
 
-    def compare_server_default(self, inspector_column,
-                               metadata_column,
-                               rendered_metadata_default,
-                               rendered_inspector_default):
+    def compare_server_default(
+        self,
+        inspector_column,
+        metadata_column,
+        rendered_metadata_default,
+        rendered_inspector_default,
+    ):
 
         if rendered_metadata_default is not None:
             rendered_metadata_default = re.sub(
-                r"^\"'|\"'$", "", rendered_metadata_default)
+                r"^\"'|\"'$", "", rendered_metadata_default
+            )
         if rendered_inspector_default is not None:
             rendered_inspector_default = re.sub(
-                r"^\"'|\"'$", "", rendered_inspector_default)
+                r"^\"'|\"'$", "", rendered_inspector_default
+            )
 
         return rendered_inspector_default != rendered_metadata_default
 
     def correct_for_autogen_constraints(
-        self, conn_unique_constraints, conn_indexes,
+        self,
+        conn_unique_constraints,
+        conn_indexes,
         metadata_unique_constraints,
-            metadata_indexes):
+        metadata_indexes,
+    ):
 
         if util.sqla_100:
             return
@@ -70,10 +82,7 @@ class SQLiteImpl(DefaultImpl):
         def uq_sig(uq):
             return tuple(sorted(uq.columns.keys()))
 
-        conn_unique_sigs = set(
-            uq_sig(uq)
-            for uq in conn_unique_constraints
-        )
+        conn_unique_sigs = set(uq_sig(uq) for uq in conn_unique_constraints)
 
         for idx in list(metadata_unique_constraints):
             # SQLite backend can't report on unnamed UNIQUE constraints,
index 1f367a10a9ff1a7a1b70667d23d6fd91aa7bc391..f3f5fac0cf5c1e56d44f42051b6d829f7026c86d 100644 (file)
@@ -3,4 +3,3 @@ from .operations.base import Operations
 # create proxy functions for
 # each method on the Operations class.
 Operations.create_module_class_proxy(globals(), locals())
-
index 1f6ee5da19bd583222d44a043a1b71dc94bcc983..e1ff01c3dd16fa9599bc86ef86daa4c5a2d120a7 100644 (file)
@@ -3,4 +3,4 @@ from .ops import MigrateOperation
 from . import toimpl
 
 
-__all__ = ['Operations', 'BatchOperations', 'MigrateOperation']
\ No newline at end of file
+__all__ = ["Operations", "BatchOperations", "MigrateOperation"]
index 1ae95241063441754e97994f4293b5a8895e5f87..2c3408ab546c82d62d135974d665d4deb2eb17fa 100644 (file)
@@ -9,7 +9,7 @@ from ..util.compat import inspect_formatargspec
 from ..util.compat import inspect_getargspec
 import textwrap
 
-__all__ = ('Operations', 'BatchOperations')
+__all__ = ("Operations", "BatchOperations")
 
 try:
     from sqlalchemy.sql.naming import conv
@@ -84,6 +84,7 @@ class Operations(util.ModuleClsProxy):
 
 
         """
+
         def register(op_cls):
             if sourcename is None:
                 fn = getattr(op_cls, name)
@@ -95,45 +96,53 @@ class Operations(util.ModuleClsProxy):
             spec = inspect_getargspec(fn)
 
             name_args = spec[0]
-            assert name_args[0:2] == ['cls', 'operations']
+            assert name_args[0:2] == ["cls", "operations"]
 
-            name_args[0:2] = ['self']
+            name_args[0:2] = ["self"]
 
             args = inspect_formatargspec(*spec)
             num_defaults = len(spec[3]) if spec[3] else 0
             if num_defaults:
-                defaulted_vals = name_args[0 - num_defaults:]
+                defaulted_vals = name_args[0 - num_defaults :]
             else:
                 defaulted_vals = ()
 
             apply_kw = inspect_formatargspec(
-                name_args, spec[1], spec[2],
+                name_args,
+                spec[1],
+                spec[2],
                 defaulted_vals,
-                formatvalue=lambda x: '=' + x)
+                formatvalue=lambda x: "=" + x,
+            )
 
-            func_text = textwrap.dedent("""\
+            func_text = textwrap.dedent(
+                """\
             def %(name)s%(args)s:
                 %(doc)r
                 return op_cls.%(source_name)s%(apply_kw)s
-            """ % {
-                'name': name,
-                'source_name': source_name,
-                'args': args,
-                'apply_kw': apply_kw,
-                'doc': fn.__doc__,
-                'meth': fn.__name__
-            })
-            globals_ = {'op_cls': op_cls}
+            """
+                % {
+                    "name": name,
+                    "source_name": source_name,
+                    "args": args,
+                    "apply_kw": apply_kw,
+                    "doc": fn.__doc__,
+                    "meth": fn.__name__,
+                }
+            )
+            globals_ = {"op_cls": op_cls}
             lcl = {}
             exec_(func_text, globals_, lcl)
             setattr(cls, name, lcl[name])
-            fn.__func__.__doc__ = "This method is proxied on "\
-                "the :class:`.%s` class, via the :meth:`.%s.%s` method." % (
-                    cls.__name__, cls.__name__, name
-                )
-            if hasattr(fn, '_legacy_translations'):
+            fn.__func__.__doc__ = (
+                "This method is proxied on "
+                "the :class:`.%s` class, via the :meth:`.%s.%s` method."
+                % (cls.__name__, cls.__name__, name)
+            )
+            if hasattr(fn, "_legacy_translations"):
                 lcl[name]._legacy_translations = fn._legacy_translations
             return op_cls
+
         return register
 
     @classmethod
@@ -151,6 +160,7 @@ class Operations(util.ModuleClsProxy):
         def decorate(fn):
             cls._to_impl.dispatch_for(op_cls)(fn)
             return fn
+
         return decorate
 
     @classmethod
@@ -163,10 +173,17 @@ class Operations(util.ModuleClsProxy):
 
     @contextmanager
     def batch_alter_table(
-            self, table_name, schema=None, recreate="auto", copy_from=None,
-            table_args=(), table_kwargs=util.immutabledict(),
-            reflect_args=(), reflect_kwargs=util.immutabledict(),
-            naming_convention=None):
+        self,
+        table_name,
+        schema=None,
+        recreate="auto",
+        copy_from=None,
+        table_args=(),
+        table_kwargs=util.immutabledict(),
+        reflect_args=(),
+        reflect_kwargs=util.immutabledict(),
+        naming_convention=None,
+    ):
         """Invoke a series of per-table migrations in batch.
 
         Batch mode allows a series of operations specific to a table
@@ -292,9 +309,17 @@ class Operations(util.ModuleClsProxy):
 
         """
         impl = batch.BatchOperationsImpl(
-            self, table_name, schema, recreate,
-            copy_from, table_args, table_kwargs, reflect_args,
-            reflect_kwargs, naming_convention)
+            self,
+            table_name,
+            schema,
+            recreate,
+            copy_from,
+            table_args,
+            table_kwargs,
+            reflect_args,
+            reflect_kwargs,
+            naming_convention,
+        )
         batch_op = BatchOperations(self.migration_context, impl=impl)
         yield batch_op
         impl.flush()
@@ -315,7 +340,8 @@ class Operations(util.ModuleClsProxy):
 
         """
         fn = self._to_impl.dispatch(
-            operation, self.migration_context.impl.__dialect__)
+            operation, self.migration_context.impl.__dialect__
+        )
         return fn(self, operation)
 
     def f(self, name):
@@ -363,7 +389,8 @@ class Operations(util.ModuleClsProxy):
             return conv(name)
         else:
             raise NotImplementedError(
-                "op.f() feature requires SQLAlchemy 0.9.4 or greater.")
+                "op.f() feature requires SQLAlchemy 0.9.4 or greater."
+            )
 
     def inline_literal(self, value, type_=None):
         """Produce an 'inline literal' expression, suitable for
@@ -442,4 +469,5 @@ class BatchOperations(Operations):
     def _noop(self, operation):
         raise NotImplementedError(
             "The %s method does not apply to a batch table alter operation."
-            % operation)
+            % operation
+        )
index 79ad533900f0769ef8da8d4830b3baa5867d8ff8..936287672c9ee7eee952bbe61a39c92354135485 100644 (file)
@@ -1,24 +1,48 @@
-from sqlalchemy import Table, MetaData, Index, select, Column, \
-    ForeignKeyConstraint, PrimaryKeyConstraint, cast, CheckConstraint
+from sqlalchemy import (
+    Table,
+    MetaData,
+    Index,
+    select,
+    Column,
+    ForeignKeyConstraint,
+    PrimaryKeyConstraint,
+    cast,
+    CheckConstraint,
+)
 from sqlalchemy import types as sqltypes
 from sqlalchemy import schema as sql_schema
 from sqlalchemy.util import OrderedDict
 from .. import util
 from sqlalchemy.events import SchemaEventTarget
-from ..util.sqla_compat import _columns_for_constraint, \
-    _is_type_bound, _fk_is_self_referential, _remove_column_from_collection
+from ..util.sqla_compat import (
+    _columns_for_constraint,
+    _is_type_bound,
+    _fk_is_self_referential,
+    _remove_column_from_collection,
+)
 
 
 class BatchOperationsImpl(object):
-    def __init__(self, operations, table_name, schema, recreate,
-                 copy_from, table_args, table_kwargs,
-                 reflect_args, reflect_kwargs, naming_convention):
+    def __init__(
+        self,
+        operations,
+        table_name,
+        schema,
+        recreate,
+        copy_from,
+        table_args,
+        table_kwargs,
+        reflect_args,
+        reflect_kwargs,
+        naming_convention,
+    ):
         self.operations = operations
         self.table_name = table_name
         self.schema = schema
-        if recreate not in ('auto', 'always', 'never'):
+        if recreate not in ("auto", "always", "never"):
             raise ValueError(
-                "recreate may be one of 'auto', 'always', or 'never'.")
+                "recreate may be one of 'auto', 'always', or 'never'."
+            )
         self.recreate = recreate
         self.copy_from = copy_from
         self.table_args = table_args
@@ -37,9 +61,9 @@ class BatchOperationsImpl(object):
         return self.operations.impl
 
     def _should_recreate(self):
-        if self.recreate == 'auto':
+        if self.recreate == "auto":
             return self.operations.impl.requires_recreate_in_batch(self)
-        elif self.recreate == 'always':
+        elif self.recreate == "always":
             return True
         else:
             return False
@@ -62,15 +86,19 @@ class BatchOperationsImpl(object):
                 reflected = False
             else:
                 existing_table = Table(
-                    self.table_name, m1,
+                    self.table_name,
+                    m1,
                     schema=self.schema,
                     autoload=True,
                     autoload_with=self.operations.get_bind(),
-                    *self.reflect_args, **self.reflect_kwargs)
+                    *self.reflect_args,
+                    **self.reflect_kwargs
+                )
                 reflected = True
 
             batch_impl = ApplyBatchImpl(
-                existing_table, self.table_args, self.table_kwargs, reflected)
+                existing_table, self.table_args, self.table_kwargs, reflected
+            )
             for opname, arg, kw in self.batch:
                 fn = getattr(batch_impl, opname)
                 fn(*arg, **kw)
@@ -90,7 +118,7 @@ class BatchOperationsImpl(object):
         self.batch.append(("add_constraint", (const,), {}))
 
     def drop_constraint(self, const):
-        self.batch.append(("drop_constraint", (const, ), {}))
+        self.batch.append(("drop_constraint", (const,), {}))
 
     def rename_table(self, *arg, **kw):
         self.batch.append(("rename_table", arg, kw))
@@ -116,7 +144,7 @@ class ApplyBatchImpl(object):
         self.temp_table_name = self._calc_temp_name(table.name)
         self.new_table = None
         self.column_transfers = OrderedDict(
-            (c.name, {'expr': c}) for c in self.table.c
+            (c.name, {"expr": c}) for c in self.table.c
         )
         self.reflected = reflected
         self._grab_table_elements()
@@ -165,16 +193,20 @@ class ApplyBatchImpl(object):
         schema = self.table.schema
 
         self.new_table = new_table = Table(
-            self.temp_table_name, m,
+            self.temp_table_name,
+            m,
             *(list(self.columns.values()) + list(self.table_args)),
             schema=schema,
-            **self.table_kwargs)
+            **self.table_kwargs
+        )
 
-        for const in list(self.named_constraints.values()) + \
-                self.unnamed_constraints:
+        for const in (
+            list(self.named_constraints.values()) + self.unnamed_constraints
+        ):
 
-            const_columns = set([
-                c.key for c in _columns_for_constraint(const)])
+            const_columns = set(
+                [c.key for c in _columns_for_constraint(const)]
+            )
 
             if not const_columns.issubset(self.column_transfers):
                 continue
@@ -188,7 +220,8 @@ class ApplyBatchImpl(object):
                     # no foreign keys just keeps the names unchanged, so
                     # when we rename back, they match again.
                     const_copy = const.copy(
-                        schema=schema, target_table=self.table)
+                        schema=schema, target_table=self.table
+                    )
                 else:
                     # "target_table" for ForeignKeyConstraint.copy() is
                     # only used if the FK is detected as being
@@ -209,7 +242,8 @@ class ApplyBatchImpl(object):
                     index.name,
                     unique=index.unique,
                     *[self.new_table.c[col] for col in index.columns.keys()],
-                    **index.kwargs)
+                    **index.kwargs
+                )
             )
         return idx
 
@@ -229,16 +263,20 @@ class ApplyBatchImpl(object):
                 for elem in constraint.elements:
                     colname = elem._get_colspec().split(".")[-1]
                     if not t.c.contains_column(colname):
-                        t.append_column(
-                            Column(colname, sqltypes.NULLTYPE)
-                        )
+                        t.append_column(Column(colname, sqltypes.NULLTYPE))
             else:
                 Table(
-                    tname, metadata,
-                    *[Column(n, sqltypes.NULLTYPE) for n in
-                        [elem._get_colspec().split(".")[-1]
-                         for elem in constraint.elements]],
-                    schema=referent_schema)
+                    tname,
+                    metadata,
+                    *[
+                        Column(n, sqltypes.NULLTYPE)
+                        for n in [
+                            elem._get_colspec().split(".")[-1]
+                            for elem in constraint.elements
+                        ]
+                    ],
+                    schema=referent_schema
+                )
 
     def _create(self, op_impl):
         self._transfer_elements_to_new_table()
@@ -249,13 +287,18 @@ class ApplyBatchImpl(object):
         try:
             op_impl._exec(
                 self.new_table.insert(inline=True).from_select(
-                    list(k for k, transfer in
-                         self.column_transfers.items() if 'expr' in transfer),
-                    select([
-                        transfer['expr']
-                        for transfer in self.column_transfers.values()
-                        if 'expr' in transfer
-                    ])
+                    list(
+                        k
+                        for k, transfer in self.column_transfers.items()
+                        if "expr" in transfer
+                    ),
+                    select(
+                        [
+                            transfer["expr"]
+                            for transfer in self.column_transfers.values()
+                            if "expr" in transfer
+                        ]
+                    ),
                 )
             )
             op_impl.drop_table(self.table)
@@ -264,9 +307,7 @@ class ApplyBatchImpl(object):
             raise
         else:
             op_impl.rename_table(
-                self.temp_table_name,
-                self.table.name,
-                schema=self.table.schema
+                self.temp_table_name, self.table.name, schema=self.table.schema
             )
             self.new_table.name = self.table.name
             try:
@@ -275,14 +316,17 @@ class ApplyBatchImpl(object):
             finally:
                 self.new_table.name = self.temp_table_name
 
-    def alter_column(self, table_name, column_name,
-                     nullable=None,
-                     server_default=False,
-                     name=None,
-                     type_=None,
-                     autoincrement=None,
-                     **kw
-                     ):
+    def alter_column(
+        self,
+        table_name,
+        column_name,
+        nullable=None,
+        server_default=False,
+        name=None,
+        type_=None,
+        autoincrement=None,
+        **kw
+    ):
         existing = self.columns[column_name]
         existing_transfer = self.column_transfers[column_name]
         if name is not None and name != column_name:
@@ -299,12 +343,14 @@ class ApplyBatchImpl(object):
             # we also ignore the drop_constraint that will come here from
             # Operations.implementation_for(alter_column)
             if isinstance(existing.type, SchemaEventTarget):
-                existing.type._create_events = \
-                    existing.type.create_constraint = False
+                existing.type._create_events = (
+                    existing.type.create_constraint
+                ) = False
 
             if existing.type._type_affinity is not type_._type_affinity:
                 existing_transfer["expr"] = cast(
-                    existing_transfer["expr"], type_)
+                    existing_transfer["expr"], type_
+                )
 
             existing.type = type_
 
@@ -332,8 +378,7 @@ class ApplyBatchImpl(object):
     def drop_column(self, table_name, column, **kw):
         if column.name in self.table.primary_key.columns:
             _remove_column_from_collection(
-                self.table.primary_key.columns,
-                column
+                self.table.primary_key.columns, column
             )
         del self.columns[column.name]
         del self.column_transfers[column.name]
index ade1cb31a02eb44ee80d8096d86601d8437ac324..5824469a6492d6f54c5244702a0714e4219a6128 100644 (file)
@@ -46,12 +46,14 @@ class AddConstraintOp(MigrateOperation):
         def go(klass):
             cls.add_constraint_ops.dispatch_for(type_)(klass.from_constraint)
             return klass
+
         return go
 
     @classmethod
     def from_constraint(cls, constraint):
-        return cls.add_constraint_ops.dispatch(
-            constraint.__visit_name__)(constraint)
+        return cls.add_constraint_ops.dispatch(constraint.__visit_name__)(
+            constraint
+        )
 
     def reverse(self):
         return DropConstraintOp.from_constraint(self.to_constraint())
@@ -66,9 +68,13 @@ class DropConstraintOp(MigrateOperation):
     """Represent a drop constraint operation."""
 
     def __init__(
-            self,
-            constraint_name, table_name, type_=None, schema=None,
-            _orig_constraint=None):
+        self,
+        constraint_name,
+        table_name,
+        type_=None,
+        schema=None,
+        _orig_constraint=None,
+    ):
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.constraint_type = type_
@@ -79,7 +85,8 @@ class DropConstraintOp(MigrateOperation):
         if self._orig_constraint is None:
             raise ValueError(
                 "operation is not reversible; "
-                "original constraint is not present")
+                "original constraint is not present"
+            )
         return AddConstraintOp.from_constraint(self._orig_constraint)
 
     def to_diff_tuple(self):
@@ -104,7 +111,7 @@ class DropConstraintOp(MigrateOperation):
             constraint_table.name,
             schema=constraint_table.schema,
             type_=types[constraint.__visit_name__],
-            _orig_constraint=constraint
+            _orig_constraint=constraint,
         )
 
     def to_constraint(self):
@@ -113,16 +120,14 @@ class DropConstraintOp(MigrateOperation):
         else:
             raise ValueError(
                 "constraint cannot be produced; "
-                "original constraint is not present")
+                "original constraint is not present"
+            )
 
     @classmethod
-    @util._with_legacy_names([
-        ("type", "type_"),
-        ("name", "constraint_name"),
-    ])
+    @util._with_legacy_names([("type", "type_"), ("name", "constraint_name")])
     def drop_constraint(
-            cls, operations, constraint_name, table_name,
-            type_=None, schema=None):
+        cls, operations, constraint_name, table_name, type_=None, schema=None
+    ):
         """Drop a constraint of the given name, typically via DROP CONSTRAINT.
 
         :param constraint_name: name of the constraint.
@@ -166,15 +171,18 @@ class DropConstraintOp(MigrateOperation):
 
         """
         op = cls(
-            constraint_name, operations.impl.table_name,
-            type_=type_, schema=operations.impl.schema
+            constraint_name,
+            operations.impl.table_name,
+            type_=type_,
+            schema=operations.impl.schema,
         )
         return operations.invoke(op)
 
 
 @Operations.register_operation("create_primary_key")
 @BatchOperations.register_operation(
-    "create_primary_key", "batch_create_primary_key")
+    "create_primary_key", "batch_create_primary_key"
+)
 @AddConstraintOp.register_add_constraint("primary_key_constraint")
 class CreatePrimaryKeyOp(AddConstraintOp):
     """Represent a create primary key operation."""
@@ -182,8 +190,14 @@ class CreatePrimaryKeyOp(AddConstraintOp):
     constraint_type = "primarykey"
 
     def __init__(
-            self, constraint_name, table_name, columns,
-            schema=None, _orig_constraint=None, **kw):
+        self,
+        constraint_name,
+        table_name,
+        columns,
+        schema=None,
+        _orig_constraint=None,
+        **kw
+    ):
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.columns = columns
@@ -200,7 +214,7 @@ class CreatePrimaryKeyOp(AddConstraintOp):
             constraint_table.name,
             constraint.columns,
             schema=constraint_table.schema,
-            _orig_constraint=constraint
+            _orig_constraint=constraint,
         )
 
     def to_constraint(self, migration_context=None):
@@ -209,17 +223,19 @@ class CreatePrimaryKeyOp(AddConstraintOp):
 
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.primary_key_constraint(
-            self.constraint_name, self.table_name,
-            self.columns, schema=self.schema)
+            self.constraint_name,
+            self.table_name,
+            self.columns,
+            schema=self.schema,
+        )
 
     @classmethod
-    @util._with_legacy_names([
-        ('name', 'constraint_name'),
-        ('cols', 'columns')
-    ])
+    @util._with_legacy_names(
+        [("name", "constraint_name"), ("cols", "columns")]
+    )
     def create_primary_key(
-            cls, operations,
-            constraint_name, table_name, columns, schema=None):
+        cls, operations, constraint_name, table_name, columns, schema=None
+    ):
         """Issue a "create primary key" instruction using the current
         migration context.
 
@@ -282,15 +298,18 @@ class CreatePrimaryKeyOp(AddConstraintOp):
 
         """
         op = cls(
-            constraint_name, operations.impl.table_name, columns,
-            schema=operations.impl.schema
+            constraint_name,
+            operations.impl.table_name,
+            columns,
+            schema=operations.impl.schema,
         )
         return operations.invoke(op)
 
 
 @Operations.register_operation("create_unique_constraint")
 @BatchOperations.register_operation(
-    "create_unique_constraint", "batch_create_unique_constraint")
+    "create_unique_constraint", "batch_create_unique_constraint"
+)
 @AddConstraintOp.register_add_constraint("unique_constraint")
 class CreateUniqueConstraintOp(AddConstraintOp):
     """Represent a create unique constraint operation."""
@@ -298,8 +317,14 @@ class CreateUniqueConstraintOp(AddConstraintOp):
     constraint_type = "unique"
 
     def __init__(
-            self, constraint_name, table_name,
-            columns, schema=None, _orig_constraint=None, **kw):
+        self,
+        constraint_name,
+        table_name,
+        columns,
+        schema=None,
+        _orig_constraint=None,
+        **kw
+    ):
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.columns = columns
@@ -313,9 +338,9 @@ class CreateUniqueConstraintOp(AddConstraintOp):
 
         kw = {}
         if constraint.deferrable:
-            kw['deferrable'] = constraint.deferrable
+            kw["deferrable"] = constraint.deferrable
         if constraint.initially:
-            kw['initially'] = constraint.initially
+            kw["initially"] = constraint.initially
 
         return cls(
             constraint.name,
@@ -332,18 +357,30 @@ class CreateUniqueConstraintOp(AddConstraintOp):
 
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.unique_constraint(
-            self.constraint_name, self.table_name, self.columns,
-            schema=self.schema, **self.kw)
+            self.constraint_name,
+            self.table_name,
+            self.columns,
+            schema=self.schema,
+            **self.kw
+        )
 
     @classmethod
-    @util._with_legacy_names([
-        ('name', 'constraint_name'),
-        ('source', 'table_name'),
-        ('local_cols', 'columns'),
-    ])
+    @util._with_legacy_names(
+        [
+            ("name", "constraint_name"),
+            ("source", "table_name"),
+            ("local_cols", "columns"),
+        ]
+    )
     def create_unique_constraint(
-            cls, operations, constraint_name, table_name, columns,
-            schema=None, **kw):
+        cls,
+        operations,
+        constraint_name,
+        table_name,
+        columns,
+        schema=None,
+        **kw
+    ):
         """Issue a "create unique constraint" instruction using the
         current migration context.
 
@@ -392,16 +429,14 @@ class CreateUniqueConstraintOp(AddConstraintOp):
 
         """
 
-        op = cls(
-            constraint_name, table_name, columns,
-            schema=schema, **kw
-        )
+        op = cls(constraint_name, table_name, columns, schema=schema, **kw)
         return operations.invoke(op)
 
     @classmethod
-    @util._with_legacy_names([('name', 'constraint_name')])
+    @util._with_legacy_names([("name", "constraint_name")])
     def batch_create_unique_constraint(
-            cls, operations, constraint_name, columns, **kw):
+        cls, operations, constraint_name, columns, **kw
+    ):
         """Issue a "create unique constraint" instruction using the
         current batch migration context.
 
@@ -418,17 +453,15 @@ class CreateUniqueConstraintOp(AddConstraintOp):
            * name -> constraint_name
 
         """
-        kw['schema'] = operations.impl.schema
-        op = cls(
-            constraint_name, operations.impl.table_name, columns,
-            **kw
-        )
+        kw["schema"] = operations.impl.schema
+        op = cls(constraint_name, operations.impl.table_name, columns, **kw)
         return operations.invoke(op)
 
 
 @Operations.register_operation("create_foreign_key")
 @BatchOperations.register_operation(
-    "create_foreign_key", "batch_create_foreign_key")
+    "create_foreign_key", "batch_create_foreign_key"
+)
 @AddConstraintOp.register_add_constraint("foreign_key_constraint")
 class CreateForeignKeyOp(AddConstraintOp):
     """Represent a create foreign key constraint operation."""
@@ -436,8 +469,15 @@ class CreateForeignKeyOp(AddConstraintOp):
     constraint_type = "foreignkey"
 
     def __init__(
-            self, constraint_name, source_table, referent_table, local_cols,
-            remote_cols, _orig_constraint=None, **kw):
+        self,
+        constraint_name,
+        source_table,
+        referent_table,
+        local_cols,
+        remote_cols,
+        _orig_constraint=None,
+        **kw
+    ):
         self.constraint_name = constraint_name
         self.source_table = source_table
         self.referent_table = referent_table
@@ -453,24 +493,22 @@ class CreateForeignKeyOp(AddConstraintOp):
     def from_constraint(cls, constraint):
         kw = {}
         if constraint.onupdate:
-            kw['onupdate'] = constraint.onupdate
+            kw["onupdate"] = constraint.onupdate
         if constraint.ondelete:
-            kw['ondelete'] = constraint.ondelete
+            kw["ondelete"] = constraint.ondelete
         if constraint.initially:
-            kw['initially'] = constraint.initially
+            kw["initially"] = constraint.initially
         if constraint.deferrable:
-            kw['deferrable'] = constraint.deferrable
+            kw["deferrable"] = constraint.deferrable
         if constraint.use_alter:
-            kw['use_alter'] = constraint.use_alter
+            kw["use_alter"] = constraint.use_alter
 
-        source_schema, source_table, \
-            source_columns, target_schema, \
-            target_table, target_columns,\
-            onupdate, ondelete, deferrable, initially \
-            = sqla_compat._fk_spec(constraint)
+        source_schema, source_table, source_columns, target_schema, target_table, target_columns, onupdate, ondelete, deferrable, initially = sqla_compat._fk_spec(
+            constraint
+        )
 
-        kw['source_schema'] = source_schema
-        kw['referent_schema'] = target_schema
+        kw["source_schema"] = source_schema
+        kw["referent_schema"] = target_schema
 
         return cls(
             constraint.name,
@@ -488,22 +526,38 @@ class CreateForeignKeyOp(AddConstraintOp):
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.foreign_key_constraint(
             self.constraint_name,
-            self.source_table, self.referent_table,
-            self.local_cols, self.remote_cols,
-            **self.kw)
+            self.source_table,
+            self.referent_table,
+            self.local_cols,
+            self.remote_cols,
+            **self.kw
+        )
 
     @classmethod
-    @util._with_legacy_names([
-        ('name', 'constraint_name'),
-        ('source', 'source_table'),
-        ('referent', 'referent_table'),
-    ])
-    def create_foreign_key(cls, operations, constraint_name,
-                           source_table, referent_table, local_cols,
-                           remote_cols, onupdate=None, ondelete=None,
-                           deferrable=None, initially=None, match=None,
-                           source_schema=None, referent_schema=None,
-                           **dialect_kw):
+    @util._with_legacy_names(
+        [
+            ("name", "constraint_name"),
+            ("source", "source_table"),
+            ("referent", "referent_table"),
+        ]
+    )
+    def create_foreign_key(
+        cls,
+        operations,
+        constraint_name,
+        source_table,
+        referent_table,
+        local_cols,
+        remote_cols,
+        onupdate=None,
+        ondelete=None,
+        deferrable=None,
+        initially=None,
+        match=None,
+        source_schema=None,
+        referent_schema=None,
+        **dialect_kw
+    ):
         """Issue a "create foreign key" instruction using the
         current migration context.
 
@@ -558,29 +612,40 @@ class CreateForeignKeyOp(AddConstraintOp):
 
         op = cls(
             constraint_name,
-            source_table, referent_table,
-            local_cols, remote_cols,
-            onupdate=onupdate, ondelete=ondelete,
+            source_table,
+            referent_table,
+            local_cols,
+            remote_cols,
+            onupdate=onupdate,
+            ondelete=ondelete,
             deferrable=deferrable,
             source_schema=source_schema,
             referent_schema=referent_schema,
-            initially=initially, match=match,
+            initially=initially,
+            match=match,
             **dialect_kw
         )
         return operations.invoke(op)
 
     @classmethod
-    @util._with_legacy_names([
-        ('name', 'constraint_name'),
-        ('referent', 'referent_table')
-    ])
+    @util._with_legacy_names(
+        [("name", "constraint_name"), ("referent", "referent_table")]
+    )
     def batch_create_foreign_key(
-            cls, operations, constraint_name, referent_table,
-            local_cols, remote_cols,
-            referent_schema=None,
-            onupdate=None, ondelete=None,
-            deferrable=None, initially=None, match=None,
-            **dialect_kw):
+        cls,
+        operations,
+        constraint_name,
+        referent_table,
+        local_cols,
+        remote_cols,
+        referent_schema=None,
+        onupdate=None,
+        ondelete=None,
+        deferrable=None,
+        initially=None,
+        match=None,
+        **dialect_kw
+    ):
         """Issue a "create foreign key" instruction using the
         current batch migration context.
 
@@ -607,13 +672,17 @@ class CreateForeignKeyOp(AddConstraintOp):
         """
         op = cls(
             constraint_name,
-            operations.impl.table_name, referent_table,
-            local_cols, remote_cols,
-            onupdate=onupdate, ondelete=ondelete,
+            operations.impl.table_name,
+            referent_table,
+            local_cols,
+            remote_cols,
+            onupdate=onupdate,
+            ondelete=ondelete,
             deferrable=deferrable,
             source_schema=operations.impl.schema,
             referent_schema=referent_schema,
-            initially=initially, match=match,
+            initially=initially,
+            match=match,
             **dialect_kw
         )
         return operations.invoke(op)
@@ -621,7 +690,8 @@ class CreateForeignKeyOp(AddConstraintOp):
 
 @Operations.register_operation("create_check_constraint")
 @BatchOperations.register_operation(
-    "create_check_constraint", "batch_create_check_constraint")
+    "create_check_constraint", "batch_create_check_constraint"
+)
 @AddConstraintOp.register_add_constraint("check_constraint")
 @AddConstraintOp.register_add_constraint("column_check_constraint")
 class CreateCheckConstraintOp(AddConstraintOp):
@@ -630,8 +700,14 @@ class CreateCheckConstraintOp(AddConstraintOp):
     constraint_type = "check"
 
     def __init__(
-            self, constraint_name, table_name,
-            condition, schema=None, _orig_constraint=None, **kw):
+        self,
+        constraint_name,
+        table_name,
+        condition,
+        schema=None,
+        _orig_constraint=None,
+        **kw
+    ):
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.condition = condition
@@ -648,7 +724,7 @@ class CreateCheckConstraintOp(AddConstraintOp):
             constraint_table.name,
             constraint.sqltext,
             schema=constraint_table.schema,
-            _orig_constraint=constraint
+            _orig_constraint=constraint,
         )
 
     def to_constraint(self, migration_context=None):
@@ -656,18 +732,26 @@ class CreateCheckConstraintOp(AddConstraintOp):
             return self._orig_constraint
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.check_constraint(
-            self.constraint_name, self.table_name,
-            self.condition, schema=self.schema, **self.kw)
+            self.constraint_name,
+            self.table_name,
+            self.condition,
+            schema=self.schema,
+            **self.kw
+        )
 
     @classmethod
-    @util._with_legacy_names([
-        ('name', 'constraint_name'),
-        ('source', 'table_name')
-    ])
+    @util._with_legacy_names(
+        [("name", "constraint_name"), ("source", "table_name")]
+    )
     def create_check_constraint(
-            cls, operations,
-            constraint_name, table_name, condition,
-            schema=None, **kw):
+        cls,
+        operations,
+        constraint_name,
+        table_name,
+        condition,
+        schema=None,
+        **kw
+    ):
         """Issue a "create check constraint" instruction using the
         current migration context.
 
@@ -721,9 +805,10 @@ class CreateCheckConstraintOp(AddConstraintOp):
         return operations.invoke(op)
 
     @classmethod
-    @util._with_legacy_names([('name', 'constraint_name')])
+    @util._with_legacy_names([("name", "constraint_name")])
     def batch_create_check_constraint(
-            cls, operations, constraint_name, condition, **kw):
+        cls, operations, constraint_name, condition, **kw
+    ):
         """Issue a "create check constraint" instruction using the
         current batch migration context.
 
@@ -741,8 +826,12 @@ class CreateCheckConstraintOp(AddConstraintOp):
 
         """
         op = cls(
-            constraint_name, operations.impl.table_name,
-            condition, schema=operations.impl.schema, **kw)
+            constraint_name,
+            operations.impl.table_name,
+            condition,
+            schema=operations.impl.schema,
+            **kw
+        )
         return operations.invoke(op)
 
 
@@ -752,8 +841,15 @@ class CreateIndexOp(MigrateOperation):
     """Represent a create index operation."""
 
     def __init__(
-            self, index_name, table_name, columns, schema=None,
-            unique=False, _orig_index=None, **kw):
+        self,
+        index_name,
+        table_name,
+        columns,
+        schema=None,
+        unique=False,
+        _orig_index=None,
+        **kw
+    ):
         self.index_name = index_name
         self.table_name = table_name
         self.columns = columns
@@ -785,15 +881,26 @@ class CreateIndexOp(MigrateOperation):
             return self._orig_index
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.index(
-            self.index_name, self.table_name, self.columns, schema=self.schema,
-            unique=self.unique, **self.kw)
+            self.index_name,
+            self.table_name,
+            self.columns,
+            schema=self.schema,
+            unique=self.unique,
+            **self.kw
+        )
 
     @classmethod
-    @util._with_legacy_names([('name', 'index_name')])
+    @util._with_legacy_names([("name", "index_name")])
     def create_index(
-            cls, operations,
-            index_name, table_name, columns, schema=None,
-            unique=False, **kw):
+        cls,
+        operations,
+        index_name,
+        table_name,
+        columns,
+        schema=None,
+        unique=False,
+        **kw
+    ):
         r"""Issue a "create index" instruction using the current
         migration context.
 
@@ -851,8 +958,7 @@ class CreateIndexOp(MigrateOperation):
 
         """
         op = cls(
-            index_name, table_name, columns, schema=schema,
-            unique=unique, **kw
+            index_name, table_name, columns, schema=schema, unique=unique, **kw
         )
         return operations.invoke(op)
 
@@ -868,8 +974,11 @@ class CreateIndexOp(MigrateOperation):
         """
 
         op = cls(
-            index_name, operations.impl.table_name, columns,
-            schema=operations.impl.schema, **kw
+            index_name,
+            operations.impl.table_name,
+            columns,
+            schema=operations.impl.schema,
+            **kw
         )
         return operations.invoke(op)
 
@@ -880,8 +989,8 @@ class DropIndexOp(MigrateOperation):
     """Represent a drop index operation."""
 
     def __init__(
-            self, index_name, table_name=None,
-            schema=None, _orig_index=None, **kw):
+        self, index_name, table_name=None, schema=None, _orig_index=None, **kw
+    ):
         self.index_name = index_name
         self.table_name = table_name
         self.schema = schema
@@ -894,8 +1003,8 @@ class DropIndexOp(MigrateOperation):
     def reverse(self):
         if self._orig_index is None:
             raise ValueError(
-                "operation is not reversible; "
-                "original index is not present")
+                "operation is not reversible; " "original index is not present"
+            )
         return CreateIndexOp.from_index(self._orig_index)
 
     @classmethod
@@ -917,16 +1026,20 @@ class DropIndexOp(MigrateOperation):
         # need a dummy column name here since SQLAlchemy
         # 0.7.6 and further raises on Index with no columns
         return schema_obj.index(
-            self.index_name, self.table_name, ['x'],
-            schema=self.schema, **self.kw)
+            self.index_name,
+            self.table_name,
+            ["x"],
+            schema=self.schema,
+            **self.kw
+        )
 
     @classmethod
-    @util._with_legacy_names([
-        ('name', 'index_name'),
-        ('tablename', 'table_name')
-    ])
-    def drop_index(cls, operations, index_name,
-                   table_name=None, schema=None, **kw):
+    @util._with_legacy_names(
+        [("name", "index_name"), ("tablename", "table_name")]
+    )
+    def drop_index(
+        cls, operations, index_name, table_name=None, schema=None, **kw
+    ):
         r"""Issue a "drop index" instruction using the current
         migration context.
 
@@ -964,7 +1077,7 @@ class DropIndexOp(MigrateOperation):
         return operations.invoke(op)
 
     @classmethod
-    @util._with_legacy_names([('name', 'index_name')])
+    @util._with_legacy_names([("name", "index_name")])
     def batch_drop_index(cls, operations, index_name, **kw):
         """Issue a "drop index" instruction using the
         current batch migration context.
@@ -981,8 +1094,10 @@ class DropIndexOp(MigrateOperation):
         """
 
         op = cls(
-            index_name, table_name=operations.impl.table_name,
-            schema=operations.impl.schema, **kw
+            index_name,
+            table_name=operations.impl.table_name,
+            schema=operations.impl.schema,
+            **kw
         )
         return operations.invoke(op)
 
@@ -992,7 +1107,8 @@ class CreateTableOp(MigrateOperation):
     """Represent a create table operation."""
 
     def __init__(
-            self, table_name, columns, schema=None, _orig_table=None, **kw):
+        self, table_name, columns, schema=None, _orig_table=None, **kw
+    ):
         self.table_name = table_name
         self.columns = columns
         self.schema = schema
@@ -1025,7 +1141,7 @@ class CreateTableOp(MigrateOperation):
         )
 
     @classmethod
-    @util._with_legacy_names([('name', 'table_name')])
+    @util._with_legacy_names([("name", "table_name")])
     def create_table(cls, operations, table_name, *columns, **kw):
         r"""Issue a "create table" instruction using the current migration
         context.
@@ -1125,7 +1241,8 @@ class DropTableOp(MigrateOperation):
     """Represent a drop table operation."""
 
     def __init__(
-            self, table_name, schema=None, table_kw=None, _orig_table=None):
+        self, table_name, schema=None, table_kw=None, _orig_table=None
+    ):
         self.table_name = table_name
         self.schema = schema
         self.table_kw = table_kw or {}
@@ -1137,8 +1254,8 @@ class DropTableOp(MigrateOperation):
     def reverse(self):
         if self._orig_table is None:
             raise ValueError(
-                "operation is not reversible; "
-                "original table is not present")
+                "operation is not reversible; " "original table is not present"
+            )
         return CreateTableOp.from_table(self._orig_table)
 
     @classmethod
@@ -1150,12 +1267,11 @@ class DropTableOp(MigrateOperation):
             return self._orig_table
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.table(
-            self.table_name,
-            schema=self.schema,
-            **self.table_kw)
+            self.table_name, schema=self.schema, **self.table_kw
+        )
 
     @classmethod
-    @util._with_legacy_names([('name', 'table_name')])
+    @util._with_legacy_names([("name", "table_name")])
     def drop_table(cls, operations, table_name, schema=None, **kw):
         r"""Issue a "drop table" instruction using the current
         migration context.
@@ -1205,7 +1321,8 @@ class RenameTableOp(AlterTableOp):
 
     @classmethod
     def rename_table(
-            cls, operations, old_table_name, new_table_name, schema=None):
+        cls, operations, old_table_name, new_table_name, schema=None
+    ):
         """Emit an ALTER TABLE to rename a table.
 
         :param old_table_name: old name.
@@ -1229,16 +1346,18 @@ class AlterColumnOp(AlterTableOp):
     """Represent an alter column operation."""
 
     def __init__(
-            self, table_name, column_name, schema=None,
-            existing_type=None,
-            existing_server_default=False,
-            existing_nullable=None,
-            modify_nullable=None,
-            modify_server_default=False,
-            modify_name=None,
-            modify_type=None,
-            **kw
-
+        self,
+        table_name,
+        column_name,
+        schema=None,
+        existing_type=None,
+        existing_server_default=False,
+        existing_nullable=None,
+        modify_nullable=None,
+        modify_server_default=False,
+        modify_name=None,
+        modify_type=None,
+        **kw
     ):
         super(AlterColumnOp, self).__init__(table_name, schema=schema)
         self.column_name = column_name
@@ -1257,47 +1376,64 @@ class AlterColumnOp(AlterTableOp):
 
         if self.modify_type is not None:
             col_diff.append(
-                ("modify_type", schema, tname, cname,
-                 {
-                     "existing_nullable": self.existing_nullable,
-                     "existing_server_default": self.existing_server_default,
-                 },
-                 self.existing_type,
-                 self.modify_type)
+                (
+                    "modify_type",
+                    schema,
+                    tname,
+                    cname,
+                    {
+                        "existing_nullable": self.existing_nullable,
+                        "existing_server_default": self.existing_server_default,
+                    },
+                    self.existing_type,
+                    self.modify_type,
+                )
             )
 
         if self.modify_nullable is not None:
             col_diff.append(
-                ("modify_nullable", schema, tname, cname,
+                (
+                    "modify_nullable",
+                    schema,
+                    tname,
+                    cname,
                     {
                         "existing_type": self.existing_type,
-                        "existing_server_default": self.existing_server_default
+                        "existing_server_default": self.existing_server_default,
                     },
                     self.existing_nullable,
-                    self.modify_nullable)
+                    self.modify_nullable,
+                )
             )
 
         if self.modify_server_default is not False:
             col_diff.append(
-                ("modify_default", schema, tname, cname,
-                 {
-                     "existing_nullable": self.existing_nullable,
-                     "existing_type": self.existing_type
-                 },
-                 self.existing_server_default,
-                 self.modify_server_default)
+                (
+                    "modify_default",
+                    schema,
+                    tname,
+                    cname,
+                    {
+                        "existing_nullable": self.existing_nullable,
+                        "existing_type": self.existing_type,
+                    },
+                    self.existing_server_default,
+                    self.modify_server_default,
+                )
             )
 
         return col_diff
 
     def has_changes(self):
-        hc1 = self.modify_nullable is not None or \
-            self.modify_server_default is not False or \
-            self.modify_type is not None
+        hc1 = (
+            self.modify_nullable is not None
+            or self.modify_server_default is not False
+            or self.modify_type is not None
+        )
         if hc1:
             return True
         for kw in self.kw:
-            if kw.startswith('modify_'):
+            if kw.startswith("modify_"):
                 return True
         else:
             return False
@@ -1305,37 +1441,40 @@ class AlterColumnOp(AlterTableOp):
     def reverse(self):
 
         kw = self.kw.copy()
-        kw['existing_type'] = self.existing_type
-        kw['existing_nullable'] = self.existing_nullable
-        kw['existing_server_default'] = self.existing_server_default
+        kw["existing_type"] = self.existing_type
+        kw["existing_nullable"] = self.existing_nullable
+        kw["existing_server_default"] = self.existing_server_default
         if self.modify_type is not None:
-            kw['modify_type'] = self.modify_type
+            kw["modify_type"] = self.modify_type
         if self.modify_nullable is not None:
-            kw['modify_nullable'] = self.modify_nullable
+            kw["modify_nullable"] = self.modify_nullable
         if self.modify_server_default is not False:
-            kw['modify_server_default'] = self.modify_server_default
+            kw["modify_server_default"] = self.modify_server_default
 
         # TODO: make this a little simpler
-        all_keys = set(m.group(1) for m in [
-            re.match(r'^(?:existing_|modify_)(.+)$', k)
-            for k in kw
-        ] if m)
+        all_keys = set(
+            m.group(1)
+            for m in [re.match(r"^(?:existing_|modify_)(.+)$", k) for k in kw]
+            if m
+        )
 
         for k in all_keys:
-            if 'modify_%s' % k in kw:
-                swap = kw['existing_%s' % k]
-                kw['existing_%s' % k] = kw['modify_%s' % k]
-                kw['modify_%s' % k] = swap
+            if "modify_%s" % k in kw:
+                swap = kw["existing_%s" % k]
+                kw["existing_%s" % k] = kw["modify_%s" % k]
+                kw["modify_%s" % k] = swap
 
         return self.__class__(
-            self.table_name, self.column_name, schema=self.schema,
-            **kw
+            self.table_name, self.column_name, schema=self.schema, **kw
         )
 
     @classmethod
-    @util._with_legacy_names([('name', 'new_column_name')])
+    @util._with_legacy_names([("name", "new_column_name")])
     def alter_column(
-        cls, operations, table_name, column_name,
+        cls,
+        operations,
+        table_name,
+        column_name,
         nullable=None,
         server_default=False,
         new_column_name=None,
@@ -1343,7 +1482,8 @@ class AlterColumnOp(AlterTableOp):
         existing_type=None,
         existing_server_default=False,
         existing_nullable=None,
-        schema=None, **kw
+        schema=None,
+        **kw
     ):
         """Issue an "alter column" instruction using the
         current migration context.
@@ -1430,7 +1570,9 @@ class AlterColumnOp(AlterTableOp):
         """
 
         alt = cls(
-            table_name, column_name, schema=schema,
+            table_name,
+            column_name,
+            schema=schema,
             existing_type=existing_type,
             existing_server_default=existing_server_default,
             existing_nullable=existing_nullable,
@@ -1445,7 +1587,9 @@ class AlterColumnOp(AlterTableOp):
 
     @classmethod
     def batch_alter_column(
-        cls, operations, column_name,
+        cls,
+        operations,
+        column_name,
         nullable=None,
         server_default=False,
         new_column_name=None,
@@ -1464,7 +1608,8 @@ class AlterColumnOp(AlterTableOp):
 
         """
         alt = cls(
-            operations.impl.table_name, column_name,
+            operations.impl.table_name,
+            column_name,
             schema=operations.impl.schema,
             existing_type=existing_type,
             existing_server_default=existing_server_default,
@@ -1490,7 +1635,8 @@ class AddColumnOp(AlterTableOp):
 
     def reverse(self):
         return DropColumnOp.from_column_and_tablename(
-            self.schema, self.table_name, self.column)
+            self.schema, self.table_name, self.column
+        )
 
     def to_diff_tuple(self):
         return ("add_column", self.schema, self.table_name, self.column)
@@ -1575,8 +1721,7 @@ class AddColumnOp(AlterTableOp):
 
         """
         op = cls(
-            operations.impl.table_name, column,
-            schema=operations.impl.schema
+            operations.impl.table_name, column, schema=operations.impl.schema
         )
         return operations.invoke(op)
 
@@ -1587,8 +1732,8 @@ class DropColumnOp(AlterTableOp):
     """Represent a drop column operation."""
 
     def __init__(
-            self, table_name, column_name, schema=None,
-            _orig_column=None, **kw):
+        self, table_name, column_name, schema=None, _orig_column=None, **kw
+    ):
         super(DropColumnOp, self).__init__(table_name, schema=schema)
         self.column_name = column_name
         self.kw = kw
@@ -1596,16 +1741,22 @@ class DropColumnOp(AlterTableOp):
 
     def to_diff_tuple(self):
         return (
-            "remove_column", self.schema, self.table_name, self.to_column())
+            "remove_column",
+            self.schema,
+            self.table_name,
+            self.to_column(),
+        )
 
     def reverse(self):
         if self._orig_column is None:
             raise ValueError(
                 "operation is not reversible; "
-                "original column is not present")
+                "original column is not present"
+            )
 
         return AddColumnOp.from_column_and_tablename(
-            self.schema, self.table_name, self._orig_column)
+            self.schema, self.table_name, self._orig_column
+        )
 
     @classmethod
     def from_column_and_tablename(cls, schema, tname, col):
@@ -1619,7 +1770,8 @@ class DropColumnOp(AlterTableOp):
 
     @classmethod
     def drop_column(
-            cls, operations, table_name, column_name, schema=None, **kw):
+        cls, operations, table_name, column_name, schema=None, **kw
+    ):
         """Issue a "drop column" instruction using the current
         migration context.
 
@@ -1677,8 +1829,11 @@ class DropColumnOp(AlterTableOp):
 
         """
         op = cls(
-            operations.impl.table_name, column_name,
-            schema=operations.impl.schema, **kw)
+            operations.impl.table_name,
+            column_name,
+            schema=operations.impl.schema,
+            **kw
+        )
         return operations.invoke(op)
 
 
@@ -1877,6 +2032,7 @@ class ExecuteSQLOp(MigrateOperation):
 
 class OpContainer(MigrateOperation):
     """Represent a sequence of operations operation."""
+
     def __init__(self, ops=()):
         self.ops = ops
 
@@ -1889,7 +2045,7 @@ class OpContainer(MigrateOperation):
     @classmethod
     def _ops_as_diffs(cls, migrations):
         for op in migrations.ops:
-            if hasattr(op, 'ops'):
+            if hasattr(op, "ops"):
                 for sub_op in cls._ops_as_diffs(op):
                     yield sub_op
             else:
@@ -1907,10 +2063,8 @@ class ModifyTableOps(OpContainer):
     def reverse(self):
         return ModifyTableOps(
             self.table_name,
-            ops=list(reversed(
-                [op.reverse() for op in self.ops]
-            )),
-            schema=self.schema
+            ops=list(reversed([op.reverse() for op in self.ops])),
+            schema=self.schema,
         )
 
 
@@ -1929,9 +2083,9 @@ class UpgradeOps(OpContainer):
         self.upgrade_token = upgrade_token
 
     def reverse_into(self, downgrade_ops):
-        downgrade_ops.ops[:] = list(reversed(
-            [op.reverse() for op in self.ops]
-        ))
+        downgrade_ops.ops[:] = list(
+            reversed([op.reverse() for op in self.ops])
+        )
         return downgrade_ops
 
     def reverse(self):
@@ -1954,9 +2108,7 @@ class DowngradeOps(OpContainer):
 
     def reverse(self):
         return UpgradeOps(
-            ops=list(reversed(
-                [op.reverse() for op in self.ops]
-            ))
+            ops=list(reversed([op.reverse() for op in self.ops]))
         )
 
 
@@ -1990,10 +2142,18 @@ class MigrationScript(MigrateOperation):
     """
 
     def __init__(
-            self, rev_id, upgrade_ops, downgrade_ops,
-            message=None,
-            imports=set(), head=None, splice=None,
-            branch_label=None, version_path=None, depends_on=None):
+        self,
+        rev_id,
+        upgrade_ops,
+        downgrade_ops,
+        message=None,
+        imports=set(),
+        head=None,
+        splice=None,
+        branch_label=None,
+        version_path=None,
+        depends_on=None,
+    ):
         self.rev_id = rev_id
         self.message = message
         self.imports = imports
@@ -2017,7 +2177,8 @@ class MigrationScript(MigrateOperation):
             raise ValueError(
                 "This MigrationScript instance has a multiple-entry "
                 "list for UpgradeOps; please use the "
-                "upgrade_ops_list attribute.")
+                "upgrade_ops_list attribute."
+            )
         elif not self._upgrade_ops:
             return None
         else:
@@ -2041,7 +2202,8 @@ class MigrationScript(MigrateOperation):
             raise ValueError(
                 "This MigrationScript instance has a multiple-entry "
                 "list for DowngradeOps; please use the "
-                "downgrade_ops_list attribute.")
+                "downgrade_ops_list attribute."
+            )
         elif not self._downgrade_ops:
             return None
         else:
@@ -2078,4 +2240,3 @@ class MigrationScript(MigrateOperation):
 
         """
         return self._downgrade_ops
-
index 1014ace27bd4af418e9520960246498df729eb79..548b6c5ad5dfb5cec4baa967268949a6d4fb6a5f 100644 (file)
@@ -5,69 +5,82 @@ from .. import util
 
 
 class SchemaObjects(object):
-
     def __init__(self, migration_context=None):
         self.migration_context = migration_context
 
     def primary_key_constraint(self, name, table_name, cols, schema=None):
         m = self.metadata()
         columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
-        t = sa_schema.Table(
-            table_name, m,
-            *columns,
-            schema=schema)
-        p = sa_schema.PrimaryKeyConstraint(
-            *[t.c[n] for n in cols], name=name)
+        t = sa_schema.Table(table_name, m, *columns, schema=schema)
+        p = sa_schema.PrimaryKeyConstraint(*[t.c[n] for n in cols], name=name)
         t.append_constraint(p)
         return p
 
     def foreign_key_constraint(
-        self, name, source, referent,
-        local_cols, remote_cols,
-        onupdate=None, ondelete=None,
-        deferrable=None, source_schema=None,
-        referent_schema=None, initially=None,
-            match=None, **dialect_kw):
+        self,
+        name,
+        source,
+        referent,
+        local_cols,
+        remote_cols,
+        onupdate=None,
+        ondelete=None,
+        deferrable=None,
+        source_schema=None,
+        referent_schema=None,
+        initially=None,
+        match=None,
+        **dialect_kw
+    ):
         m = self.metadata()
         if source == referent and source_schema == referent_schema:
             t1_cols = local_cols + remote_cols
         else:
             t1_cols = local_cols
             sa_schema.Table(
-                referent, m,
+                referent,
+                m,
                 *[sa_schema.Column(n, NULLTYPE) for n in remote_cols],
-                schema=referent_schema)
+                schema=referent_schema
+            )
 
         t1 = sa_schema.Table(
-            source, m,
+            source,
+            m,
             *[sa_schema.Column(n, NULLTYPE) for n in t1_cols],
-            schema=source_schema)
-
-        tname = "%s.%s" % (referent_schema, referent) if referent_schema \
-                else referent
-
-        dialect_kw['match'] = match
-
-        f = sa_schema.ForeignKeyConstraint(local_cols,
-                                           ["%s.%s" % (tname, n)
-                                            for n in remote_cols],
-                                           name=name,
-                                           onupdate=onupdate,
-                                           ondelete=ondelete,
-                                           deferrable=deferrable,
-                                           initially=initially,
-                                           **dialect_kw
-                                           )
+            schema=source_schema
+        )
+
+        tname = (
+            "%s.%s" % (referent_schema, referent)
+            if referent_schema
+            else referent
+        )
+
+        dialect_kw["match"] = match
+
+        f = sa_schema.ForeignKeyConstraint(
+            local_cols,
+            ["%s.%s" % (tname, n) for n in remote_cols],
+            name=name,
+            onupdate=onupdate,
+            ondelete=ondelete,
+            deferrable=deferrable,
+            initially=initially,
+            **dialect_kw
+        )
         t1.append_constraint(f)
 
         return f
 
     def unique_constraint(self, name, source, local_cols, schema=None, **kw):
         t = sa_schema.Table(
-            source, self.metadata(),
+            source,
+            self.metadata(),
             *[sa_schema.Column(n, NULLTYPE) for n in local_cols],
-            schema=schema)
-        kw['name'] = name
+            schema=schema
+        )
+        kw["name"] = name
         uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw)
         # TODO: need event tests to ensure the event
         # is fired off here
@@ -75,8 +88,12 @@ class SchemaObjects(object):
         return uq
 
     def check_constraint(self, name, source, condition, schema=None, **kw):
-        t = sa_schema.Table(source, self.metadata(),
-                            sa_schema.Column('x', Integer), schema=schema)
+        t = sa_schema.Table(
+            source,
+            self.metadata(),
+            sa_schema.Column("x", Integer),
+            schema=schema,
+        )
         ck = sa_schema.CheckConstraint(condition, name=name, **kw)
         t.append_constraint(ck)
         return ck
@@ -84,18 +101,21 @@ class SchemaObjects(object):
     def generic_constraint(self, name, table_name, type_, schema=None, **kw):
         t = self.table(table_name, schema=schema)
         types = {
-            'foreignkey': lambda name: sa_schema.ForeignKeyConstraint(
-                [], [], name=name),
-            'primary': sa_schema.PrimaryKeyConstraint,
-            'unique': sa_schema.UniqueConstraint,
-            'check': lambda name: sa_schema.CheckConstraint("", name=name),
-            None: sa_schema.Constraint
+            "foreignkey": lambda name: sa_schema.ForeignKeyConstraint(
+                [], [], name=name
+            ),
+            "primary": sa_schema.PrimaryKeyConstraint,
+            "unique": sa_schema.UniqueConstraint,
+            "check": lambda name: sa_schema.CheckConstraint("", name=name),
+            None: sa_schema.Constraint,
         }
         try:
             const = types[type_]
         except KeyError:
-            raise TypeError("'type' can be one of %s" %
-                            ", ".join(sorted(repr(x) for x in types)))
+            raise TypeError(
+                "'type' can be one of %s"
+                % ", ".join(sorted(repr(x) for x in types))
+            )
         else:
             const = const(name=name)
             t.append_constraint(const)
@@ -103,11 +123,13 @@ class SchemaObjects(object):
 
     def metadata(self):
         kw = {}
-        if self.migration_context is not None and \
-                'target_metadata' in self.migration_context.opts:
-            mt = self.migration_context.opts['target_metadata']
-            if hasattr(mt, 'naming_convention'):
-                kw['naming_convention'] = mt.naming_convention
+        if (
+            self.migration_context is not None
+            and "target_metadata" in self.migration_context.opts
+        ):
+            mt = self.migration_context.opts["target_metadata"]
+            if hasattr(mt, "naming_convention"):
+                kw["naming_convention"] = mt.naming_convention
         return sa_schema.MetaData(**kw)
 
     def table(self, name, *columns, **kw):
@@ -122,18 +144,18 @@ class SchemaObjects(object):
 
     def index(self, name, tablename, columns, schema=None, **kw):
         t = sa_schema.Table(
-            tablename or 'no_table', self.metadata(),
-            schema=schema
+            tablename or "no_table", self.metadata(), schema=schema
         )
         idx = sa_schema.Index(
             name,
             *[util.sqla_compat._textual_index_column(t, n) for n in columns],
-            **kw)
+            **kw
+        )
         return idx
 
     def _parse_table_key(self, table_key):
-        if '.' in table_key:
-            tokens = table_key.split('.')
+        if "." in table_key:
+            tokens = table_key.split(".")
             sname = ".".join(tokens[0:-1])
             tname = tokens[-1]
         else:
@@ -147,7 +169,7 @@ class SchemaObjects(object):
 
         """
         if isinstance(fk._colspec, string_types):
-            table_key, cname = fk._colspec.rsplit('.', 1)
+            table_key, cname = fk._colspec.rsplit(".", 1)
             sname, tname = self._parse_table_key(table_key)
             if table_key not in metadata.tables:
                 rel_t = sa_schema.Table(tname, metadata, schema=sname)
index 13273673e22e404875d9ee6790a10880d251206b..1635a4244feb698517fb25cab251d5c922adcaad 100644 (file)
@@ -8,8 +8,7 @@ from sqlalchemy import schema as sa_schema
 def alter_column(operations, operation):
 
     compiler = operations.impl.dialect.statement_compiler(
-        operations.impl.dialect,
-        None
+        operations.impl.dialect, None
     )
 
     existing_type = operation.existing_type
@@ -24,24 +23,23 @@ def alter_column(operations, operation):
     nullable = operation.modify_nullable
 
     def _count_constraint(constraint):
-        return not isinstance(
-            constraint,
-            sa_schema.PrimaryKeyConstraint) and \
-            (not constraint._create_rule or
-                constraint._create_rule(compiler))
+        return not isinstance(constraint, sa_schema.PrimaryKeyConstraint) and (
+            not constraint._create_rule or constraint._create_rule(compiler)
+        )
 
     if existing_type and type_:
         t = operations.schema_obj.table(
             table_name,
             sa_schema.Column(column_name, existing_type),
-            schema=schema
+            schema=schema,
         )
         for constraint in t.constraints:
             if _count_constraint(constraint):
                 operations.impl.drop_constraint(constraint)
 
     operations.impl.alter_column(
-        table_name, column_name,
+        table_name,
+        column_name,
         nullable=nullable,
         server_default=server_default,
         name=new_column_name,
@@ -57,7 +55,7 @@ def alter_column(operations, operation):
         t = operations.schema_obj.table(
             table_name,
             operations.schema_obj.column(column_name, type_),
-            schema=schema
+            schema=schema,
         )
         for constraint in t.constraints:
             if _count_constraint(constraint):
@@ -75,10 +73,7 @@ def drop_table(operations, operation):
 def drop_column(operations, operation):
     column = operation.to_column(operations.migration_context)
     operations.impl.drop_column(
-        operation.table_name,
-        column,
-        schema=operation.schema,
-        **operation.kw
+        operation.table_name, column, schema=operation.schema, **operation.kw
     )
 
 
@@ -105,9 +100,8 @@ def create_table(operations, operation):
 @Operations.implementation_for(ops.RenameTableOp)
 def rename_table(operations, operation):
     operations.impl.rename_table(
-        operation.table_name,
-        operation.new_table_name,
-        schema=operation.schema)
+        operation.table_name, operation.new_table_name, schema=operation.schema
+    )
 
 
 @Operations.implementation_for(ops.AddColumnOp)
@@ -117,11 +111,7 @@ def add_column(operations, operation):
     schema = operation.schema
 
     t = operations.schema_obj.table(table_name, column, schema=schema)
-    operations.impl.add_column(
-        table_name,
-        column,
-        schema=schema
-    )
+    operations.impl.add_column(table_name, column, schema=schema)
     for constraint in t.constraints:
         if not isinstance(constraint, sa_schema.PrimaryKeyConstraint):
             operations.impl.add_constraint(constraint)
@@ -151,12 +141,12 @@ def drop_constraint(operations, operation):
 @Operations.implementation_for(ops.BulkInsertOp)
 def bulk_insert(operations, operation):
     operations.impl.bulk_insert(
-        operation.table, operation.rows, multiinsert=operation.multiinsert)
+        operation.table, operation.rows, multiinsert=operation.multiinsert
+    )
 
 
 @Operations.implementation_for(ops.ExecuteSQLOp)
 def execute_sql(operations, operation):
     operations.migration_context.impl.execute(
-        operation.sqltext,
-        execution_options=operation.execution_options
+        operation.sqltext, execution_options=operation.execution_options
     )
index ce9be63785e648663003678f767cc06194fe1920..32db3ae51ce73f171f5b070f92595c1063146b11 100644 (file)
@@ -120,7 +120,7 @@ class EnvironmentContext(util.ModuleClsProxy):
         has been configured.
 
         """
-        return self.context_opts.get('as_sql', False)
+        return self.context_opts.get("as_sql", False)
 
     def is_transactional_ddl(self):
         """Return True if the context is configured to expect a
@@ -182,17 +182,20 @@ class EnvironmentContext(util.ModuleClsProxy):
         """
         if self._migration_context is not None:
             return self.script.as_revision_number(
-                self.get_context()._start_from_rev)
-        elif 'starting_rev' in self.context_opts:
+                self.get_context()._start_from_rev
+            )
+        elif "starting_rev" in self.context_opts:
             return self.script.as_revision_number(
-                self.context_opts['starting_rev'])
+                self.context_opts["starting_rev"]
+            )
         else:
             # this should raise only in the case that a command
             # is being run where the "starting rev" is never applicable;
             # this is to catch scripts which rely upon this in
             # non-sql mode or similar
             raise util.CommandError(
-                "No starting revision argument is available.")
+                "No starting revision argument is available."
+            )
 
     def get_revision_argument(self):
         """Get the 'destination' revision argument.
@@ -209,7 +212,8 @@ class EnvironmentContext(util.ModuleClsProxy):
 
         """
         return self.script.as_revision_number(
-            self.context_opts['destination_rev'])
+            self.context_opts["destination_rev"]
+        )
 
     def get_tag_argument(self):
         """Return the value passed for the ``--tag`` argument, if any.
@@ -229,7 +233,7 @@ class EnvironmentContext(util.ModuleClsProxy):
             line.
 
         """
-        return self.context_opts.get('tag', None)
+        return self.context_opts.get("tag", None)
 
     def get_x_argument(self, as_dictionary=False):
         """Return the value(s) passed for the ``-x`` argument, if any.
@@ -277,39 +281,38 @@ class EnvironmentContext(util.ModuleClsProxy):
         else:
             value = []
         if as_dictionary:
-            value = dict(
-                arg.split('=', 1) for arg in value
-            )
+            value = dict(arg.split("=", 1) for arg in value)
         return value
 
-    def configure(self,
-                  connection=None,
-                  url=None,
-                  dialect_name=None,
-                  transactional_ddl=None,
-                  transaction_per_migration=False,
-                  output_buffer=None,
-                  starting_rev=None,
-                  tag=None,
-                  template_args=None,
-                  render_as_batch=False,
-                  target_metadata=None,
-                  include_symbol=None,
-                  include_object=None,
-                  include_schemas=False,
-                  process_revision_directives=None,
-                  compare_type=False,
-                  compare_server_default=False,
-                  render_item=None,
-                  literal_binds=False,
-                  upgrade_token="upgrades",
-                  downgrade_token="downgrades",
-                  alembic_module_prefix="op.",
-                  sqlalchemy_module_prefix="sa.",
-                  user_module_prefix=None,
-                  on_version_apply=None,
-                  **kw
-                  ):
+    def configure(
+        self,
+        connection=None,
+        url=None,
+        dialect_name=None,
+        transactional_ddl=None,
+        transaction_per_migration=False,
+        output_buffer=None,
+        starting_rev=None,
+        tag=None,
+        template_args=None,
+        render_as_batch=False,
+        target_metadata=None,
+        include_symbol=None,
+        include_object=None,
+        include_schemas=False,
+        process_revision_directives=None,
+        compare_type=False,
+        compare_server_default=False,
+        render_item=None,
+        literal_binds=False,
+        upgrade_token="upgrades",
+        downgrade_token="downgrades",
+        alembic_module_prefix="op.",
+        sqlalchemy_module_prefix="sa.",
+        user_module_prefix=None,
+        on_version_apply=None,
+        **kw
+    ):
         """Configure a :class:`.MigrationContext` within this
         :class:`.EnvironmentContext` which will provide database
         connectivity and other configuration to a series of
@@ -774,33 +777,33 @@ class EnvironmentContext(util.ModuleClsProxy):
         elif self.config.output_buffer is not None:
             opts["output_buffer"] = self.config.output_buffer
         if starting_rev:
-            opts['starting_rev'] = starting_rev
+            opts["starting_rev"] = starting_rev
         if tag:
-            opts['tag'] = tag
-        if template_args and 'template_args' in opts:
-            opts['template_args'].update(template_args)
+            opts["tag"] = tag
+        if template_args and "template_args" in opts:
+            opts["template_args"].update(template_args)
         opts["transaction_per_migration"] = transaction_per_migration
-        opts['target_metadata'] = target_metadata
-        opts['include_symbol'] = include_symbol
-        opts['include_object'] = include_object
-        opts['include_schemas'] = include_schemas
-        opts['render_as_batch'] = render_as_batch
-        opts['upgrade_token'] = upgrade_token
-        opts['downgrade_token'] = downgrade_token
-        opts['sqlalchemy_module_prefix'] = sqlalchemy_module_prefix
-        opts['alembic_module_prefix'] = alembic_module_prefix
-        opts['user_module_prefix'] = user_module_prefix
-        opts['literal_binds'] = literal_binds
-        opts['process_revision_directives'] = process_revision_directives
-        opts['on_version_apply'] = util.to_tuple(on_version_apply, default=())
+        opts["target_metadata"] = target_metadata
+        opts["include_symbol"] = include_symbol
+        opts["include_object"] = include_object
+        opts["include_schemas"] = include_schemas
+        opts["render_as_batch"] = render_as_batch
+        opts["upgrade_token"] = upgrade_token
+        opts["downgrade_token"] = downgrade_token
+        opts["sqlalchemy_module_prefix"] = sqlalchemy_module_prefix
+        opts["alembic_module_prefix"] = alembic_module_prefix
+        opts["user_module_prefix"] = user_module_prefix
+        opts["literal_binds"] = literal_binds
+        opts["process_revision_directives"] = process_revision_directives
+        opts["on_version_apply"] = util.to_tuple(on_version_apply, default=())
 
         if render_item is not None:
-            opts['render_item'] = render_item
+            opts["render_item"] = render_item
         if compare_type is not None:
-            opts['compare_type'] = compare_type
+            opts["compare_type"] = compare_type
         if compare_server_default is not None:
-            opts['compare_server_default'] = compare_server_default
-        opts['script'] = self.script
+            opts["compare_server_default"] = compare_server_default
+        opts["script"] = self.script
 
         opts.update(kw)
 
@@ -809,7 +812,7 @@ class EnvironmentContext(util.ModuleClsProxy):
             url=url,
             dialect_name=dialect_name,
             environment_context=self,
-            opts=opts
+            opts=opts,
         )
 
     def run_migrations(self, **kw):
@@ -847,8 +850,7 @@ class EnvironmentContext(util.ModuleClsProxy):
         first been made available via :meth:`.configure`.
 
         """
-        self.get_context().execute(sql,
-                                   execution_options=execution_options)
+        self.get_context().execute(sql, execution_options=execution_options)
 
     def static_output(self, text):
         """Emit text directly to the "offline" SQL stream.
index 17cc2265438be65600b42e83e0f6028955b8a588..80dc8ff096143af4701f264c984d66aed7876cd2 100644 (file)
@@ -2,8 +2,14 @@ import logging
 import sys
 from contextlib import contextmanager
 
-from sqlalchemy import MetaData, Table, Column, String, literal_column,\
-    PrimaryKeyConstraint
+from sqlalchemy import (
+    MetaData,
+    Table,
+    Column,
+    String,
+    literal_column,
+    PrimaryKeyConstraint,
+)
 from sqlalchemy.engine.strategies import MockEngineStrategy
 from sqlalchemy.engine import url as sqla_url
 from sqlalchemy.engine import Connection
@@ -65,71 +71,82 @@ class MigrationContext(object):
         self.environment_context = environment_context
         self.opts = opts
         self.dialect = dialect
-        self.script = opts.get('script')
-        as_sql = opts.get('as_sql', False)
+        self.script = opts.get("script")
+        as_sql = opts.get("as_sql", False)
         transactional_ddl = opts.get("transactional_ddl")
         self._transaction_per_migration = opts.get(
-            "transaction_per_migration", False)
-        self.on_version_apply_callbacks = opts.get('on_version_apply', ())
+            "transaction_per_migration", False
+        )
+        self.on_version_apply_callbacks = opts.get("on_version_apply", ())
 
         if as_sql:
             self.connection = self._stdout_connection(connection)
             assert self.connection is not None
         else:
             self.connection = connection
-        self._migrations_fn = opts.get('fn')
+        self._migrations_fn = opts.get("fn")
         self.as_sql = as_sql
 
         if "output_encoding" in opts:
             self.output_buffer = EncodedIO(
                 opts.get("output_buffer") or sys.stdout,
-                opts['output_encoding']
+                opts["output_encoding"],
             )
         else:
             self.output_buffer = opts.get("output_buffer", sys.stdout)
 
-        self._user_compare_type = opts.get('compare_type', False)
+        self._user_compare_type = opts.get("compare_type", False)
         self._user_compare_server_default = opts.get(
-            'compare_server_default',
-            False)
+            "compare_server_default", False
+        )
         self.version_table = version_table = opts.get(
-            'version_table', 'alembic_version')
-        self.version_table_schema = version_table_schema = \
-            opts.get('version_table_schema', None)
+            "version_table", "alembic_version"
+        )
+        self.version_table_schema = version_table_schema = opts.get(
+            "version_table_schema", None
+        )
         self._version = Table(
-            version_table, MetaData(),
-            Column('version_num', String(32), nullable=False),
-            schema=version_table_schema)
+            version_table,
+            MetaData(),
+            Column("version_num", String(32), nullable=False),
+            schema=version_table_schema,
+        )
         if opts.get("version_table_pk", True):
             self._version.append_constraint(
                 PrimaryKeyConstraint(
-                    'version_num', name="%s_pkc" % version_table
+                    "version_num", name="%s_pkc" % version_table
                 )
             )
 
         self._start_from_rev = opts.get("starting_rev")
         self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
-            dialect, self.connection, self.as_sql,
+            dialect,
+            self.connection,
+            self.as_sql,
             transactional_ddl,
             self.output_buffer,
-            opts
+            opts,
         )
         log.info("Context impl %s.", self.impl.__class__.__name__)
         if self.as_sql:
             log.info("Generating static SQL")
-        log.info("Will assume %s DDL.",
-                 "transactional" if self.impl.transactional_ddl
-                 else "non-transactional")
+        log.info(
+            "Will assume %s DDL.",
+            "transactional"
+            if self.impl.transactional_ddl
+            else "non-transactional",
+        )
 
     @classmethod
-    def configure(cls,
-                  connection=None,
-                  url=None,
-                  dialect_name=None,
-                  dialect=None,
-                  environment_context=None,
-                  opts=None,
-                  ):
+    def configure(
+        cls,
+        connection=None,
+        url=None,
+        dialect_name=None,
+        dialect=None,
+        environment_context=None,
+        opts=None,
+    ):
         """Create a new :class:`.MigrationContext`.
 
         This is a factory method usually called
@@ -158,7 +175,8 @@ class MigrationContext(object):
                 util.warn(
                     "'connection' argument to configure() is expected "
                     "to be a sqlalchemy.engine.Connection instance, "
-                    "got %r" % connection)
+                    "got %r" % connection
+                )
             dialect = connection.dialect
         elif url:
             url = sqla_url.make_url(url)
@@ -175,22 +193,28 @@ class MigrationContext(object):
         transaction_now = _per_migration == self._transaction_per_migration
 
         if not transaction_now:
+
             @contextmanager
             def do_nothing():
                 yield
+
             return do_nothing()
 
         elif not self.impl.transactional_ddl:
+
             @contextmanager
             def do_nothing():
                 yield
+
             return do_nothing()
         elif self.as_sql:
+
             @contextmanager
             def begin_commit():
                 self.impl.emit_begin()
                 yield
                 self.impl.emit_commit()
+
             return begin_commit()
         else:
             return self.bind.begin()
@@ -217,7 +241,8 @@ class MigrationContext(object):
         elif len(heads) > 1:
             raise util.CommandError(
                 "Version table '%s' has more than one head present; "
-                "please use get_current_heads()" % self.version_table)
+                "please use get_current_heads()" % self.version_table
+            )
         else:
             return heads[0]
 
@@ -243,18 +268,20 @@ class MigrationContext(object):
         """
         if self.as_sql:
             start_from_rev = self._start_from_rev
-            if start_from_rev == 'base':
+            if start_from_rev == "base":
                 start_from_rev = None
             elif start_from_rev is not None and self.script:
-                start_from_rev = \
-                    self.script.get_revision(start_from_rev).revision
+                start_from_rev = self.script.get_revision(
+                    start_from_rev
+                ).revision
 
             return util.to_tuple(start_from_rev, default=())
         else:
             if self._start_from_rev:
                 raise util.CommandError(
                     "Can't specify current_rev to context "
-                    "when using a database connection")
+                    "when using a database connection"
+                )
             if not self._has_version_table():
                 return ()
         return tuple(
@@ -266,7 +293,8 @@ class MigrationContext(object):
 
     def _has_version_table(self):
         return self.connection.dialect.has_table(
-            self.connection, self.version_table, self.version_table_schema)
+            self.connection, self.version_table, self.version_table_schema
+        )
 
     def stamp(self, script_directory, revision):
         """Stamp the version table with a specific revision.
@@ -315,8 +343,9 @@ class MigrationContext(object):
 
         head_maintainer = HeadMaintainer(self, heads)
 
-        starting_in_transaction = not self.as_sql and \
-            self._in_connection_transaction()
+        starting_in_transaction = (
+            not self.as_sql and self._in_connection_transaction()
+        )
 
         for step in self._migrations_fn(heads, self):
             with self.begin_transaction(_per_migration=True):
@@ -326,7 +355,9 @@ class MigrationContext(object):
                     self._version.create(self.connection)
                 log.info("Running %s", step)
                 if self.as_sql:
-                    self.impl.static_output("-- Running %s" % (step.short_log,))
+                    self.impl.static_output(
+                        "-- Running %s" % (step.short_log,)
+                    )
                 step.migration_fn(**kw)
 
                 # previously, we wouldn't stamp per migration
@@ -336,19 +367,24 @@ class MigrationContext(object):
                 # just to run the operations on every version
                 head_maintainer.update_to_step(step)
                 for callback in self.on_version_apply_callbacks:
-                    callback(ctx=self,
-                             step=step.info,
-                             heads=set(head_maintainer.heads),
-                             run_args=kw)
-
-            if not starting_in_transaction and not self.as_sql and \
-                not self.impl.transactional_ddl and \
-                    self._in_connection_transaction():
+                    callback(
+                        ctx=self,
+                        step=step.info,
+                        heads=set(head_maintainer.heads),
+                        run_args=kw,
+                    )
+
+            if (
+                not starting_in_transaction
+                and not self.as_sql
+                and not self.impl.transactional_ddl
+                and self._in_connection_transaction()
+            ):
                 raise util.CommandError(
-                    "Migration \"%s\" has left an uncommitted "
+                    'Migration "%s" has left an uncommitted '
                     "transaction opened; transactional_ddl is False so "
-                    "Alembic is not committing transactions"
-                    % step)
+                    "Alembic is not committing transactions" % step
+                )
 
         if self.as_sql and not head_maintainer.heads:
             self._version.drop(self.connection)
@@ -421,19 +457,20 @@ class MigrationContext(object):
                 inspector_column,
                 metadata_column,
                 inspector_column.type,
-                metadata_column.type
+                metadata_column.type,
             )
             if user_value is not None:
                 return user_value
 
-        return self.impl.compare_type(
-            inspector_column,
-            metadata_column)
+        return self.impl.compare_type(inspector_column, metadata_column)
 
-    def _compare_server_default(self, inspector_column,
-                                metadata_column,
-                                rendered_metadata_default,
-                                rendered_column_default):
+    def _compare_server_default(
+        self,
+        inspector_column,
+        metadata_column,
+        rendered_metadata_default,
+        rendered_column_default,
+    ):
 
         if self._user_compare_server_default is False:
             return False
@@ -445,7 +482,7 @@ class MigrationContext(object):
                 metadata_column,
                 rendered_column_default,
                 metadata_column.server_default,
-                rendered_metadata_default
+                rendered_metadata_default,
             )
             if user_value is not None:
                 return user_value
@@ -454,7 +491,8 @@ class MigrationContext(object):
             inspector_column,
             metadata_column,
             rendered_metadata_default,
-            rendered_column_default)
+            rendered_column_default,
+        )
 
 
 class HeadMaintainer(object):
@@ -467,8 +505,7 @@ class HeadMaintainer(object):
         self.heads.add(version)
 
         self.context.impl._exec(
-            self.context._version.insert().
-            values(
+            self.context._version.insert().values(
                 version_num=literal_column("'%s'" % version)
             )
         )
@@ -478,15 +515,17 @@ class HeadMaintainer(object):
 
         ret = self.context.impl._exec(
             self.context._version.delete().where(
-                self.context._version.c.version_num ==
-                literal_column("'%s'" % version)))
+                self.context._version.c.version_num
+                == literal_column("'%s'" % version)
+            )
+        )
         if not self.context.as_sql and ret.rowcount != 1:
             raise util.CommandError(
                 "Online migration expected to match one "
                 "row when deleting '%s' in '%s'; "
                 "%d found"
-                % (version,
-                   self.context.version_table, ret.rowcount))
+                % (version, self.context.version_table, ret.rowcount)
+            )
 
     def _update_version(self, from_, to_):
         assert to_ not in self.heads
@@ -494,17 +533,20 @@ class HeadMaintainer(object):
         self.heads.add(to_)
 
         ret = self.context.impl._exec(
-            self.context._version.update().
-            values(version_num=literal_column("'%s'" % to_)).where(
+            self.context._version.update()
+            .values(version_num=literal_column("'%s'" % to_))
+            .where(
                 self.context._version.c.version_num
-                == literal_column("'%s'" % from_))
+                == literal_column("'%s'" % from_)
+            )
         )
         if not self.context.as_sql and ret.rowcount != 1:
             raise util.CommandError(
                 "Online migration expected to match one "
                 "row when updating '%s' to '%s' in '%s'; "
                 "%d found"
-                % (from_, to_, self.context.version_table, ret.rowcount))
+                % (from_, to_, self.context.version_table, ret.rowcount)
+            )
 
     def update_to_step(self, step):
         if step.should_delete_branch(self.heads):
@@ -517,20 +559,32 @@ class HeadMaintainer(object):
             self._insert_version(vers)
         elif step.should_merge_branches(self.heads):
             # delete revs, update from rev, update to rev
-            (delete_revs, update_from_rev,
-             update_to_rev) = step.merge_branch_idents(self.heads)
+            (
+                delete_revs,
+                update_from_rev,
+                update_to_rev,
+            ) = step.merge_branch_idents(self.heads)
             log.debug(
                 "merge, delete %s, update %s to %s",
-                delete_revs, update_from_rev, update_to_rev)
+                delete_revs,
+                update_from_rev,
+                update_to_rev,
+            )
             for delrev in delete_revs:
                 self._delete_version(delrev)
             self._update_version(update_from_rev, update_to_rev)
         elif step.should_unmerge_branches(self.heads):
-            (update_from_rev, update_to_rev,
-             insert_revs) = step.unmerge_branch_idents(self.heads)
+            (
+                update_from_rev,
+                update_to_rev,
+                insert_revs,
+            ) = step.unmerge_branch_idents(self.heads)
             log.debug(
                 "unmerge, insert %s, update %s to %s",
-                insert_revs, update_from_rev, update_to_rev)
+                insert_revs,
+                update_from_rev,
+                update_to_rev,
+            )
             for insrev in insert_revs:
                 self._insert_version(insrev)
             self._update_version(update_from_rev, update_to_rev)
@@ -597,8 +651,9 @@ class MigrationInfo(object):
     revision_map = None
     """The revision map inside of which this operation occurs."""
 
-    def __init__(self, revision_map, is_upgrade, is_stamp, up_revisions,
-                 down_revisions):
+    def __init__(
+        self, revision_map, is_upgrade, is_stamp, up_revisions, down_revisions
+    ):
         self.revision_map = revision_map
         self.is_upgrade = is_upgrade
         self.is_stamp = is_stamp
@@ -625,14 +680,16 @@ class MigrationInfo(object):
     @property
     def source_revision_ids(self):
         """Active revisions before this migration step is applied."""
-        return self.down_revision_ids if self.is_upgrade \
-            else self.up_revision_ids
+        return (
+            self.down_revision_ids if self.is_upgrade else self.up_revision_ids
+        )
 
     @property
     def destination_revision_ids(self):
         """Active revisions after this migration step is applied."""
-        return self.up_revision_ids if self.is_upgrade \
-            else self.down_revision_ids
+        return (
+            self.up_revision_ids if self.is_upgrade else self.down_revision_ids
+        )
 
     @property
     def up_revision(self):
@@ -689,7 +746,7 @@ class MigrationStep(object):
         return "%s %s -> %s" % (
             self.name,
             util.format_as_comma(self.from_revisions_no_deps),
-            util.format_as_comma(self.to_revisions_no_deps)
+            util.format_as_comma(self.to_revisions_no_deps),
         )
 
     def __str__(self):
@@ -698,7 +755,7 @@ class MigrationStep(object):
                 self.name,
                 util.format_as_comma(self.from_revisions_no_deps),
                 util.format_as_comma(self.to_revisions_no_deps),
-                self.doc
+                self.doc,
             )
         else:
             return self.short_log
@@ -716,13 +773,16 @@ class RevisionStep(MigrationStep):
 
     def __repr__(self):
         return "RevisionStep(%r, is_upgrade=%r)" % (
-            self.revision.revision, self.is_upgrade
+            self.revision.revision,
+            self.is_upgrade,
         )
 
     def __eq__(self, other):
-        return isinstance(other, RevisionStep) and \
-            other.revision == self.revision and \
-            self.is_upgrade == other.is_upgrade
+        return (
+            isinstance(other, RevisionStep)
+            and other.revision == self.revision
+            and self.is_upgrade == other.is_upgrade
+        )
 
     @property
     def doc(self):
@@ -733,26 +793,26 @@ class RevisionStep(MigrationStep):
         if self.is_upgrade:
             return self.revision._all_down_revisions
         else:
-            return (self.revision.revision, )
+            return (self.revision.revision,)
 
     @property
     def from_revisions_no_deps(self):
         if self.is_upgrade:
             return self.revision._versioned_down_revisions
         else:
-            return (self.revision.revision, )
+            return (self.revision.revision,)
 
     @property
     def to_revisions(self):
         if self.is_upgrade:
-            return (self.revision.revision, )
+            return (self.revision.revision,)
         else:
             return self.revision._all_down_revisions
 
     @property
     def to_revisions_no_deps(self):
         if self.is_upgrade:
-            return (self.revision.revision, )
+            return (self.revision.revision,)
         else:
             return self.revision._versioned_down_revisions
 
@@ -788,31 +848,31 @@ class RevisionStep(MigrationStep):
 
         if other_heads:
             ancestors = set(
-                r.revision for r in
-                self.revision_map._get_ancestor_nodes(
-                    self.revision_map.get_revisions(other_heads),
-                    check=False
+                r.revision
+                for r in self.revision_map._get_ancestor_nodes(
+                    self.revision_map.get_revisions(other_heads), check=False
                 )
             )
             from_revisions = list(
-                set(self.from_revisions).difference(ancestors))
+                set(self.from_revisions).difference(ancestors)
+            )
         else:
             from_revisions = list(self.from_revisions)
 
         return (
             # delete revs, update from rev, update to rev
-            list(from_revisions[0:-1]), from_revisions[-1],
-            self.to_revisions[0]
+            list(from_revisions[0:-1]),
+            from_revisions[-1],
+            self.to_revisions[0],
         )
 
     def _unmerge_to_revisions(self, heads):
         other_heads = set(heads).difference([self.revision.revision])
         if other_heads:
             ancestors = set(
-                r.revision for r in
-                self.revision_map._get_ancestor_nodes(
-                    self.revision_map.get_revisions(other_heads),
-                    check=False
+                r.revision
+                for r in self.revision_map._get_ancestor_nodes(
+                    self.revision_map.get_revisions(other_heads), check=False
                 )
             )
             return list(set(self.to_revisions).difference(ancestors))
@@ -824,8 +884,9 @@ class RevisionStep(MigrationStep):
 
         return (
             # update from rev, update to rev, insert revs
-            self.from_revisions[0], to_revisions[-1],
-            to_revisions[0:-1]
+            self.from_revisions[0],
+            to_revisions[-1],
+            to_revisions[0:-1],
         )
 
     def should_create_branch(self, heads):
@@ -853,8 +914,7 @@ class RevisionStep(MigrationStep):
 
         downrevs = self.revision._all_down_revisions
 
-        if len(downrevs) > 1 and \
-                len(heads.intersection(downrevs)) > 1:
+        if len(downrevs) > 1 and len(heads.intersection(downrevs)) > 1:
             return True
 
         return False
@@ -873,8 +933,9 @@ class RevisionStep(MigrationStep):
     def update_version_num(self, heads):
         if not self._has_scalar_down_revision:
             downrev = heads.intersection(self.revision._all_down_revisions)
-            assert len(downrev) == 1, \
-                "Can't do an UPDATE because downrevision is ambiguous"
+            assert (
+                len(downrev) == 1
+            ), "Can't do an UPDATE because downrevision is ambiguous"
             down_revision = list(downrev)[0]
         else:
             down_revision = self.revision._all_down_revisions[0]
@@ -894,10 +955,13 @@ class RevisionStep(MigrationStep):
 
     @property
     def info(self):
-        return MigrationInfo(revision_map=self.revision_map,
-                             up_revisions=self.revision.revision,
-                             down_revisions=self.revision._all_down_revisions,
-                             is_upgrade=self.is_upgrade, is_stamp=False)
+        return MigrationInfo(
+            revision_map=self.revision_map,
+            up_revisions=self.revision.revision,
+            down_revisions=self.revision._all_down_revisions,
+            is_upgrade=self.is_upgrade,
+            is_stamp=False,
+        )
 
 
 class StampStep(MigrationStep):
@@ -915,11 +979,13 @@ class StampStep(MigrationStep):
         return None
 
     def __eq__(self, other):
-        return isinstance(other, StampStep) and \
-            other.from_revisions == self.revisions and \
-            other.to_revisions == self.to_revisions and \
-            other.branch_move == self.branch_move and \
-            self.is_upgrade == other.is_upgrade
+        return (
+            isinstance(other, StampStep)
+            and other.from_revisions == self.revisions
+            and other.to_revisions == self.to_revisions
+            and other.branch_move == self.branch_move
+            and self.is_upgrade == other.is_upgrade
+        )
 
     @property
     def from_revisions(self):
@@ -955,15 +1021,17 @@ class StampStep(MigrationStep):
     def merge_branch_idents(self, heads):
         return (
             # delete revs, update from rev, update to rev
-            list(self.from_[0:-1]), self.from_[-1],
-            self.to_[0]
+            list(self.from_[0:-1]),
+            self.from_[-1],
+            self.to_[0],
         )
 
     def unmerge_branch_idents(self, heads):
         return (
             # update from rev, update to rev, insert revs
-            self.from_[0], self.to_[-1],
-            list(self.to_[0:-1])
+            self.from_[0],
+            self.to_[-1],
+            list(self.to_[0:-1]),
         )
 
     def should_delete_branch(self, heads):
@@ -980,10 +1048,15 @@ class StampStep(MigrationStep):
 
     @property
     def info(self):
-        up, down = (self.to_, self.from_) if self.is_upgrade \
+        up, down = (
+            (self.to_, self.from_)
+            if self.is_upgrade
             else (self.from_, self.to_)
-        return MigrationInfo(revision_map=self.revision_map,
-                             up_revisions=up,
-                             down_revisions=down,
-                             is_upgrade=self.is_upgrade,
-                             is_stamp=True)
+        )
+        return MigrationInfo(
+            revision_map=self.revision_map,
+            up_revisions=up,
+            down_revisions=down,
+            is_upgrade=self.is_upgrade,
+            is_stamp=True,
+        )
index cae294f8220757c224ea4da3a079d659743ff247..65562b48c0e872626465a0d6f804d2779971ecf6 100644 (file)
@@ -1,3 +1,3 @@
 from .base import ScriptDirectory, Script  # noqa
 
-__all__ = ['ScriptDirectory', 'Script']
+__all__ = ["ScriptDirectory", "Script"]
index 12e15108a0b140177965a7a6a30e8048d28a207d..1c63e086197f33107ace9a0c7a0ae3eef531d2cf 100644 (file)
@@ -10,13 +10,13 @@ from ..runtime import migration
 
 from contextlib import contextmanager
 
-_sourceless_rev_file = re.compile(r'(?!\.\#|__init__)(.*\.py)(c|o)?$')
-_only_source_rev_file = re.compile(r'(?!\.\#|__init__)(.*\.py)$')
-_legacy_rev = re.compile(r'([a-f0-9]+)\.py$')
-_mod_def_re = re.compile(r'(upgrade|downgrade)_([a-z0-9]+)')
-_slug_re = re.compile(r'\w+')
+_sourceless_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)(c|o)?$")
+_only_source_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)$")
+_legacy_rev = re.compile(r"([a-f0-9]+)\.py$")
+_mod_def_re = re.compile(r"(upgrade|downgrade)_([a-z0-9]+)")
+_slug_re = re.compile(r"\w+")
 _default_file_template = "%(rev)s_%(slug)s"
-_split_on_space_comma = re.compile(r',|(?: +)')
+_split_on_space_comma = re.compile(r",|(?: +)")
 
 
 class ScriptDirectory(object):
@@ -40,11 +40,16 @@ class ScriptDirectory(object):
 
     """
 
-    def __init__(self, dir, file_template=_default_file_template,
-                 truncate_slug_length=40,
-                 version_locations=None,
-                 sourceless=False, output_encoding="utf-8",
-                 timezone=None):
+    def __init__(
+        self,
+        dir,
+        file_template=_default_file_template,
+        truncate_slug_length=40,
+        version_locations=None,
+        sourceless=False,
+        output_encoding="utf-8",
+        timezone=None,
+    ):
         self.dir = dir
         self.file_template = file_template
         self.version_locations = version_locations
@@ -55,9 +60,11 @@ class ScriptDirectory(object):
         self.timezone = timezone
 
         if not os.access(dir, os.F_OK):
-            raise util.CommandError("Path doesn't exist: %r.  Please use "
-                                    "the 'init' command to create a new "
-                                    "scripts folder." % dir)
+            raise util.CommandError(
+                "Path doesn't exist: %r.  Please use "
+                "the 'init' command to create a new "
+                "scripts folder." % dir
+            )
 
     @property
     def versions(self):
@@ -75,13 +82,15 @@ class ScriptDirectory(object):
                 for location in self.version_locations
             ]
         else:
-            return (os.path.abspath(os.path.join(self.dir, 'versions')),)
+            return (os.path.abspath(os.path.join(self.dir, "versions")),)
 
     def _load_revisions(self):
         if self.version_locations:
             paths = [
-                vers for vers in self._version_locations
-                if os.path.exists(vers)]
+                vers
+                for vers in self._version_locations
+                if os.path.exists(vers)
+            ]
         else:
             paths = [self.versions]
 
@@ -110,10 +119,11 @@ class ScriptDirectory(object):
         present.
 
         """
-        script_location = config.get_main_option('script_location')
+        script_location = config.get_main_option("script_location")
         if script_location is None:
-            raise util.CommandError("No 'script_location' key "
-                                    "found in configuration.")
+            raise util.CommandError(
+                "No 'script_location' key " "found in configuration."
+            )
         truncate_slug_length = config.get_main_option("truncate_slug_length")
         if truncate_slug_length is not None:
             truncate_slug_length = int(truncate_slug_length)
@@ -125,20 +135,24 @@ class ScriptDirectory(object):
         return ScriptDirectory(
             util.coerce_resource_to_filename(script_location),
             file_template=config.get_main_option(
-                'file_template',
-                _default_file_template),
+                "file_template", _default_file_template
+            ),
             truncate_slug_length=truncate_slug_length,
             sourceless=config.get_main_option("sourceless") == "true",
             output_encoding=config.get_main_option("output_encoding", "utf-8"),
             version_locations=version_locations,
-            timezone=config.get_main_option("timezone")
+            timezone=config.get_main_option("timezone"),
         )
 
     @contextmanager
     def _catch_revision_errors(
-            self,
-            ancestor=None, multiple_heads=None, start=None, end=None,
-            resolution=None):
+        self,
+        ancestor=None,
+        multiple_heads=None,
+        start=None,
+        end=None,
+        resolution=None,
+    ):
         try:
             yield
         except revision.RangeNotAncestorError as rna:
@@ -160,10 +174,11 @@ class ScriptDirectory(object):
                     "argument '%(head_arg)s'; please "
                     "specify a specific target revision, "
                     "'<branchname>@%(head_arg)s' to "
-                    "narrow to a specific head, or 'heads' for all heads")
+                    "narrow to a specific head, or 'heads' for all heads"
+                )
             multiple_heads = multiple_heads % {
                 "head_arg": end or mh.argument,
-                "heads": util.format_as_comma(mh.heads)
+                "heads": util.format_as_comma(mh.heads),
             }
             compat.raise_from_cause(util.CommandError(multiple_heads))
         except revision.ResolutionError as re:
@@ -192,7 +207,8 @@ class ScriptDirectory(object):
         """
         with self._catch_revision_errors(start=base, end=head):
             for rev in self.revision_map.iterate_revisions(
-                    head, base, inclusive=True, assert_relative_length=False):
+                head, base, inclusive=True, assert_relative_length=False
+            ):
                 yield rev
 
     def get_revisions(self, id_):
@@ -210,7 +226,8 @@ class ScriptDirectory(object):
             top_revs = set(self.revision_map.get_revisions(id_))
             top_revs.update(
                 self.revision_map._get_ancestor_nodes(
-                    list(top_revs), include_dependencies=True)
+                    list(top_revs), include_dependencies=True
+                )
             )
             top_revs = self.revision_map._filter_into_branch_heads(top_revs)
             return top_revs
@@ -275,11 +292,13 @@ class ScriptDirectory(object):
             :meth:`.ScriptDirectory.get_heads`
 
         """
-        with self._catch_revision_errors(multiple_heads=(
-                'The script directory has multiple heads (due to branching).'
-                'Please use get_heads(), or merge the branches using '
-                'alembic merge.'
-        )):
+        with self._catch_revision_errors(
+            multiple_heads=(
+                "The script directory has multiple heads (due to branching)."
+                "Please use get_heads(), or merge the branches using "
+                "alembic merge."
+            )
+        ):
             return self.revision_map.get_current_head()
 
     def get_heads(self):
@@ -310,7 +329,8 @@ class ScriptDirectory(object):
         if len(bases) > 1:
             raise util.CommandError(
                 "The script directory has multiple bases. "
-                "Please use get_bases().")
+                "Please use get_bases()."
+            )
         elif bases:
             return bases[0]
         else:
@@ -329,40 +349,50 @@ class ScriptDirectory(object):
 
     def _upgrade_revs(self, destination, current_rev):
         with self._catch_revision_errors(
-                ancestor="Destination %(end)s is not a valid upgrade "
-                "target from current head(s)", end=destination):
+            ancestor="Destination %(end)s is not a valid upgrade "
+            "target from current head(s)",
+            end=destination,
+        ):
             revs = self.revision_map.iterate_revisions(
-                destination, current_rev, implicit_base=True)
+                destination, current_rev, implicit_base=True
+            )
             revs = list(revs)
             return [
                 migration.MigrationStep.upgrade_from_script(
-                    self.revision_map, script)
+                    self.revision_map, script
+                )
                 for script in reversed(list(revs))
             ]
 
     def _downgrade_revs(self, destination, current_rev):
         with self._catch_revision_errors(
-                ancestor="Destination %(end)s is not a valid downgrade "
-                "target from current head(s)", end=destination):
+            ancestor="Destination %(end)s is not a valid downgrade "
+            "target from current head(s)",
+            end=destination,
+        ):
             revs = self.revision_map.iterate_revisions(
-                current_rev, destination, select_for_downgrade=True)
+                current_rev, destination, select_for_downgrade=True
+            )
             return [
                 migration.MigrationStep.downgrade_from_script(
-                    self.revision_map, script)
+                    self.revision_map, script
+                )
                 for script in revs
             ]
 
     def _stamp_revs(self, revision, heads):
         with self._catch_revision_errors(
-                multiple_heads="Multiple heads are present; please specify a "
-                "single target revision"):
+            multiple_heads="Multiple heads are present; please specify a "
+            "single target revision"
+        ):
 
             heads = self.get_revisions(heads)
 
             # filter for lineage will resolve things like
             # branchname@base, version@base, etc.
             filtered_heads = self.revision_map.filter_for_lineage(
-                heads, revision, include_dependencies=True)
+                heads, revision, include_dependencies=True
+            )
 
             steps = []
 
@@ -371,11 +401,18 @@ class ScriptDirectory(object):
                 if dest is None:
                     # dest is 'base'.  Return a "delete branch" migration
                     # for all applicable heads.
-                    steps.extend([
-                        migration.StampStep(head.revision, None, False, True,
-                                            self.revision_map)
-                        for head in filtered_heads
-                    ])
+                    steps.extend(
+                        [
+                            migration.StampStep(
+                                head.revision,
+                                None,
+                                False,
+                                True,
+                                self.revision_map,
+                            )
+                            for head in filtered_heads
+                        ]
+                    )
                     continue
                 elif dest in filtered_heads:
                     # the dest is already in the version table, do nothing.
@@ -384,7 +421,8 @@ class ScriptDirectory(object):
                 # figure out if the dest is a descendant or an
                 # ancestor of the selected nodes
                 descendants = set(
-                    self.revision_map._get_descendant_nodes([dest]))
+                    self.revision_map._get_descendant_nodes([dest])
+                )
                 ancestors = set(self.revision_map._get_ancestor_nodes([dest]))
 
                 if descendants.intersection(filtered_heads):
@@ -393,8 +431,12 @@ class ScriptDirectory(object):
                     assert not ancestors.intersection(filtered_heads)
                     todo_heads = [head.revision for head in filtered_heads]
                     step = migration.StampStep(
-                        todo_heads, dest.revision, False, False,
-                        self.revision_map)
+                        todo_heads,
+                        dest.revision,
+                        False,
+                        False,
+                        self.revision_map,
+                    )
                     steps.append(step)
                     continue
                 elif ancestors.intersection(filtered_heads):
@@ -402,15 +444,20 @@ class ScriptDirectory(object):
                     # we can treat them as a "merge", single step.
                     todo_heads = [head.revision for head in filtered_heads]
                     step = migration.StampStep(
-                        todo_heads, dest.revision, True, False,
-                        self.revision_map)
+                        todo_heads,
+                        dest.revision,
+                        True,
+                        False,
+                        self.revision_map,
+                    )
                     steps.append(step)
                     continue
                 else:
                     # destination is in a branch not represented,
                     # treat it as new branch
-                    step = migration.StampStep((), dest.revision, True, True,
-                                               self.revision_map)
+                    step = migration.StampStep(
+                        (), dest.revision, True, True, self.revision_map
+                    )
                     steps.append(step)
                     continue
             return steps
@@ -424,32 +471,31 @@ class ScriptDirectory(object):
 
 
         """
-        util.load_python_file(self.dir, 'env.py')
+        util.load_python_file(self.dir, "env.py")
 
     @property
     def env_py_location(self):
         return os.path.abspath(os.path.join(self.dir, "env.py"))
 
     def _generate_template(self, src, dest, **kw):
-        util.status("Generating %s" % os.path.abspath(dest),
-                    util.template_to_file,
-                    src,
-                    dest,
-                    self.output_encoding,
-                    **kw
-                    )
+        util.status(
+            "Generating %s" % os.path.abspath(dest),
+            util.template_to_file,
+            src,
+            dest,
+            self.output_encoding,
+            **kw
+        )
 
     def _copy_file(self, src, dest):
-        util.status("Generating %s" % os.path.abspath(dest),
-                    shutil.copy,
-                    src, dest)
+        util.status(
+            "Generating %s" % os.path.abspath(dest), shutil.copy, src, dest
+        )
 
     def _ensure_directory(self, path):
         path = os.path.abspath(path)
         if not os.path.exists(path):
-            util.status(
-                "Creating directory %s" % path,
-                os.makedirs, path)
+            util.status("Creating directory %s" % path, os.makedirs, path)
 
     def _generate_create_date(self):
         if self.timezone is not None:
@@ -460,17 +506,29 @@ class ScriptDirectory(object):
                 tzinfo = tz.gettz(self.timezone.upper())
             if tzinfo is None:
                 raise util.CommandError(
-                    "Can't locate timezone: %s" % self.timezone)
-            create_date = datetime.datetime.utcnow().replace(
-                tzinfo=tz.tzutc()).astimezone(tzinfo)
+                    "Can't locate timezone: %s" % self.timezone
+                )
+            create_date = (
+                datetime.datetime.utcnow()
+                .replace(tzinfo=tz.tzutc())
+                .astimezone(tzinfo)
+            )
         else:
             create_date = datetime.datetime.now()
         return create_date
 
     def generate_revision(
-            self, revid, message, head=None,
-            refresh=False, splice=False, branch_labels=None,
-            version_path=None, depends_on=None, **kw):
+        self,
+        revid,
+        message,
+        head=None,
+        refresh=False,
+        splice=False,
+        branch_labels=None,
+        version_path=None,
+        depends_on=None,
+        **kw
+    ):
         """Generate a new revision file.
 
         This runs the ``script.py.mako`` template, given
@@ -500,11 +558,13 @@ class ScriptDirectory(object):
         except revision.RevisionError as err:
             compat.raise_from_cause(util.CommandError(err.args[0]))
 
-        with self._catch_revision_errors(multiple_heads=(
-            "Multiple heads are present; please specify the head "
-            "revision on which the new revision should be based, "
-            "or perform a merge."
-        )):
+        with self._catch_revision_errors(
+            multiple_heads=(
+                "Multiple heads are present; please specify the head "
+                "revision on which the new revision should be based, "
+                "or perform a merge."
+            )
+        ):
             heads = self.revision_map.get_revisions(head)
 
         if len(set(heads)) != len(heads):
@@ -521,7 +581,8 @@ class ScriptDirectory(object):
                 else:
                     raise util.CommandError(
                         "Multiple version locations present, "
-                        "please specify --version-path")
+                        "please specify --version-path"
+                    )
             else:
                 version_path = self.versions
 
@@ -532,7 +593,8 @@ class ScriptDirectory(object):
         else:
             raise util.CommandError(
                 "Path %s is not represented in current "
-                "version locations" % version_path)
+                "version locations" % version_path
+            )
 
         if self.version_locations:
             self._ensure_directory(version_path)
@@ -545,7 +607,8 @@ class ScriptDirectory(object):
                     raise util.CommandError(
                         "Revision %s is not a head revision; please specify "
                         "--splice to create a new branch from this revision"
-                        % head.revision)
+                        % head.revision
+                    )
 
         if depends_on:
             with self._catch_revision_errors():
@@ -557,7 +620,6 @@ class ScriptDirectory(object):
                         (self.revision_map.get_revision(dep), dep)
                         for dep in util.to_list(depends_on)
                     ]
-
                 ]
 
         self._generate_template(
@@ -565,7 +627,8 @@ class ScriptDirectory(object):
             path,
             up_revision=str(revid),
             down_revision=revision.tuple_rev_as_scalar(
-                tuple(h.revision if h is not None else None for h in heads)),
+                tuple(h.revision if h is not None else None for h in heads)
+            ),
             branch_labels=util.to_tuple(branch_labels),
             depends_on=revision.tuple_rev_as_scalar(depends_on),
             create_date=create_date,
@@ -582,9 +645,9 @@ class ScriptDirectory(object):
                 "Version %s specified branch_labels %s, however the "
                 "migration file %s does not have them; have you upgraded "
                 "your script.py.mako to include the "
-                "'branch_labels' section?" % (
-                    script.revision, branch_labels, script.path
-                ))
+                "'branch_labels' section?"
+                % (script.revision, branch_labels, script.path)
+            )
 
         self.revision_map.add_revision(script)
         return script
@@ -592,17 +655,18 @@ class ScriptDirectory(object):
     def _rev_path(self, path, rev_id, message, create_date):
         slug = "_".join(_slug_re.findall(message or "")).lower()
         if len(slug) > self.truncate_slug_length:
-            slug = slug[:self.truncate_slug_length].rsplit('_', 1)[0] + '_'
+            slug = slug[: self.truncate_slug_length].rsplit("_", 1)[0] + "_"
         filename = "%s.py" % (
-            self.file_template % {
-                'rev': rev_id,
-                'slug': slug,
-                'year': create_date.year,
-                'month': create_date.month,
-                'day': create_date.day,
-                'hour': create_date.hour,
-                'minute': create_date.minute,
-                'second': create_date.second
+            self.file_template
+            % {
+                "rev": rev_id,
+                "slug": slug,
+                "year": create_date.year,
+                "month": create_date.month,
+                "day": create_date.day,
+                "hour": create_date.hour,
+                "minute": create_date.minute,
+                "second": create_date.second,
             }
         )
         return os.path.join(path, filename)
@@ -624,9 +688,11 @@ class Script(revision.Revision):
             rev_id,
             module.down_revision,
             branch_labels=util.to_tuple(
-                getattr(module, 'branch_labels', None), default=()),
+                getattr(module, "branch_labels", None), default=()
+            ),
             dependencies=util.to_tuple(
-                getattr(module, 'depends_on', None), default=())
+                getattr(module, "depends_on", None), default=()
+            ),
         )
 
     module = None
@@ -664,32 +730,32 @@ class Script(revision.Revision):
             " (head)" if self.is_head else "",
             " (branchpoint)" if self.is_branch_point else "",
             " (mergepoint)" if self.is_merge_point else "",
-            " (current)" if self._db_current_indicator else ""
+            " (current)" if self._db_current_indicator else "",
         )
         if self.is_merge_point:
-            entry += "Merges: %s\n" % (self._format_down_revision(), )
+            entry += "Merges: %s\n" % (self._format_down_revision(),)
         else:
-            entry += "Parent: %s\n" % (self._format_down_revision(), )
+            entry += "Parent: %s\n" % (self._format_down_revision(),)
 
         if self.dependencies:
             entry += "Also depends on: %s\n" % (
-                util.format_as_comma(self.dependencies))
+                util.format_as_comma(self.dependencies)
+            )
 
         if self.is_branch_point:
             entry += "Branches into: %s\n" % (
-                util.format_as_comma(self.nextrev))
+                util.format_as_comma(self.nextrev)
+            )
 
         if self.branch_labels:
             entry += "Branch names: %s\n" % (
-                util.format_as_comma(self.branch_labels), )
+                util.format_as_comma(self.branch_labels),
+            )
 
         entry += "Path: %s\n" % (self.path,)
 
         entry += "\n%s\n" % (
-            "\n".join(
-                "    %s" % para
-                for para in self.longdoc.splitlines()
-            )
+            "\n".join("    %s" % para for para in self.longdoc.splitlines())
         )
         return entry
 
@@ -700,36 +766,41 @@ class Script(revision.Revision):
             " (head)" if self.is_head else "",
             " (branchpoint)" if self.is_branch_point else "",
             " (mergepoint)" if self.is_merge_point else "",
-            self.doc)
+            self.doc,
+        )
 
     def _head_only(
-            self, include_branches=False, include_doc=False,
-            include_parents=False, tree_indicators=True,
-            head_indicators=True):
+        self,
+        include_branches=False,
+        include_doc=False,
+        include_parents=False,
+        tree_indicators=True,
+        head_indicators=True,
+    ):
         text = self.revision
         if include_parents:
             if self.dependencies:
                 text = "%s (%s) -> %s" % (
                     self._format_down_revision(),
                     util.format_as_comma(self.dependencies),
-                    text
+                    text,
                 )
             else:
-                text = "%s -> %s" % (
-                    self._format_down_revision(), text)
+                text = "%s -> %s" % (self._format_down_revision(), text)
         if include_branches and self.branch_labels:
             text += " (%s)" % util.format_as_comma(self.branch_labels)
         if head_indicators or tree_indicators:
             text += "%s%s%s" % (
                 " (head)" if self._is_real_head else "",
-                " (effective head)" if self.is_head and
-                    not self._is_real_head else "",
-                " (current)" if self._db_current_indicator else ""
+                " (effective head)"
+                if self.is_head and not self._is_real_head
+                else "",
+                " (current)" if self._db_current_indicator else "",
             )
         if tree_indicators:
             text += "%s%s" % (
                 " (branchpoint)" if self.is_branch_point else "",
-                " (mergepoint)" if self.is_merge_point else ""
+                " (mergepoint)" if self.is_merge_point else "",
             )
         if include_doc:
             text += ", %s" % self.doc
@@ -737,15 +808,18 @@ class Script(revision.Revision):
 
     def cmd_format(
         self,
-            verbose,
-            include_branches=False, include_doc=False,
-            include_parents=False, tree_indicators=True):
+        verbose,
+        include_branches=False,
+        include_doc=False,
+        include_parents=False,
+        tree_indicators=True,
+    ):
         if verbose:
             return self.log_entry
         else:
             return self._head_only(
-                include_branches, include_doc,
-                include_parents, tree_indicators)
+                include_branches, include_doc, include_parents, tree_indicators
+            )
 
     def _format_down_revision(self):
         if not self.down_revision:
@@ -768,13 +842,13 @@ class Script(revision.Revision):
             names = set(fname.split(".")[0] for fname in paths)
 
             # look for __pycache__
-            if os.path.exists(os.path.join(path, '__pycache__')):
+            if os.path.exists(os.path.join(path, "__pycache__")):
                 # add all files from __pycache__ whose filename is not
                 # already in the names we got from the version directory.
                 # add as relative paths including __pycache__ token
                 paths.extend(
-                    os.path.join('__pycache__', pyc)
-                    for pyc in os.listdir(os.path.join(path, '__pycache__'))
+                    os.path.join("__pycache__", pyc)
+                    for pyc in os.listdir(os.path.join(path, "__pycache__"))
                     if pyc.split(".")[0] not in names
                 )
             return paths
@@ -794,8 +868,8 @@ class Script(revision.Revision):
         py_filename = py_match.group(1)
 
         if scriptdir.sourceless:
-            is_c = py_match.group(2) == 'c'
-            is_o = py_match.group(2) == 'o'
+            is_c = py_match.group(2) == "c"
+            is_o = py_match.group(2) == "o"
         else:
             is_c = is_o = False
 
@@ -821,7 +895,8 @@ class Script(revision.Revision):
                     "Be sure the 'revision' variable is "
                     "declared inside the script (please see 'Upgrading "
                     "from Alembic 0.1 to 0.2' in the documentation)."
-                    % filename)
+                    % filename
+                )
             else:
                 revision = m.group(1)
         else:
index 3d9a332de59684726fefab4e898ebcc15bf3b7ef..832cce1579538d292f09563aafe2176f2dc562e1 100644 (file)
@@ -5,8 +5,8 @@ from .. import util
 from sqlalchemy import util as sqlautil
 from ..util import compat
 
-_relative_destination = re.compile(r'(?:(.+?)@)?(\w+)?((?:\+|-)\d+)')
-_revision_illegal_chars = ['@', '-', '+']
+_relative_destination = re.compile(r"(?:(.+?)@)?(\w+)?((?:\+|-)\d+)")
+_revision_illegal_chars = ["@", "-", "+"]
 
 
 class RevisionError(Exception):
@@ -18,8 +18,8 @@ class RangeNotAncestorError(RevisionError):
         self.lower = lower
         self.upper = upper
         super(RangeNotAncestorError, self).__init__(
-            "Revision %s is not an ancestor of revision %s" %
-            (lower or "base", upper or "base")
+            "Revision %s is not an ancestor of revision %s"
+            (lower or "base", upper or "base")
         )
 
 
@@ -122,8 +122,9 @@ class RevisionMap(object):
         for revision in self._generator():
 
             if revision.revision in map_:
-                util.warn("Revision %s is present more than once" %
-                          revision.revision)
+                util.warn(
+                    "Revision %s is present more than once" % revision.revision
+                )
             map_[revision.revision] = revision
             if revision.branch_labels:
                 has_branch_labels.add(revision)
@@ -132,9 +133,9 @@ class RevisionMap(object):
             heads.add(revision.revision)
             _real_heads.add(revision.revision)
             if revision.is_base:
-                self.bases += (revision.revision, )
+                self.bases += (revision.revision,)
             if revision._is_real_base:
-                self._real_bases += (revision.revision, )
+                self._real_bases += (revision.revision,)
 
         # add the branch_labels to the map_.  We'll need these
         # to resolve the dependencies.
@@ -147,8 +148,10 @@ class RevisionMap(object):
         for rev in map_.values():
             for downrev in rev._all_down_revisions:
                 if downrev not in map_:
-                    util.warn("Revision %s referenced from %s is not present"
-                              % (downrev, rev))
+                    util.warn(
+                        "Revision %s referenced from %s is not present"
+                        % (downrev, rev)
+                    )
                 down_revision = map_[downrev]
                 down_revision.add_nextrev(rev)
                 if downrev in rev._versioned_down_revisions:
@@ -169,9 +172,12 @@ class RevisionMap(object):
                 if branch_label in map_:
                     raise RevisionError(
                         "Branch name '%s' in revision %s already "
-                        "used by revision %s" %
-                        (branch_label, revision.revision,
-                            map_[branch_label].revision)
+                        "used by revision %s"
+                        % (
+                            branch_label,
+                            revision.revision,
+                            map_[branch_label].revision,
+                        )
                     )
                 map_[branch_label] = revision
 
@@ -182,13 +188,16 @@ class RevisionMap(object):
         if revision.branch_labels:
             revision.branch_labels.update(revision.branch_labels)
             for node in self._get_descendant_nodes(
-                    [revision], map_, include_dependencies=False):
+                [revision], map_, include_dependencies=False
+            ):
                 node.branch_labels.update(revision.branch_labels)
 
             parent = node
-            while parent and \
-                    not parent._is_real_branch_point and \
-                    not parent.is_merge_point:
+            while (
+                parent
+                and not parent._is_real_branch_point
+                and not parent.is_merge_point
+            ):
 
                 parent.branch_labels.update(revision.branch_labels)
                 if parent.down_revision:
@@ -201,7 +210,6 @@ class RevisionMap(object):
             deps = [map_[dep] for dep in util.to_tuple(revision.dependencies)]
             revision._resolved_dependencies = tuple([d.revision for d in deps])
 
-
     def add_revision(self, revision, _replace=False):
         """add a single revision to an existing map.
 
@@ -211,8 +219,9 @@ class RevisionMap(object):
         """
         map_ = self._revision_map
         if not _replace and revision.revision in map_:
-            util.warn("Revision %s is present more than once" %
-                      revision.revision)
+            util.warn(
+                "Revision %s is present more than once" % revision.revision
+            )
         elif _replace and revision.revision not in map_:
             raise Exception("revision %s not in map" % revision.revision)
 
@@ -221,9 +230,9 @@ class RevisionMap(object):
         self._add_depends_on(revision, map_)
 
         if revision.is_base:
-            self.bases += (revision.revision, )
+            self.bases += (revision.revision,)
         if revision._is_real_base:
-            self._real_bases += (revision.revision, )
+            self._real_bases += (revision.revision,)
         for downrev in revision._all_down_revisions:
             if downrev not in map_:
                 util.warn(
@@ -233,15 +242,21 @@ class RevisionMap(object):
             map_[downrev].add_nextrev(revision)
         if revision._is_real_head:
             self._real_heads = tuple(
-                head for head in self._real_heads
-                if head not in
-                set(revision._all_down_revisions).union([revision.revision])
+                head
+                for head in self._real_heads
+                if head
+                not in set(revision._all_down_revisions).union(
+                    [revision.revision]
+                )
             ) + (revision.revision,)
         if revision.is_head:
             self.heads = tuple(
-                head for head in self.heads
-                if head not in
-                set(revision._versioned_down_revisions).union([revision.revision])
+                head
+                for head in self.heads
+                if head
+                not in set(revision._versioned_down_revisions).union(
+                    [revision.revision]
+                )
             ) + (revision.revision,)
 
     def get_current_head(self, branch_label=None):
@@ -264,11 +279,14 @@ class RevisionMap(object):
         """
         current_heads = self.heads
         if branch_label:
-            current_heads = self.filter_for_lineage(current_heads, branch_label)
+            current_heads = self.filter_for_lineage(
+                current_heads, branch_label
+            )
         if len(current_heads) > 1:
             raise MultipleHeads(
                 current_heads,
-                "%s@head" % branch_label if branch_label else "head")
+                "%s@head" % branch_label if branch_label else "head",
+            )
 
         if current_heads:
             return current_heads[0]
@@ -301,7 +319,8 @@ class RevisionMap(object):
             resolved_id, branch_label = self._resolve_revision_number(id_)
             return tuple(
                 self._revision_for_ident(rev_id, branch_label)
-                for rev_id in resolved_id)
+                for rev_id in resolved_id
+            )
 
     def get_revision(self, id_):
         """Return the :class:`.Revision` instance with the given rev id.
@@ -333,7 +352,8 @@ class RevisionMap(object):
                 nonbranch_rev = self._revision_for_ident(branch_label)
             except ResolutionError:
                 raise ResolutionError(
-                    "No such branch: '%s'" % branch_label, branch_label)
+                    "No such branch: '%s'" % branch_label, branch_label
+                )
             else:
                 return nonbranch_rev
         else:
@@ -352,30 +372,37 @@ class RevisionMap(object):
             revision = False
         if revision is False:
             # do a partial lookup
-            revs = [x for x in self._revision_map
-                    if x and x.startswith(resolved_id)]
+            revs = [
+                x
+                for x in self._revision_map
+                if x and x.startswith(resolved_id)
+            ]
             if branch_rev:
                 revs = self.filter_for_lineage(revs, check_branch)
             if not revs:
                 raise ResolutionError(
                     "No such revision or branch '%s'" % resolved_id,
-                    resolved_id)
+                    resolved_id,
+                )
             elif len(revs) > 1:
                 raise ResolutionError(
                     "Multiple revisions start "
-                    "with '%s': %s..." % (
-                        resolved_id,
-                        ", ".join("'%s'" % r for r in revs[0:3])
-                    ), resolved_id)
+                    "with '%s': %s..."
+                    % (resolved_id, ", ".join("'%s'" % r for r in revs[0:3])),
+                    resolved_id,
+                )
             else:
                 revision = self._revision_map[revs[0]]
 
         if check_branch and revision is not None:
             if not self._shares_lineage(
-                    revision.revision, branch_rev.revision):
+                revision.revision, branch_rev.revision
+            ):
                 raise ResolutionError(
-                    "Revision %s is not a member of branch '%s'" %
-                    (revision.revision, check_branch), resolved_id)
+                    "Revision %s is not a member of branch '%s'"
+                    % (revision.revision, check_branch),
+                    resolved_id,
+                )
         return revision
 
     def _filter_into_branch_heads(self, targets):
@@ -383,14 +410,14 @@ class RevisionMap(object):
 
         for rev in list(targets):
             if targets.intersection(
-                self._get_descendant_nodes(
-                    [rev], include_dependencies=False)).\
-                    difference([rev]):
+                self._get_descendant_nodes([rev], include_dependencies=False)
+            ).difference([rev]):
                 targets.discard(rev)
         return targets
 
     def filter_for_lineage(
-            self, targets, check_against, include_dependencies=False):
+        self, targets, check_against, include_dependencies=False
+    ):
         id_, branch_label = self._resolve_revision_number(check_against)
 
         shares = []
@@ -400,12 +427,16 @@ class RevisionMap(object):
             shares.extend(id_)
 
         return [
-            tg for tg in targets
+            tg
+            for tg in targets
             if self._shares_lineage(
-                tg, shares, include_dependencies=include_dependencies)]
+                tg, shares, include_dependencies=include_dependencies
+            )
+        ]
 
     def _shares_lineage(
-            self, target, test_against_revs, include_dependencies=False):
+        self, target, test_against_revs, include_dependencies=False
+    ):
         if not test_against_revs:
             return True
         if not isinstance(target, Revision):
@@ -415,46 +446,61 @@ class RevisionMap(object):
             self._revision_for_ident(test_against_rev)
             if not isinstance(test_against_rev, Revision)
             else test_against_rev
-            for test_against_rev
-            in util.to_tuple(test_against_revs, default=())
+            for test_against_rev in util.to_tuple(
+                test_against_revs, default=()
+            )
         ]
 
         return bool(
-            set(self._get_descendant_nodes([target],
-                include_dependencies=include_dependencies))
-            .union(self._get_ancestor_nodes([target],
-                   include_dependencies=include_dependencies))
+            set(
+                self._get_descendant_nodes(
+                    [target], include_dependencies=include_dependencies
+                )
+            )
+            .union(
+                self._get_ancestor_nodes(
+                    [target], include_dependencies=include_dependencies
+                )
+            )
             .intersection(test_against_revs)
         )
 
     def _resolve_revision_number(self, id_):
         if isinstance(id_, compat.string_types) and "@" in id_:
-            branch_label, id_ = id_.split('@', 1)
+            branch_label, id_ = id_.split("@", 1)
         else:
             branch_label = None
 
         # ensure map is loaded
         self._revision_map
-        if id_ == 'heads':
+        if id_ == "heads":
             if branch_label:
-                return self.filter_for_lineage(
-                    self.heads, branch_label), branch_label
+                return (
+                    self.filter_for_lineage(self.heads, branch_label),
+                    branch_label,
+                )
             else:
                 return self._real_heads, branch_label
-        elif id_ == 'head':
+        elif id_ == "head":
             current_head = self.get_current_head(branch_label)
             if current_head:
-                return (current_head, ), branch_label
+                return (current_head,), branch_label
             else:
                 return (), branch_label
-        elif id_ == 'base' or id_ is None:
+        elif id_ == "base" or id_ is None:
             return (), branch_label
         else:
             return util.to_tuple(id_, default=None), branch_label
 
     def _relative_iterate(
-            self, destination, source, is_upwards,
-            implicit_base, inclusive, assert_relative_length):
+        self,
+        destination,
+        source,
+        is_upwards,
+        implicit_base,
+        inclusive,
+        assert_relative_length,
+    ):
         if isinstance(destination, compat.string_types):
             match = _relative_destination.match(destination)
             if not match:
@@ -490,13 +536,15 @@ class RevisionMap(object):
 
         revs = list(
             self._iterate_revisions(
-                from_, to_,
-                inclusive=inclusive, implicit_base=implicit_base))
+                from_, to_, inclusive=inclusive, implicit_base=implicit_base
+            )
+        )
 
         if symbol:
             if branch_label:
                 symbol_rev = self.get_revision(
-                    "%s@%s" % (branch_label, symbol))
+                    "%s@%s" % (branch_label, symbol)
+                )
             else:
                 symbol_rev = self.get_revision(symbol)
             if symbol.startswith("head"):
@@ -513,25 +561,39 @@ class RevisionMap(object):
         else:
             index = 0
         if is_upwards:
-            revs = revs[index - relative - reldelta:]
-            if not index and assert_relative_length and \
-                    len(revs) < abs(relative - reldelta):
+            revs = revs[index - relative - reldelta :]
+            if (
+                not index
+                and assert_relative_length
+                and len(revs) < abs(relative - reldelta)
+            ):
                 raise RevisionError(
                     "Relative revision %s didn't "
-                    "produce %d migrations" % (destination, abs(relative)))
+                    "produce %d migrations" % (destination, abs(relative))
+                )
         else:
-            revs = revs[0:index - relative + reldelta]
-            if not index and assert_relative_length and \
-                    len(revs) != abs(relative) + reldelta:
+            revs = revs[0 : index - relative + reldelta]
+            if (
+                not index
+                and assert_relative_length
+                and len(revs) != abs(relative) + reldelta
+            ):
                 raise RevisionError(
                     "Relative revision %s didn't "
-                    "produce %d migrations" % (destination, abs(relative)))
+                    "produce %d migrations" % (destination, abs(relative))
+                )
 
         return iter(revs)
 
     def iterate_revisions(
-            self, upper, lower, implicit_base=False, inclusive=False,
-            assert_relative_length=True, select_for_downgrade=False):
+        self,
+        upper,
+        lower,
+        implicit_base=False,
+        inclusive=False,
+        assert_relative_length=True,
+        select_for_downgrade=False,
+    ):
         """Iterate through script revisions, starting at the given
         upper revision identifier and ending at the lower.
 
@@ -545,37 +607,59 @@ class RevisionMap(object):
         """
 
         relative_upper = self._relative_iterate(
-            upper, lower, True, implicit_base,
-            inclusive, assert_relative_length
+            upper,
+            lower,
+            True,
+            implicit_base,
+            inclusive,
+            assert_relative_length,
         )
         if relative_upper:
             return relative_upper
 
         relative_lower = self._relative_iterate(
-            lower, upper, False, implicit_base,
-            inclusive, assert_relative_length
+            lower,
+            upper,
+            False,
+            implicit_base,
+            inclusive,
+            assert_relative_length,
         )
         if relative_lower:
             return relative_lower
 
         return self._iterate_revisions(
-            upper, lower, inclusive=inclusive, implicit_base=implicit_base,
-            select_for_downgrade=select_for_downgrade)
+            upper,
+            lower,
+            inclusive=inclusive,
+            implicit_base=implicit_base,
+            select_for_downgrade=select_for_downgrade,
+        )
 
     def _get_descendant_nodes(
-            self, targets, map_=None, check=False,
-            omit_immediate_dependencies=False, include_dependencies=True):
+        self,
+        targets,
+        map_=None,
+        check=False,
+        omit_immediate_dependencies=False,
+        include_dependencies=True,
+    ):
 
         if omit_immediate_dependencies:
+
             def fn(rev):
                 if rev not in targets:
                     return rev._all_nextrev
                 else:
                     return rev.nextrev
+
         elif include_dependencies:
+
             def fn(rev):
                 return rev._all_nextrev
+
         else:
+
             def fn(rev):
                 return rev.nextrev
 
@@ -584,12 +668,16 @@ class RevisionMap(object):
         )
 
     def _get_ancestor_nodes(
-            self, targets, map_=None, check=False, include_dependencies=True):
+        self, targets, map_=None, check=False, include_dependencies=True
+    ):
 
         if include_dependencies:
+
             def fn(rev):
                 return rev._all_down_revisions
+
         else:
+
             def fn(rev):
                 return rev._versioned_down_revisions
 
@@ -617,24 +705,30 @@ class RevisionMap(object):
                 if rev in seen:
                     continue
                 seen.add(rev)
-                todo.extend(
-                    map_[rev_id] for rev_id in fn(rev))
+                todo.extend(map_[rev_id] for rev_id in fn(rev))
                 yield rev
             if check:
-                overlaps = per_target.intersection(targets).\
-                    difference([target])
+                overlaps = per_target.intersection(targets).difference(
+                    [target]
+                )
                 if overlaps:
                     raise RevisionError(
                         "Requested revision %s overlaps with "
-                        "other requested revisions %s" % (
+                        "other requested revisions %s"
+                        % (
                             target.revision,
-                            ", ".join(r.revision for r in overlaps)
+                            ", ".join(r.revision for r in overlaps),
                         )
                     )
 
     def _iterate_revisions(
-            self, upper, lower, inclusive=True, implicit_base=False,
-            select_for_downgrade=False):
+        self,
+        upper,
+        lower,
+        inclusive=True,
+        implicit_base=False,
+        select_for_downgrade=False,
+    ):
         """iterate revisions from upper to lower.
 
         The traversal is depth-first within branches, and breadth-first
@@ -650,8 +744,9 @@ class RevisionMap(object):
         # is specified using a branch identifier, then we limit operations
         # to just that branch.
 
-        limit_to_lower_branch = \
-            isinstance(lower, compat.string_types) and lower.endswith('@base')
+        limit_to_lower_branch = isinstance(
+            lower, compat.string_types
+        ) and lower.endswith("@base")
 
         uppers = util.dedupe_tuple(self.get_revisions(upper))
 
@@ -663,16 +758,14 @@ class RevisionMap(object):
         if limit_to_lower_branch:
             lowers = self.get_revisions(self._get_base_revisions(lower))
         elif implicit_base and requested_lowers:
-            lower_ancestors = set(
-                self._get_ancestor_nodes(requested_lowers)
-            )
+            lower_ancestors = set(self._get_ancestor_nodes(requested_lowers))
             lower_descendants = set(
                 self._get_descendant_nodes(requested_lowers)
             )
             base_lowers = set()
-            candidate_lowers = upper_ancestors.\
-                difference(lower_ancestors).\
-                difference(lower_descendants)
+            candidate_lowers = upper_ancestors.difference(
+                lower_ancestors
+            ).difference(lower_descendants)
             for rev in candidate_lowers:
                 for downrev in rev._all_down_revisions:
                     if self._revision_map[downrev] in candidate_lowers:
@@ -690,13 +783,15 @@ class RevisionMap(object):
 
         # represents all nodes we will produce
         total_space = set(
-            rev.revision for rev in upper_ancestors).intersection(
-            rev.revision for rev
-            in self._get_descendant_nodes(
-                lowers, check=True,
+            rev.revision for rev in upper_ancestors
+        ).intersection(
+            rev.revision
+            for rev in self._get_descendant_nodes(
+                lowers,
+                check=True,
                 omit_immediate_dependencies=(
                     select_for_downgrade and requested_lowers
-                )
+                ),
             )
         )
 
@@ -706,7 +801,8 @@ class RevisionMap(object):
             start_from = set(requested_lowers)
             start_from.update(
                 self._get_ancestor_nodes(
-                    list(start_from), include_dependencies=True)
+                    list(start_from), include_dependencies=True
+                )
             )
 
             # determine all the current branch points represented
@@ -725,19 +821,18 @@ class RevisionMap(object):
         # organize branch points to be consumed separately from
         # member nodes
         branch_todo = set(
-            rev for rev in
-            (self._revision_map[rev] for rev in total_space)
-            if rev._is_real_branch_point and
-            len(total_space.intersection(rev._all_nextrev)) > 1
+            rev
+            for rev in (self._revision_map[rev] for rev in total_space)
+            if rev._is_real_branch_point
+            and len(total_space.intersection(rev._all_nextrev)) > 1
         )
 
         # it's not possible for any "uppers" to be in branch_todo,
         # because the ._all_nextrev of those nodes is not in total_space
-        #assert not branch_todo.intersection(uppers)
+        # assert not branch_todo.intersection(uppers)
 
         todo = collections.deque(
-            r for r in uppers
-            if r.revision in total_space
+            r for r in uppers if r.revision in total_space
         )
 
         # iterate for total_space being emptied out
@@ -746,7 +841,8 @@ class RevisionMap(object):
 
             if not total_space_modified:
                 raise RevisionError(
-                    "Dependency resolution failed; iteration can't proceed")
+                    "Dependency resolution failed; iteration can't proceed"
+                )
             total_space_modified = False
             # when everything non-branch pending is consumed,
             # add to the todo any branch nodes that have no
@@ -755,12 +851,13 @@ class RevisionMap(object):
                 todo.extendleft(
                     sorted(
                         (
-                            rev for rev in branch_todo
+                            rev
+                            for rev in branch_todo
                             if not rev._all_nextrev.intersection(total_space)
                         ),
                         # favor "revisioned" branch points before
                         # dependent ones
-                        key=lambda rev: 0 if rev.is_branch_point else 1
+                        key=lambda rev: 0 if rev.is_branch_point else 1,
                     )
                 )
                 branch_todo.difference_update(todo)
@@ -772,11 +869,14 @@ class RevisionMap(object):
 
                 # do depth first for elements within branches,
                 # don't consume any actual branch nodes
-                todo.extendleft([
-                    self._revision_map[downrev]
-                    for downrev in reversed(rev._all_down_revisions)
-                    if self._revision_map[downrev] not in branch_todo
-                    and downrev in total_space])
+                todo.extendleft(
+                    [
+                        self._revision_map[downrev]
+                        for downrev in reversed(rev._all_down_revisions)
+                        if self._revision_map[downrev] not in branch_todo
+                        and downrev in total_space
+                    ]
+                )
 
                 if not inclusive and rev in requested_lowers:
                     continue
@@ -795,6 +895,7 @@ class Revision(object):
     to Python files in a version directory.
 
     """
+
     nextrev = frozenset()
     """following revisions, based on down_revision only."""
 
@@ -830,15 +931,13 @@ class Revision(object):
         illegal_chars = set(revision).intersection(_revision_illegal_chars)
         if illegal_chars:
             raise RevisionError(
-                "Character(s) '%s' not allowed in revision identifier '%s'" % (
-                    ", ".join(sorted(illegal_chars)),
-                    revision
-                )
+                "Character(s) '%s' not allowed in revision identifier '%s'"
+                % (", ".join(sorted(illegal_chars)), revision)
             )
 
     def __init__(
-            self, revision, down_revision,
-            dependencies=None, branch_labels=None):
+        self, revision, down_revision, dependencies=None, branch_labels=None
+    ):
         self.verify_rev_id(revision)
         self.revision = revision
         self.down_revision = tuple_rev_as_scalar(down_revision)
@@ -848,18 +947,12 @@ class Revision(object):
         self.branch_labels = set(self._orig_branch_labels)
 
     def __repr__(self):
-        args = [
-            repr(self.revision),
-            repr(self.down_revision)
-        ]
+        args = [repr(self.revision), repr(self.down_revision)]
         if self.dependencies:
             args.append("dependencies=%r" % (self.dependencies,))
         if self.branch_labels:
             args.append("branch_labels=%r" % (self.branch_labels,))
-        return "%s(%s)" % (
-            self.__class__.__name__,
-            ", ".join(args)
-        )
+        return "%s(%s)" % (self.__class__.__name__, ", ".join(args))
 
     def add_nextrev(self, revision):
         self._all_nextrev = self._all_nextrev.union([revision.revision])
@@ -868,8 +961,10 @@ class Revision(object):
 
     @property
     def _all_down_revisions(self):
-        return util.to_tuple(self.down_revision, default=()) + \
-            self._resolved_dependencies
+        return (
+            util.to_tuple(self.down_revision, default=())
+            + self._resolved_dependencies
+        )
 
     @property
     def _versioned_down_revisions(self):
index 058378b9d037de152ec2ce6287b9cc6ea1eabe59..f3df9520c978e3bfd69a6b6be57c54f20e968858 100644 (file)
@@ -37,7 +37,8 @@ def run_migrations_offline():
     """
     url = config.get_main_option("sqlalchemy.url")
     context.configure(
-        url=url, target_metadata=target_metadata, literal_binds=True)
+        url=url, target_metadata=target_metadata, literal_binds=True
+    )
 
     with context.begin_transaction():
         context.run_migrations()
@@ -52,18 +53,19 @@ def run_migrations_online():
     """
     connectable = engine_from_config(
         config.get_section(config.config_ini_section),
-        prefix='sqlalchemy.',
-        poolclass=pool.NullPool)
+        prefix="sqlalchemy.",
+        poolclass=pool.NullPool,
+    )
 
     with connectable.connect() as connection:
         context.configure(
-            connection=connection,
-            target_metadata=target_metadata
+            connection=connection, target_metadata=target_metadata
         )
 
         with context.begin_transaction():
             context.run_migrations()
 
+
 if context.is_offline_mode():
     run_migrations_offline()
 else:
index db24173f993037a9c9cd39a53dbcc6139b208911..f5ad3d48e77664c2a5967f803e89588ca9490448 100644 (file)
@@ -14,12 +14,12 @@ config = context.config
 # Interpret the config file for Python logging.
 # This line sets up loggers basically.
 fileConfig(config.config_file_name)
-logger = logging.getLogger('alembic.env')
+logger = logging.getLogger("alembic.env")
 
 # gather section names referring to different
 # databases.  These are named "engine1", "engine2"
 # in the sample .ini file.
-db_names = config.get_main_option('databases')
+db_names = config.get_main_option("databases")
 
 # add your model's MetaData objects here
 # for 'autogenerate' support.  These must be set
@@ -56,19 +56,21 @@ def run_migrations_offline():
     # individual files.
 
     engines = {}
-    for name in re.split(r',\s*', db_names):
+    for name in re.split(r",\s*", db_names):
         engines[name] = rec = {}
-        rec['url'] = context.config.get_section_option(name,
-                                                       "sqlalchemy.url")
+        rec["url"] = context.config.get_section_option(name, "sqlalchemy.url")
 
     for name, rec in engines.items():
         logger.info("Migrating database %s" % name)
         file_ = "%s.sql" % name
         logger.info("Writing output to %s" % file_)
-        with open(file_, 'w') as buffer:
-            context.configure(url=rec['url'], output_buffer=buffer,
-                              target_metadata=target_metadata.get(name),
-                              literal_binds=True)
+        with open(file_, "w") as buffer:
+            context.configure(
+                url=rec["url"],
+                output_buffer=buffer,
+                target_metadata=target_metadata.get(name),
+                literal_binds=True,
+            )
             with context.begin_transaction():
                 context.run_migrations(engine_name=name)
 
@@ -85,46 +87,47 @@ def run_migrations_online():
     # engines, then run all migrations, then commit all transactions.
 
     engines = {}
-    for name in re.split(r',\s*', db_names):
+    for name in re.split(r",\s*", db_names):
         engines[name] = rec = {}
-        rec['engine'] = engine_from_config(
+        rec["engine"] = engine_from_config(
             context.config.get_section(name),
-            prefix='sqlalchemy.',
-            poolclass=pool.NullPool)
+            prefix="sqlalchemy.",
+            poolclass=pool.NullPool,
+        )
 
     for name, rec in engines.items():
-        engine = rec['engine']
-        rec['connection'] = conn = engine.connect()
+        engine = rec["engine"]
+        rec["connection"] = conn = engine.connect()
 
         if USE_TWOPHASE:
-            rec['transaction'] = conn.begin_twophase()
+            rec["transaction"] = conn.begin_twophase()
         else:
-            rec['transaction'] = conn.begin()
+            rec["transaction"] = conn.begin()
 
     try:
         for name, rec in engines.items():
             logger.info("Migrating database %s" % name)
             context.configure(
-                connection=rec['connection'],
+                connection=rec["connection"],
                 upgrade_token="%s_upgrades" % name,
                 downgrade_token="%s_downgrades" % name,
-                target_metadata=target_metadata.get(name)
+                target_metadata=target_metadata.get(name),
             )
             context.run_migrations(engine_name=name)
 
         if USE_TWOPHASE:
             for rec in engines.values():
-                rec['transaction'].prepare()
+                rec["transaction"].prepare()
 
         for rec in engines.values():
-            rec['transaction'].commit()
+            rec["transaction"].commit()
     except:
         for rec in engines.values():
-            rec['transaction'].rollback()
+            rec["transaction"].rollback()
         raise
     finally:
         for rec in engines.values():
-            rec['connection'].close()
+            rec["connection"].close()
 
 
 if context.is_offline_mode():
index 5ad9fd5958626a6c413522211822fee96b688c75..8c06cdc00bbe0261ed783b081870a72b5803937f 100644 (file)
@@ -13,18 +13,21 @@ from sqlalchemy.engine.base import Engine
 try:
     # if pylons app already in, don't create a new app
     from pylons import config as pylons_config
-    pylons_config['__file__']
+
+    pylons_config["__file__"]
 except:
     config = context.config
     # can use config['__file__'] here, i.e. the Pylons
     # ini file, instead of alembic.ini
-    config_file = config.get_main_option('pylons_config_file')
+    config_file = config.get_main_option("pylons_config_file")
     fileConfig(config_file)
-    wsgi_app = loadapp('config:%s' % config_file, relative_to='.')
+    wsgi_app = loadapp("config:%s" % config_file, relative_to=".")
 
 
 # customize this section for non-standard engine configurations.
-meta = __import__("%s.model.meta" % wsgi_app.config['pylons.package']).model.meta
+meta = __import__(
+    "%s.model.meta" % wsgi_app.config["pylons.package"]
+).model.meta
 
 # add your model's MetaData object here
 # for 'autogenerate' support
@@ -46,8 +49,10 @@ def run_migrations_offline():
 
     """
     context.configure(
-        url=meta.engine.url, target_metadata=target_metadata,
-        literal_binds=True)
+        url=meta.engine.url,
+        target_metadata=target_metadata,
+        literal_binds=True,
+    )
     with context.begin_transaction():
         context.run_migrations()
 
@@ -65,13 +70,13 @@ def run_migrations_online():
 
     with engine.connect() as connection:
         context.configure(
-            connection=connection,
-            target_metadata=target_metadata
+            connection=connection, target_metadata=target_metadata
         )
 
         with context.begin_transaction():
             context.run_migrations()
 
+
 if context.is_offline_mode():
     run_migrations_offline()
 else:
index 553f501b901a2adeb766ba751db378ce05cc7750..70c28e0324377eb6c98cad05e5d82e06f0006539 100644 (file)
@@ -1,6 +1,13 @@
 from .fixtures import TestBase
-from .assertions import eq_, ne_, is_, is_not_, assert_raises_message, \
-    eq_ignore_whitespace, assert_raises
+from .assertions import (
+    eq_,
+    ne_,
+    is_,
+    is_not_,
+    assert_raises_message,
+    eq_ignore_whitespace,
+    assert_raises,
+)
 
 from .util import provide_metadata
 
index 2c7382c2997498dd102cc98cbee0738316ced55a..c25b4449079f8ded9960ca94c60cdac4e8d3bd30 100644 (file)
@@ -15,6 +15,7 @@ from . import config
 
 
 if not util.sqla_094:
+
     def eq_(a, b, msg=None):
         """Assert a == b, with repr messaging on failure."""
         assert a == b, msg or "%r != %r" % (a, b)
@@ -46,27 +47,36 @@ if not util.sqla_094:
             callable_(*args, **kwargs)
             assert False, "Callable did not raise an exception"
         except except_cls as e:
-            assert re.search(
-                msg, text_type(e), re.UNICODE), "%r !~ %s" % (msg, e)
-            print(text_type(e).encode('utf-8'))
+            assert re.search(msg, text_type(e), re.UNICODE), "%r !~ %s" % (
+                msg,
+                e,
+            )
+            print(text_type(e).encode("utf-8"))
+
 
 else:
-    from sqlalchemy.testing.assertions import eq_, ne_, is_, is_not_, \
-        assert_raises_message, assert_raises
+    from sqlalchemy.testing.assertions import (
+        eq_,
+        ne_,
+        is_,
+        is_not_,
+        assert_raises_message,
+        assert_raises,
+    )
 
 
 def eq_ignore_whitespace(a, b, msg=None):
-    a = re.sub(r'^\s+?|\n', "", a)
-    a = re.sub(r' {2,}', " ", a)
-    b = re.sub(r'^\s+?|\n', "", b)
-    b = re.sub(r' {2,}', " ", b)
+    a = re.sub(r"^\s+?|\n", "", a)
+    a = re.sub(r" {2,}", " ", a)
+    b = re.sub(r"^\s+?|\n", "", b)
+    b = re.sub(r" {2,}", " ", b)
 
     # convert for unicode string rendering,
     # using special escape character "!U"
     if py3k:
-        b = re.sub(r'!U', '', b)
+        b = re.sub(r"!U", "", b)
     else:
-        b = re.sub(r'!U', 'u', b)
+        b = re.sub(r"!U", "u", b)
 
     assert a == b, msg or "%r != %r" % (a, b)
 
@@ -74,9 +84,10 @@ def eq_ignore_whitespace(a, b, msg=None):
 def assert_compiled(element, assert_string, dialect=None):
     dialect = _get_dialect(dialect)
     eq_(
-        text_type(element.compile(dialect=dialect)).
-        replace("\n", "").replace("\t", ""),
-        assert_string.replace("\n", "").replace("\t", "")
+        text_type(element.compile(dialect=dialect))
+        .replace("\n", "")
+        .replace("\t", ""),
+        assert_string.replace("\n", "").replace("\t", ""),
     )
 
 
@@ -84,19 +95,20 @@ _dialect_mods = {}
 
 
 def _get_dialect(name):
-    if name is None or name == 'default':
+    if name is None or name == "default":
         return default.DefaultDialect()
     else:
         try:
             dialect_mod = _dialect_mods[name]
         except KeyError:
             dialect_mod = getattr(
-                __import__('sqlalchemy.dialects.%s' % name).dialects, name)
+                __import__("sqlalchemy.dialects.%s" % name).dialects, name
+            )
             _dialect_mods[name] = dialect_mod
         d = dialect_mod.dialect()
-        if name == 'postgresql':
+        if name == "postgresql":
             d.implicit_returning = True
-        elif name == 'mssql':
+        elif name == "mssql":
             d.legacy_schema_aliasing = False
         return d
 
@@ -161,6 +173,7 @@ def emits_warning_on(db, *messages):
     were in fact seen.
 
     """
+
     @decorator
     def decorate(fn, *args, **kw):
         with expect_warnings_on(db, *messages):
@@ -189,8 +202,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True):
             return
 
         for filter_ in filters:
-            if (regex and filter_.match(msg)) or \
-                    (not regex and filter_ == msg):
+            if (regex and filter_.match(msg)) or (
+                not regex and filter_ == msg
+            ):
                 seen.discard(filter_)
                 break
         else:
@@ -203,6 +217,6 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True):
         yield
 
     if assert_:
-        assert not seen, "Warnings were not seen: %s" % \
-            ", ".join("%r" % (s.pattern if regex else s) for s in seen)
-
+        assert not seen, "Warnings were not seen: %s" % ", ".join(
+            "%r" % (s.pattern if regex else s) for s in seen
+        )
index e0af6a2cc2db324bd8f0cadd3be38860ff853ea4..9fbd50f6fb41bc26577f6e25280920b53ceb419b 100644 (file)
@@ -1,13 +1,12 @@
 def get_url_driver_name(url):
-    if '+' not in url.drivername:
+    if "+" not in url.drivername:
         return url.get_dialect().driver
     else:
-        return url.drivername.split('+')[1]
+        return url.drivername.split("+")[1]
 
 
 def get_url_backend_name(url):
-    if '+' not in url.drivername:
+    if "+" not in url.drivername:
         return url.drivername
     else:
-        return url.drivername.split('+')[0]
-
+        return url.drivername.split("+")[0]
index ca28c6be20aef31cccf48ee518937bbf265ea3da..7d7009e726808afa7aaf255f71cb689cd90a4d27 100644 (file)
@@ -66,7 +66,8 @@ class Config(object):
         assert _current, "Can't push without a default Config set up"
         cls.push(
             Config(
-                db, _current.db_opts, _current.options, _current.file_config)
+                db, _current.db_opts, _current.options, _current.file_config
+            )
         )
 
     @classmethod
@@ -88,4 +89,3 @@ class Config(object):
     def all_dbs(cls):
         for cfg in cls.all_configs():
             yield cfg.db
-
index dadabc8a157ea6c490b688a783e4cc81ebc3f602..68d00687ce9cb770259bb7e48afe162ea88e96b2 100644 (file)
@@ -25,4 +25,3 @@ def testing_engine(url=None, options=None):
     engine = create_engine(url, **options)
 
     return engine
-
index 0318703f518d9926f1bfeb243220b2ee340c3faa..51483d1945c706280f946351c08998607bf83ece 100644 (file)
@@ -15,21 +15,22 @@ def _get_staging_directory():
     if provision.FOLLOWER_IDENT:
         return "scratch_%s" % provision.FOLLOWER_IDENT
     else:
-        return 'scratch'
+        return "scratch"
 
 
 def staging_env(create=True, template="generic", sourceless=False):
     from alembic import command, script
+
     cfg = _testing_config()
     if create:
-        path = os.path.join(_get_staging_directory(), 'scripts')
+        path = os.path.join(_get_staging_directory(), "scripts")
         if os.path.exists(path):
             shutil.rmtree(path)
         command.init(cfg, path, template=template)
         if sourceless:
             try:
                 # do an import so that a .pyc/.pyo is generated.
-                util.load_python_file(path, 'env.py')
+                util.load_python_file(path, "env.py")
             except AttributeError:
                 # we don't have the migration context set up yet
                 # so running the .env py throws this exception.
@@ -38,10 +39,13 @@ def staging_env(create=True, template="generic", sourceless=False):
                 # worth it.
                 pass
             assert sourceless in (
-                "pep3147_envonly", "simple", "pep3147_everything"), sourceless
+                "pep3147_envonly",
+                "simple",
+                "pep3147_everything",
+            ), sourceless
             make_sourceless(
                 os.path.join(path, "env.py"),
-                "pep3147" if "pep3147" in sourceless else "simple"
+                "pep3147" if "pep3147" in sourceless else "simple",
             )
 
     sc = script.ScriptDirectory.from_config(cfg)
@@ -53,40 +57,44 @@ def clear_staging_env():
 
 
 def script_file_fixture(txt):
-    dir_ = os.path.join(_get_staging_directory(), 'scripts')
+    dir_ = os.path.join(_get_staging_directory(), "scripts")
     path = os.path.join(dir_, "script.py.mako")
-    with open(path, 'w') as f:
+    with open(path, "w") as f:
         f.write(txt)
 
 
 def env_file_fixture(txt):
-    dir_ = os.path.join(_get_staging_directory(), 'scripts')
-    txt = """
+    dir_ = os.path.join(_get_staging_directory(), "scripts")
+    txt = (
+        """
 from alembic import context
 
 config = context.config
-""" + txt
+"""
+        + txt
+    )
 
     path = os.path.join(dir_, "env.py")
     pyc_path = util.pyc_file_from_path(path)
     if pyc_path:
         os.unlink(pyc_path)
 
-    with open(path, 'w') as f:
+    with open(path, "w") as f:
         f.write(txt)
 
 
 def _sqlite_file_db(tempname="foo.db"):
-    dir_ = os.path.join(_get_staging_directory(), 'scripts')
+    dir_ = os.path.join(_get_staging_directory(), "scripts")
     url = "sqlite:///%s/%s" % (dir_, tempname)
     return engines.testing_engine(url=url)
 
 
 def _sqlite_testing_config(sourceless=False):
-    dir_ = os.path.join(_get_staging_directory(), 'scripts')
+    dir_ = os.path.join(_get_staging_directory(), "scripts")
     url = "sqlite:///%s/foo.db" % dir_
 
-    return _write_config_file("""
+    return _write_config_file(
+        """
 [alembic]
 script_location = %s
 sqlalchemy.url = %s
@@ -115,14 +123,17 @@ keys = generic
 [formatter_generic]
 format = %%(levelname)-5.5s [%%(name)s] %%(message)s
 datefmt = %%H:%%M:%%S
-    """ % (dir_, url, "true" if sourceless else "false"))
+    """
+        % (dir_, url, "true" if sourceless else "false")
+    )
 
 
-def _multi_dir_testing_config(sourceless=False, extra_version_location=''):
-    dir_ = os.path.join(_get_staging_directory(), 'scripts')
+def _multi_dir_testing_config(sourceless=False, extra_version_location=""):
+    dir_ = os.path.join(_get_staging_directory(), "scripts")
     url = "sqlite:///%s/foo.db" % dir_
 
-    return _write_config_file("""
+    return _write_config_file(
+        """
 [alembic]
 script_location = %s
 sqlalchemy.url = %s
@@ -152,15 +163,22 @@ keys = generic
 [formatter_generic]
 format = %%(levelname)-5.5s [%%(name)s] %%(message)s
 datefmt = %%H:%%M:%%S
-    """ % (dir_, url, "true" if sourceless else "false",
-           extra_version_location))
+    """
+        % (
+            dir_,
+            url,
+            "true" if sourceless else "false",
+            extra_version_location,
+        )
+    )
 
 
 def _no_sql_testing_config(dialect="postgresql", directives=""):
     """use a postgresql url with no host so that
     connections guaranteed to fail"""
-    dir_ = os.path.join(_get_staging_directory(), 'scripts')
-    return _write_config_file("""
+    dir_ = os.path.join(_get_staging_directory(), "scripts")
+    return _write_config_file(
+        """
 [alembic]
 script_location = %s
 sqlalchemy.url = %s://
@@ -190,32 +208,36 @@ keys = generic
 format = %%(levelname)-5.5s [%%(name)s] %%(message)s
 datefmt = %%H:%%M:%%S
 
-""" % (dir_, dialect, directives))
+"""
+        % (dir_, dialect, directives)
+    )
 
 
 def _write_config_file(text):
     cfg = _testing_config()
-    with open(cfg.config_file_name, 'w') as f:
+    with open(cfg.config_file_name, "w") as f:
         f.write(text)
     return cfg
 
 
 def _testing_config():
     from alembic.config import Config
+
     if not os.access(_get_staging_directory(), os.F_OK):
         os.mkdir(_get_staging_directory())
-    return Config(os.path.join(_get_staging_directory(), 'test_alembic.ini'))
+    return Config(os.path.join(_get_staging_directory(), "test_alembic.ini"))
 
 
 def write_script(
-        scriptdir, rev_id, content, encoding='ascii', sourceless=False):
+    scriptdir, rev_id, content, encoding="ascii", sourceless=False
+):
     old = scriptdir.revision_map.get_revision(rev_id)
     path = old.path
 
     content = textwrap.dedent(content)
     if encoding:
         content = content.encode(encoding)
-    with open(path, 'wb') as fp:
+    with open(path, "wb") as fp:
         fp.write(content)
     pyc_path = util.pyc_file_from_path(path)
     if pyc_path:
@@ -223,20 +245,21 @@ def write_script(
     script = Script._from_path(scriptdir, path)
     old = scriptdir.revision_map.get_revision(script.revision)
     if old.down_revision != script.down_revision:
-        raise Exception("Can't change down_revision "
-                        "on a refresh operation.")
+        raise Exception(
+            "Can't change down_revision " "on a refresh operation."
+        )
     scriptdir.revision_map.add_revision(script, _replace=True)
 
     if sourceless:
         make_sourceless(
-            path,
-            "pep3147" if sourceless == "pep3147_everything" else "simple"
+            path, "pep3147" if sourceless == "pep3147_everything" else "simple"
         )
 
 
 def make_sourceless(path, style):
 
     import py_compile
+
     py_compile.compile(path)
 
     if style == "simple" and has_pep3147():
@@ -264,7 +287,10 @@ def three_rev_fixture(cfg):
 
     script = ScriptDirectory.from_config(cfg)
     script.generate_revision(a, "revision a", refresh=True)
-    write_script(script, a, """\
+    write_script(
+        script,
+        a,
+        """\
 "Rev A"
 revision = '%s'
 down_revision = None
@@ -279,10 +305,16 @@ def upgrade():
 def downgrade():
     op.execute("DROP STEP 1")
 
-""" % a)
+"""
+        % a,
+    )
 
     script.generate_revision(b, "revision b", refresh=True)
-    write_script(script, b, u("""# coding: utf-8
+    write_script(
+        script,
+        b,
+        u(
+            """# coding: utf-8
 "Rev B, méil, %3"
 revision = '{}'
 down_revision = '{}'
@@ -297,10 +329,16 @@ def upgrade():
 def downgrade():
     op.execute("DROP STEP 2")
 
-""").format(b, a), encoding="utf-8")
+"""
+        ).format(b, a),
+        encoding="utf-8",
+    )
 
     script.generate_revision(c, "revision c", refresh=True)
-    write_script(script, c, """\
+    write_script(
+        script,
+        c,
+        """\
 "Rev C"
 revision = '%s'
 down_revision = '%s'
@@ -315,7 +353,9 @@ def upgrade():
 def downgrade():
     op.execute("DROP STEP 3")
 
-""" % (c, b))
+"""
+        % (c, b),
+    )
     return a, b, c
 
 
@@ -328,8 +368,12 @@ def multi_heads_fixture(cfg, a, b, c):
 
     script = ScriptDirectory.from_config(cfg)
     script.generate_revision(
-        d, "revision d from b", head=b, splice=True, refresh=True)
-    write_script(script, d, """\
+        d, "revision d from b", head=b, splice=True, refresh=True
+    )
+    write_script(
+        script,
+        d,
+        """\
 "Rev D"
 revision = '%s'
 down_revision = '%s'
@@ -344,11 +388,17 @@ def upgrade():
 def downgrade():
     op.execute("DROP STEP 4")
 
-""" % (d, b))
+"""
+        % (d, b),
+    )
 
     script.generate_revision(
-        e, "revision e from d", head=d, splice=True, refresh=True)
-    write_script(script, e, """\
+        e, "revision e from d", head=d, splice=True, refresh=True
+    )
+    write_script(
+        script,
+        e,
+        """\
 "Rev E"
 revision = '%s'
 down_revision = '%s'
@@ -363,11 +413,17 @@ def upgrade():
 def downgrade():
     op.execute("DROP STEP 5")
 
-""" % (e, d))
+"""
+        % (e, d),
+    )
 
     script.generate_revision(
-        f, "revision f from b", head=b, splice=True, refresh=True)
-    write_script(script, f, """\
+        f, "revision f from b", head=b, splice=True, refresh=True
+    )
+    write_script(
+        script,
+        f,
+        """\
 "Rev F"
 revision = '%s'
 down_revision = '%s'
@@ -382,7 +438,9 @@ def upgrade():
 def downgrade():
     op.execute("DROP STEP 6")
 
-""" % (f, b))
+"""
+        % (f, b),
+    )
 
     return d, e, f
 
@@ -390,18 +448,16 @@ def downgrade():
 def _multidb_testing_config(engines):
     """alembic.ini fixture to work exactly with the 'multidb' template"""
 
-    dir_ = os.path.join(_get_staging_directory(), 'scripts')
+    dir_ = os.path.join(_get_staging_directory(), "scripts")
 
-    databases = ", ".join(
-        engines.keys()
-    )
+    databases = ", ".join(engines.keys())
     engines = "\n\n".join(
-        "[%s]\n"
-        "sqlalchemy.url = %s" % (key, value.url)
+        "[%s]\n" "sqlalchemy.url = %s" % (key, value.url)
         for key, value in engines.items()
     )
 
-    return _write_config_file("""
+    return _write_config_file(
+        """
 [alembic]
 script_location = %s
 sourceless = false
@@ -432,5 +488,6 @@ keys = generic
 [formatter_generic]
 format = %%(levelname)-5.5s [%%(name)s] %%(message)s
 datefmt = %%H:%%M:%%S
-    """ % (dir_, databases, engines)
+    """
+        % (dir_, databases, engines)
     )
index 7d33a5b3f20e25fecb1b3a9d6bcda91307a3d0fc..41ed547a033f8d17c4a0123104c07a1abf58b63a 100644 (file)
@@ -74,15 +74,15 @@ class compound(object):
 
     def matching_config_reasons(self, config):
         return [
-            predicate._as_string(config) for predicate
-            in self.skips.union(self.fails)
+            predicate._as_string(config)
+            for predicate in self.skips.union(self.fails)
             if predicate(config)
         ]
 
     def include_test(self, include_tags, exclude_tags):
         return bool(
-            not self.tags.intersection(exclude_tags) and
-            (not include_tags or self.tags.intersection(include_tags))
+            not self.tags.intersection(exclude_tags)
+            and (not include_tags or self.tags.intersection(include_tags))
         )
 
     def _extend(self, other):
@@ -91,13 +91,14 @@ class compound(object):
         self.tags.update(other.tags)
 
     def __call__(self, fn):
-        if hasattr(fn, '_sa_exclusion_extend'):
+        if hasattr(fn, "_sa_exclusion_extend"):
             fn._sa_exclusion_extend._extend(self)
             return fn
 
         @decorator
         def decorate(fn, *args, **kw):
             return self._do(config._current, fn, *args, **kw)
+
         decorated = decorate(fn)
         decorated._sa_exclusion_extend = self
         return decorated
@@ -117,10 +118,7 @@ class compound(object):
     def _do(self, config, fn, *args, **kw):
         for skip in self.skips:
             if skip(config):
-                msg = "'%s' : %s" % (
-                    fn.__name__,
-                    skip._as_string(config)
-                )
+                msg = "'%s' : %s" % (fn.__name__, skip._as_string(config))
                 raise SkipTest(msg)
 
         try:
@@ -131,16 +129,20 @@ class compound(object):
             self._expect_success(config, name=fn.__name__)
             return return_value
 
-    def _expect_failure(self, config, ex, name='block'):
+    def _expect_failure(self, config, ex, name="block"):
         for fail in self.fails:
             if fail(config):
-                print(("%s failed as expected (%s): %s " % (
-                    name, fail._as_string(config), str(ex))))
+                print(
+                    (
+                        "%s failed as expected (%s): %s "
+                        % (name, fail._as_string(config), str(ex))
+                    )
+                )
                 break
         else:
             compat.raise_from_cause(ex)
 
-    def _expect_success(self, config, name='block'):
+    def _expect_success(self, config, name="block"):
         if not self.fails:
             return
         for fail in self.fails:
@@ -148,13 +150,12 @@ class compound(object):
                 break
         else:
             raise AssertionError(
-                "Unexpected success for '%s' (%s)" %
-                (
+                "Unexpected success for '%s' (%s)"
+                (
                     name,
                     " and ".join(
-                        fail._as_string(config)
-                        for fail in self.fails
-                    )
+                        fail._as_string(config) for fail in self.fails
+                    ),
                 )
             )
 
@@ -191,8 +192,8 @@ class Predicate(object):
             return predicate
         elif isinstance(predicate, (list, set)):
             return OrPredicate(
-                [cls.as_predicate(pred) for pred in predicate],
-                description)
+                [cls.as_predicate(pred) for pred in predicate], description
+            )
         elif isinstance(predicate, tuple):
             return SpecPredicate(*predicate)
         elif isinstance(predicate, compat.string_types):
@@ -217,7 +218,7 @@ class Predicate(object):
             "driver": get_url_driver_name(config.db.url),
             "database": get_url_backend_name(config.db.url),
             "doesnt_support": "doesn't support" if bool_ else "does support",
-            "does_support": "does support" if bool_ else "doesn't support"
+            "does_support": "does support" if bool_ else "doesn't support",
         }
 
     def _as_string(self, config=None, negate=False):
@@ -244,21 +245,21 @@ class SpecPredicate(Predicate):
         self.description = description
 
     _ops = {
-        '<': operator.lt,
-        '>': operator.gt,
-        '==': operator.eq,
-        '!=': operator.ne,
-        '<=': operator.le,
-        '>=': operator.ge,
-        'in': operator.contains,
-        'between': lambda val, pair: val >= pair[0] and val <= pair[1],
+        "<": operator.lt,
+        ">": operator.gt,
+        "==": operator.eq,
+        "!=": operator.ne,
+        "<=": operator.le,
+        ">=": operator.ge,
+        "in": operator.contains,
+        "between": lambda val, pair: val >= pair[0] and val <= pair[1],
     }
 
     def __call__(self, config):
         engine = config.db
 
         if "+" in self.db:
-            dialect, driver = self.db.split('+')
+            dialect, driver = self.db.split("+")
         else:
             dialect, driver = self.db, None
 
@@ -271,8 +272,9 @@ class SpecPredicate(Predicate):
             assert driver is None, "DBAPI version specs not supported yet"
 
             version = _server_version(engine)
-            oper = hasattr(self.op, '__call__') and self.op \
-                or self._ops[self.op]
+            oper = (
+                hasattr(self.op, "__call__") and self.op or self._ops[self.op]
+            )
             return oper(version, self.spec)
         else:
             return True
@@ -287,17 +289,9 @@ class SpecPredicate(Predicate):
                 return "%s" % self.db
         else:
             if negate:
-                return "not %s %s %s" % (
-                    self.db,
-                    self.op,
-                    self.spec
-                )
+                return "not %s %s %s" % (self.db, self.op, self.spec)
             else:
-                return "%s %s %s" % (
-                    self.db,
-                    self.op,
-                    self.spec
-                )
+                return "%s %s %s" % (self.db, self.op, self.spec)
 
 
 class LambdaPredicate(Predicate):
@@ -354,8 +348,9 @@ class OrPredicate(Predicate):
             conjunction = " and "
         else:
             conjunction = " or "
-        return conjunction.join(p._as_string(config, negate=negate)
-                                for p in self.predicates)
+        return conjunction.join(
+            p._as_string(config, negate=negate) for p in self.predicates
+        )
 
     def _negation_str(self, config):
         if self.description is not None:
@@ -385,15 +380,13 @@ def _server_version(engine):
 
     # force metadata to be retrieved
     conn = engine.connect()
-    version = getattr(engine.dialect, 'server_version_info', ())
+    version = getattr(engine.dialect, "server_version_info", ())
     conn.close()
     return version
 
 
 def db_spec(*dbs):
-    return OrPredicate(
-        [Predicate.as_predicate(db) for db in dbs]
-    )
+    return OrPredicate([Predicate.as_predicate(db) for db in dbs])
 
 
 def open():
@@ -418,11 +411,7 @@ def fails_on(db, reason=None):
 
 
 def fails_on_everything_except(*dbs):
-    return succeeds_if(
-        OrPredicate([
-            Predicate.as_predicate(db) for db in dbs
-        ])
-    )
+    return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
 
 
 def skip(db, reason=None):
@@ -441,7 +430,6 @@ def exclude(db, op, spec, reason=None):
 
 def against(config, *queries):
     assert queries, "no queries sent!"
-    return OrPredicate([
-        Predicate.as_predicate(query)
-        for query in queries
-    ])(config)
+    return OrPredicate([Predicate.as_predicate(query) for query in queries])(
+        config
+    )
index 86d40a29adf0ebe457d9117bf86d67712f0cf979..b812476d1cc41b78fafe5237449a6a18f836594e 100644 (file)
@@ -17,10 +17,11 @@ from .assertions import _get_dialect, eq_
 from . import mock
 
 testing_config = configparser.ConfigParser()
-testing_config.read(['test.cfg'])
+testing_config.read(["test.cfg"])
 
 
 if not util.sqla_094:
+
     class TestBase(object):
         # A sequence of database names to always run, regardless of the
         # constraints below.
@@ -51,6 +52,8 @@ if not util.sqla_094:
         def teardown(self):
             if hasattr(self, "tearDown"):
                 self.tearDown()
+
+
 else:
     from sqlalchemy.testing.fixtures import TestBase
 
@@ -60,23 +63,22 @@ def capture_db():
 
     def dump(sql, *multiparams, **params):
         buf.append(str(sql.compile(dialect=engine.dialect)))
+
     engine = create_engine("postgresql://", strategy="mock", executor=dump)
     return engine, buf
 
+
 _engs = {}
 
 
 @contextmanager
 def capture_context_buffer(**kw):
-    if kw.pop('bytes_io', False):
+    if kw.pop("bytes_io", False):
         buf = io.BytesIO()
     else:
         buf = io.StringIO()
 
-    kw.update({
-        'dialect_name': "sqlite",
-        'output_buffer': buf
-    })
+    kw.update({"dialect_name": "sqlite", "output_buffer": buf})
     conf = EnvironmentContext.configure
 
     def configure(*arg, **opt):
@@ -88,17 +90,20 @@ def capture_context_buffer(**kw):
 
 
 def op_fixture(
-        dialect='default', as_sql=False,
-        naming_convention=None, literal_binds=False,
-        native_boolean=None):
+    dialect="default",
+    as_sql=False,
+    naming_convention=None,
+    literal_binds=False,
+    native_boolean=None,
+):
 
     opts = {}
     if naming_convention:
         if not util.sqla_092:
             raise SkipTest(
-                "naming_convention feature requires "
-                "sqla 0.9.2 or greater")
-        opts['target_metadata'] = MetaData(naming_convention=naming_convention)
+                "naming_convention feature requires " "sqla 0.9.2 or greater"
+            )
+        opts["target_metadata"] = MetaData(naming_convention=naming_convention)
 
     class buffer_(object):
         def __init__(self):
@@ -106,12 +111,12 @@ def op_fixture(
 
         def write(self, msg):
             msg = msg.strip()
-            msg = re.sub(r'[\n\t]', '', msg)
+            msg = re.sub(r"[\n\t]", "", msg)
             if as_sql:
                 # the impl produces soft tabs,
                 # so search for blocks of 4 spaces
-                msg = re.sub(r'    ', '', msg)
-                msg = re.sub(r'\;\n*$', '', msg)
+                msg = re.sub(r"    ", "", msg)
+                msg = re.sub(r"\;\n*$", "", msg)
 
             self.lines.append(msg)
 
@@ -136,13 +141,13 @@ def op_fixture(
             else:
                 assert False, "Could not locate fragment %r in %r" % (
                     sql,
-                    buf.lines
+                    buf.lines,
                 )
 
     if as_sql:
-        opts['as_sql'] = as_sql
+        opts["as_sql"] = as_sql
     if literal_binds:
-        opts['literal_binds'] = literal_binds
+        opts["literal_binds"] = literal_binds
     ctx_dialect = _get_dialect(dialect)
     if native_boolean is not None:
         ctx_dialect.supports_native_boolean = native_boolean
@@ -150,6 +155,7 @@ def op_fixture(
         # which breaks assumptions in the alembic test suite
         ctx_dialect.non_native_boolean_check_constraint = True
     if not as_sql:
+
         def execute(stmt, *multiparam, **param):
             if isinstance(stmt, string_types):
                 stmt = text(stmt)
@@ -160,12 +166,9 @@ def op_fixture(
 
         connection = mock.Mock(dialect=ctx_dialect, execute=execute)
     else:
-        opts['output_buffer'] = buf
+        opts["output_buffer"] = buf
         connection = None
-    context = ctx(
-        ctx_dialect,
-        connection,
-        opts)
+    context = ctx(ctx_dialect, connection, opts)
 
     alembic.op._proxy = Operations(context)
     return context
index 08a756cbc27e1fab3cda7021d4cbb7b54f3f0187..1d5256d48f1121bdc49fa9a435c3915db479f96b 100644 (file)
@@ -22,4 +22,5 @@ else:
     except ImportError:
         raise ImportError(
             "SQLAlchemy's test suite requires the "
-            "'mock' library as of 0.8.2.")
+            "'mock' library as of 0.8.2."
+        )
index 9f42fd2191a662eb1c9c1a0ffe96f2d2bd86d012..4bd415d193ed7f82d8e45d55f5db022e8960a24d 100644 (file)
@@ -20,20 +20,23 @@ this should be removable when Alembic targets SQLAlchemy 1.0.0.
 import os
 import sys
 
-bootstrap_file = locals()['bootstrap_file']
-to_bootstrap = locals()['to_bootstrap']
+bootstrap_file = locals()["bootstrap_file"]
+to_bootstrap = locals()["to_bootstrap"]
 
 
 def load_file_as_module(name):
     path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
     if sys.version_info.major >= 3:
         from importlib import machinery
+
         mod = machinery.SourceFileLoader(name, path).load_module()
     else:
         import imp
+
         mod = imp.load_source(name, path)
     return mod
 
+
 if to_bootstrap == "pytest":
     sys.modules["alembic_plugin_base"] = load_file_as_module("plugin_base")
     sys.modules["alembic_pytestplugin"] = load_file_as_module("pytestplugin")
index f8894d66689d945bd87ce50d6091b8295cd92c30..fafb9e1ccfc85f031fe8ace58dbcb07fa39b50a0 100644 (file)
@@ -25,6 +25,7 @@ import os
 import sys
 
 from nose.plugins import Plugin
+
 fixtures = None
 
 py3k = sys.version_info.major >= 3
@@ -33,7 +34,7 @@ py3k = sys.version_info.major >= 3
 class NoseSQLAlchemy(Plugin):
     enabled = True
 
-    name = 'sqla_testing'
+    name = "sqla_testing"
     score = 100
 
     def options(self, parser, env=os.environ):
@@ -43,8 +44,10 @@ class NoseSQLAlchemy(Plugin):
         def make_option(name, **kw):
             callback_ = kw.pop("callback", None)
             if callback_:
+
                 def wrap_(option, opt_str, value, parser):
                     callback_(opt_str, value, parser)
+
                 kw["callback"] = wrap_
             opt(name, **kw)
 
@@ -71,7 +74,7 @@ class NoseSQLAlchemy(Plugin):
 
     def wantMethod(self, fn):
         if py3k:
-            if not hasattr(fn.__self__, 'cls'):
+            if not hasattr(fn.__self__, "cls"):
                 return False
             cls = fn.__self__.cls
         else:
@@ -85,19 +88,19 @@ class NoseSQLAlchemy(Plugin):
         plugin_base.before_test(
             test,
             test.test.cls.__module__,
-            test.test.cls, test.test.method.__name__)
+            test.test.cls,
+            test.test.method.__name__,
+        )
 
     def afterTest(self, test):
         plugin_base.after_test(test)
 
     def startContext(self, ctx):
-        if not isinstance(ctx, type) \
-                or not issubclass(ctx, fixtures.TestBase):
+        if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase):
             return
         plugin_base.start_test_class(ctx)
 
     def stopContext(self, ctx):
-        if not isinstance(ctx, type) \
-                or not issubclass(ctx, fixtures.TestBase):
+        if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase):
             return
         plugin_base.stop_test_class(ctx)
index 141e82f7a8d7ca203b99d7e43728e71d5efea8cb..9acffb55e547927da34c83ac3ddfe3e9948627b6 100644 (file)
@@ -17,12 +17,14 @@ this should be removable when Alembic targets SQLAlchemy 1.0.0
 """
 
 from __future__ import absolute_import
+
 try:
     # unitttest has a SkipTest also but pytest doesn't
     # honor it unless nose is imported too...
     from nose import SkipTest
 except ImportError:
     from pytest import skip
+
     SkipTest = skip.Exception
 
 import sys
@@ -55,54 +57,118 @@ options = None
 
 
 def setup_options(make_option):
-    make_option("--log-info", action="callback", type="string", callback=_log,
-                help="turn on info logging for <LOG> (multiple OK)")
-    make_option("--log-debug", action="callback",
-                type="string", callback=_log,
-                help="turn on debug logging for <LOG> (multiple OK)")
-    make_option("--db", action="append", type="string", dest="db",
-                help="Use prefab database uri. Multiple OK, "
-                "first one is run by default.")
-    make_option('--dbs', action='callback', zeroarg_callback=_list_dbs,
-                help="List available prefab dbs")
-    make_option("--dburi", action="append", type="string", dest="dburi",
-                help="Database uri.  Multiple OK, "
-                "first one is run by default.")
-    make_option("--dropfirst", action="store_true", dest="dropfirst",
-                help="Drop all tables in the target database first")
-    make_option("--backend-only", action="store_true", dest="backend_only",
-                help="Run only tests marked with __backend__")
-    make_option("--postgresql-templatedb", type="string",
-                help="name of template database to use for Postgresql "
-                     "CREATE DATABASE (defaults to current database)")
-    make_option("--low-connections", action="store_true",
-                dest="low_connections",
-                help="Use a low number of distinct connections - "
-                "i.e. for Oracle TNS")
-    make_option("--write-idents", type="string", dest="write_idents",
-                help="write out generated follower idents to <file>, "
-                "when -n<num> is used")
-    make_option("--reversetop", action="store_true",
-                dest="reversetop", default=False,
-                help="Use a random-ordering set implementation in the ORM "
-                "(helps reveal dependency issues)")
-    make_option("--requirements", action="callback", type="string",
-                callback=_requirements_opt,
-                help="requirements class for testing, overrides setup.cfg")
-    make_option("--with-cdecimal", action="store_true",
-                dest="cdecimal", default=False,
-                help="Monkeypatch the cdecimal library into Python 'decimal' "
-                "for all tests")
-    make_option("--include-tag", action="callback", callback=_include_tag,
-                type="string",
-                help="Include tests with tag <tag>")
-    make_option("--exclude-tag", action="callback", callback=_exclude_tag,
-                type="string",
-                help="Exclude tests with tag <tag>")
-    make_option("--mysql-engine", action="store",
-                dest="mysql_engine", default=None,
-                help="Use the specified MySQL storage engine for all tables, "
-                "default is a db-default/InnoDB combo.")
+    make_option(
+        "--log-info",
+        action="callback",
+        type="string",
+        callback=_log,
+        help="turn on info logging for <LOG> (multiple OK)",
+    )
+    make_option(
+        "--log-debug",
+        action="callback",
+        type="string",
+        callback=_log,
+        help="turn on debug logging for <LOG> (multiple OK)",
+    )
+    make_option(
+        "--db",
+        action="append",
+        type="string",
+        dest="db",
+        help="Use prefab database uri. Multiple OK, "
+        "first one is run by default.",
+    )
+    make_option(
+        "--dbs",
+        action="callback",
+        zeroarg_callback=_list_dbs,
+        help="List available prefab dbs",
+    )
+    make_option(
+        "--dburi",
+        action="append",
+        type="string",
+        dest="dburi",
+        help="Database uri.  Multiple OK, " "first one is run by default.",
+    )
+    make_option(
+        "--dropfirst",
+        action="store_true",
+        dest="dropfirst",
+        help="Drop all tables in the target database first",
+    )
+    make_option(
+        "--backend-only",
+        action="store_true",
+        dest="backend_only",
+        help="Run only tests marked with __backend__",
+    )
+    make_option(
+        "--postgresql-templatedb",
+        type="string",
+        help="name of template database to use for Postgresql "
+        "CREATE DATABASE (defaults to current database)",
+    )
+    make_option(
+        "--low-connections",
+        action="store_true",
+        dest="low_connections",
+        help="Use a low number of distinct connections - "
+        "i.e. for Oracle TNS",
+    )
+    make_option(
+        "--write-idents",
+        type="string",
+        dest="write_idents",
+        help="write out generated follower idents to <file>, "
+        "when -n<num> is used",
+    )
+    make_option(
+        "--reversetop",
+        action="store_true",
+        dest="reversetop",
+        default=False,
+        help="Use a random-ordering set implementation in the ORM "
+        "(helps reveal dependency issues)",
+    )
+    make_option(
+        "--requirements",
+        action="callback",
+        type="string",
+        callback=_requirements_opt,
+        help="requirements class for testing, overrides setup.cfg",
+    )
+    make_option(
+        "--with-cdecimal",
+        action="store_true",
+        dest="cdecimal",
+        default=False,
+        help="Monkeypatch the cdecimal library into Python 'decimal' "
+        "for all tests",
+    )
+    make_option(
+        "--include-tag",
+        action="callback",
+        callback=_include_tag,
+        type="string",
+        help="Include tests with tag <tag>",
+    )
+    make_option(
+        "--exclude-tag",
+        action="callback",
+        callback=_exclude_tag,
+        type="string",
+        help="Exclude tests with tag <tag>",
+    )
+    make_option(
+        "--mysql-engine",
+        action="store",
+        dest="mysql_engine",
+        default=None,
+        help="Use the specified MySQL storage engine for all tables, "
+        "default is a db-default/InnoDB combo.",
+    )
 
 
 def configure_follower(follower_ident):
@@ -113,6 +179,7 @@ def configure_follower(follower_ident):
 
     """
     from alembic.testing import provision
+
     provision.FOLLOWER_IDENT = follower_ident
 
 
@@ -126,9 +193,9 @@ def memoize_important_follower_config(dict_):
     callables, so we have to just copy all of that over.
 
     """
-    dict_['memoized_config'] = {
-        'include_tags': include_tags,
-        'exclude_tags': exclude_tags
+    dict_["memoized_config"] = {
+        "include_tags": include_tags,
+        "exclude_tags": exclude_tags,
     }
 
 
@@ -138,14 +205,14 @@ def restore_important_follower_config(dict_):
     This invokes in the follower process.
 
     """
-    include_tags.update(dict_['memoized_config']['include_tags'])
-    exclude_tags.update(dict_['memoized_config']['exclude_tags'])
+    include_tags.update(dict_["memoized_config"]["include_tags"])
+    exclude_tags.update(dict_["memoized_config"]["exclude_tags"])
 
 
 def read_config():
     global file_config
     file_config = configparser.ConfigParser()
-    file_config.read(['setup.cfg', 'test.cfg'])
+    file_config.read(["setup.cfg", "test.cfg"])
 
 
 def pre_begin(opt):
@@ -169,12 +236,11 @@ def post_begin():
 
     # late imports, has to happen after config as well
     # as nose plugins like coverage
-    global util, fixtures, engines, exclusions, \
-        assertions, warnings, profiling,\
-        config, testing
+    global util, fixtures, engines, exclusions, assertions, warnings, profiling, config, testing
     from alembic.testing import config, warnings, exclusions  # noqa
     from alembic.testing import engines, fixtures  # noqa
     from sqlalchemy import util  # noqa
+
     warnings.setup_filters()
 
 
@@ -182,18 +248,19 @@ def _log(opt_str, value, parser):
     global logging
     if not logging:
         import logging
+
         logging.basicConfig()
 
-    if opt_str.endswith('-info'):
+    if opt_str.endswith("-info"):
         logging.getLogger(value).setLevel(logging.INFO)
-    elif opt_str.endswith('-debug'):
+    elif opt_str.endswith("-debug"):
         logging.getLogger(value).setLevel(logging.DEBUG)
 
 
 def _list_dbs(*args):
     print("Available --db options (use --dburi to override)")
-    for macro in sorted(file_config.options('db')):
-        print("%20s\t%s" % (macro, file_config.get('db', macro)))
+    for macro in sorted(file_config.options("db")):
+        print("%20s\t%s" % (macro, file_config.get("db", macro)))
     sys.exit(0)
 
 
@@ -202,11 +269,12 @@ def _requirements_opt(opt_str, value, parser):
 
 
 def _exclude_tag(opt_str, value, parser):
-    exclude_tags.add(value.replace('-', '_'))
+    exclude_tags.add(value.replace("-", "_"))
 
 
 def _include_tag(opt_str, value, parser):
-    include_tags.add(value.replace('-', '_'))
+    include_tags.add(value.replace("-", "_"))
+
 
 pre_configure = []
 post_configure = []
@@ -228,12 +296,12 @@ def _setup_options(opt, file_config):
     options = opt
 
 
-
 @pre
 def _monkeypatch_cdecimal(options, file_config):
     if options.cdecimal:
         import cdecimal
-        sys.modules['decimal'] = cdecimal
+
+        sys.modules["decimal"] = cdecimal
 
 
 @post
@@ -248,26 +316,27 @@ def _engine_uri(options, file_config):
 
     if options.db:
         for db_token in options.db:
-            for db in re.split(r'[,\s]+', db_token):
-                if db not in file_config.options('db'):
+            for db in re.split(r"[,\s]+", db_token):
+                if db not in file_config.options("db"):
                     raise RuntimeError(
                         "Unknown URI specifier '%s'.  "
-                        "Specify --dbs for known uris."
-                        % db)
+                        "Specify --dbs for known uris." % db
+                    )
                 else:
-                    db_urls.append(file_config.get('db', db))
+                    db_urls.append(file_config.get("db", db))
 
     if not db_urls:
-        db_urls.append(file_config.get('db', 'default'))
+        db_urls.append(file_config.get("db", "default"))
 
     for db_url in db_urls:
 
-        if options.write_idents and provision.FOLLOWER_IDENT: # != 'master':
+        if options.write_idents and provision.FOLLOWER_IDENT:  # != 'master':
             with open(options.write_idents, "a") as file_:
                 file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n")
 
         cfg = provision.setup_config(
-            db_url, options, file_config, provision.FOLLOWER_IDENT)
+            db_url, options, file_config, provision.FOLLOWER_IDENT
+        )
 
         if not config._current:
             cfg.set_as_current(cfg)
@@ -276,7 +345,7 @@ def _engine_uri(options, file_config):
 @post
 def _requirements(options, file_config):
 
-    requirement_cls = file_config.get('sqla_testing', "requirement_cls")
+    requirement_cls = file_config.get("sqla_testing", "requirement_cls")
     _setup_requirements(requirement_cls)
 
 
@@ -317,56 +386,75 @@ def _prep_testing_database(options, file_config):
                 pass
             else:
                 for vname in view_names:
-                    e.execute(schema._DropView(
-                        schema.Table(vname, schema.MetaData())
-                    ))
+                    e.execute(
+                        schema._DropView(
+                            schema.Table(vname, schema.MetaData())
+                        )
+                    )
 
             if config.requirements.schemas.enabled_for_config(cfg):
                 try:
-                    view_names = inspector.get_view_names(
-                        schema="test_schema")
+                    view_names = inspector.get_view_names(schema="test_schema")
                 except NotImplementedError:
                     pass
                 else:
                     for vname in view_names:
-                        e.execute(schema._DropView(
-                            schema.Table(vname, schema.MetaData(),
-                                         schema="test_schema")
-                        ))
-
-            for tname in reversed(inspector.get_table_names(
-                    order_by="foreign_key")):
-                e.execute(schema.DropTable(
-                    schema.Table(tname, schema.MetaData())
-                ))
+                        e.execute(
+                            schema._DropView(
+                                schema.Table(
+                                    vname,
+                                    schema.MetaData(),
+                                    schema="test_schema",
+                                )
+                            )
+                        )
+
+            for tname in reversed(
+                inspector.get_table_names(order_by="foreign_key")
+            ):
+                e.execute(
+                    schema.DropTable(schema.Table(tname, schema.MetaData()))
+                )
 
             if config.requirements.schemas.enabled_for_config(cfg):
-                for tname in reversed(inspector.get_table_names(
-                        order_by="foreign_key", schema="test_schema")):
-                    e.execute(schema.DropTable(
-                        schema.Table(tname, schema.MetaData(),
-                                     schema="test_schema")
-                    ))
+                for tname in reversed(
+                    inspector.get_table_names(
+                        order_by="foreign_key", schema="test_schema"
+                    )
+                ):
+                    e.execute(
+                        schema.DropTable(
+                            schema.Table(
+                                tname, schema.MetaData(), schema="test_schema"
+                            )
+                        )
+                    )
 
             if against(cfg, "postgresql") and util.sqla_100:
                 from sqlalchemy.dialects import postgresql
+
                 for enum in inspector.get_enums("*"):
-                    e.execute(postgresql.DropEnumType(
-                        postgresql.ENUM(
-                            name=enum['name'],
-                            schema=enum['schema'])))
+                    e.execute(
+                        postgresql.DropEnumType(
+                            postgresql.ENUM(
+                                name=enum["name"], schema=enum["schema"]
+                            )
+                        )
+                    )
 
 
 @post
 def _reverse_topological(options, file_config):
     if options.reversetop:
         from sqlalchemy.orm.util import randomize_unitofwork
+
         randomize_unitofwork()
 
 
 @post
 def _post_setup_options(opt, file_config):
     from alembic.testing import config
+
     config.options = options
     config.file_config = file_config
 
@@ -374,10 +462,11 @@ def _post_setup_options(opt, file_config):
 def want_class(cls):
     if not issubclass(cls, fixtures.TestBase):
         return False
-    elif cls.__name__.startswith('_'):
+    elif cls.__name__.startswith("_"):
         return False
-    elif config.options.backend_only and not getattr(cls, '__backend__',
-                                                     False):
+    elif config.options.backend_only and not getattr(
+        cls, "__backend__", False
+    ):
         return False
     else:
         return True
@@ -390,25 +479,28 @@ def want_method(cls, fn):
         return False
     elif include_tags:
         return (
-            hasattr(cls, '__tags__') and
-            exclusions.tags(cls.__tags__).include_test(
-                include_tags, exclude_tags)
+            hasattr(cls, "__tags__")
+            and exclusions.tags(cls.__tags__).include_test(
+                include_tags, exclude_tags
+            )
         ) or (
-            hasattr(fn, '_sa_exclusion_extend') and
-            fn._sa_exclusion_extend.include_test(
-                include_tags, exclude_tags)
+            hasattr(fn, "_sa_exclusion_extend")
+            and fn._sa_exclusion_extend.include_test(
+                include_tags, exclude_tags
+            )
         )
-    elif exclude_tags and hasattr(cls, '__tags__'):
+    elif exclude_tags and hasattr(cls, "__tags__"):
         return exclusions.tags(cls.__tags__).include_test(
-            include_tags, exclude_tags)
-    elif exclude_tags and hasattr(fn, '_sa_exclusion_extend'):
+            include_tags, exclude_tags
+        )
+    elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"):
         return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
     else:
         return True
 
 
 def generate_sub_tests(cls, module):
-    if getattr(cls, '__backend__', False):
+    if getattr(cls, "__backend__", False):
         for cfg in _possible_configs_for_cls(cls):
             orig_name = cls.__name__
 
@@ -416,17 +508,14 @@ def generate_sub_tests(cls, module):
             # pytest junit plugin, which is tripped up by the brackets
             # and periods, so sanitize
 
-            alpha_name = re.sub(r'[_\[\]\.]+', '_', cfg.name)
-            alpha_name = re.sub('_+$', '', alpha_name)
+            alpha_name = re.sub(r"[_\[\]\.]+", "_", cfg.name)
+            alpha_name = re.sub("_+$", "", alpha_name)
             name = "%s_%s" % (cls.__name__, alpha_name)
 
             subcls = type(
                 name,
-                (cls, ),
-                {
-                    "_sa_orig_cls_name": orig_name,
-                    "__only_on_config__": cfg
-                }
+                (cls,),
+                {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg},
             )
             setattr(module, name, subcls)
             yield subcls
@@ -440,8 +529,8 @@ def start_test_class(cls):
 
 
 def stop_test_class(cls):
-    #from sqlalchemy import inspect
-    #assert not inspect(testing.db).get_table_names()
+    # from sqlalchemy import inspect
+    # assert not inspect(testing.db).get_table_names()
     _restore_engine()
 
 
@@ -450,7 +539,7 @@ def _restore_engine():
 
 
 def _setup_engine(cls):
-    if getattr(cls, '__engine_options__', None):
+    if getattr(cls, "__engine_options__", None):
         eng = engines.testing_engine(options=cls.__engine_options__)
         config._current.push_engine(eng)
 
@@ -472,16 +561,16 @@ def _possible_configs_for_cls(cls, reasons=None):
             if spec(config_obj):
                 all_configs.remove(config_obj)
 
-    if getattr(cls, '__only_on__', None):
+    if getattr(cls, "__only_on__", None):
         spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
         for config_obj in list(all_configs):
             if not spec(config_obj):
                 all_configs.remove(config_obj)
 
-    if getattr(cls, '__only_on_config__', None):
+    if getattr(cls, "__only_on_config__", None):
         all_configs.intersection_update([cls.__only_on_config__])
 
-    if hasattr(cls, '__requires__'):
+    if hasattr(cls, "__requires__"):
         requirements = config.requirements
         for config_obj in list(all_configs):
             for requirement in cls.__requires__:
@@ -494,7 +583,7 @@ def _possible_configs_for_cls(cls, reasons=None):
                         reasons.extend(skip_reasons)
                     break
 
-    if hasattr(cls, '__prefer_requires__'):
+    if hasattr(cls, "__prefer_requires__"):
         non_preferred = set()
         requirements = config.requirements
         for config_obj in list(all_configs):
@@ -513,30 +602,32 @@ def _do_skips(cls):
     reasons = []
     all_configs = _possible_configs_for_cls(cls, reasons)
 
-    if getattr(cls, '__skip_if__', False):
-        for c in getattr(cls, '__skip_if__'):
+    if getattr(cls, "__skip_if__", False):
+        for c in getattr(cls, "__skip_if__"):
             if c():
-                raise SkipTest("'%s' skipped by %s" % (
-                    cls.__name__, c.__name__)
+                raise SkipTest(
+                    "'%s' skipped by %s" % (cls.__name__, c.__name__)
                 )
 
     if not all_configs:
         msg = "'%s' unsupported on any DB implementation %s%s" % (
             cls.__name__,
             ", ".join(
-                "'%s(%s)+%s'" % (
+                "'%s(%s)+%s'"
+                % (
                     config_obj.db.name,
                     ".".join(
-                        str(dig) for dig in
-                        config_obj.db.dialect.server_version_info),
-                    config_obj.db.driver
+                        str(dig)
+                        for dig in config_obj.db.dialect.server_version_info
+                    ),
+                    config_obj.db.driver,
                 )
-              for config_obj in config.Config.all_configs()
+                for config_obj in config.Config.all_configs()
             ),
-            ", ".join(reasons)
+            ", ".join(reasons),
         )
         raise SkipTest(msg)
-    elif hasattr(cls, '__prefer_backends__'):
+    elif hasattr(cls, "__prefer_backends__"):
         non_preferred = set()
         spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
         for config_obj in all_configs:
index 4d0f340d5ec635cf069e224b87516d29e8e18613..cc5b69ff6bf30b1e5934b8fb7a49e89a207edb3b 100644 (file)
@@ -21,6 +21,7 @@ import os
 
 try:
     import xdist  # noqa
+
     has_xdist = True
 except ImportError:
     has_xdist = False
@@ -32,30 +33,42 @@ def pytest_addoption(parser):
     def make_option(name, **kw):
         callback_ = kw.pop("callback", None)
         if callback_:
+
             class CallableAction(argparse.Action):
-                def __call__(self, parser, namespace,
-                             values, option_string=None):
+                def __call__(
+                    self, parser, namespace, values, option_string=None
+                ):
                     callback_(option_string, values, parser)
+
             kw["action"] = CallableAction
 
         zeroarg_callback = kw.pop("zeroarg_callback", None)
         if zeroarg_callback:
+
             class CallableAction(argparse.Action):
-                def __init__(self, option_strings,
-                             dest, default=False,
-                             required=False, help=None):
-                        super(CallableAction, self).__init__(
-                            option_strings=option_strings,
-                            dest=dest,
-                            nargs=0,
-                            const=True,
-                            default=default,
-                            required=required,
-                            help=help)
-
-                def __call__(self, parser, namespace,
-                             values, option_string=None):
+                def __init__(
+                    self,
+                    option_strings,
+                    dest,
+                    default=False,
+                    required=False,
+                    help=None,
+                ):
+                    super(CallableAction, self).__init__(
+                        option_strings=option_strings,
+                        dest=dest,
+                        nargs=0,
+                        const=True,
+                        default=default,
+                        required=required,
+                        help=help,
+                    )
+
+                def __call__(
+                    self, parser, namespace, values, option_string=None
+                ):
                     zeroarg_callback(option_string, values, parser)
+
             kw["action"] = CallableAction
 
         group.addoption(name, **kw)
@@ -67,23 +80,24 @@ def pytest_addoption(parser):
 def pytest_configure(config):
     if hasattr(config, "slaveinput"):
         plugin_base.restore_important_follower_config(config.slaveinput)
-        plugin_base.configure_follower(
-            config.slaveinput["follower_ident"]
-        )
+        plugin_base.configure_follower(config.slaveinput["follower_ident"])
     else:
-        if config.option.write_idents and \
-                os.path.exists(config.option.write_idents):
+        if config.option.write_idents and os.path.exists(
+            config.option.write_idents
+        ):
             os.remove(config.option.write_idents)
 
     plugin_base.pre_begin(config.option)
 
-    plugin_base.set_coverage_flag(bool(getattr(config.option,
-                                               "cov_source", False)))
+    plugin_base.set_coverage_flag(
+        bool(getattr(config.option, "cov_source", False))
+    )
 
 
 def pytest_sessionstart(session):
     plugin_base.post_begin()
 
+
 if has_xdist:
     import uuid
 
@@ -95,10 +109,12 @@ if has_xdist:
 
         node.slaveinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
         from alembic.testing import provision
+
         provision.create_follower_db(node.slaveinput["follower_ident"])
 
     def pytest_testnodedown(node, error):
         from alembic.testing import provision
+
         provision.drop_follower_db(node.slaveinput["follower_ident"])
 
 
@@ -115,18 +131,19 @@ def pytest_collection_modifyitems(session, config, items):
 
     rebuilt_items = collections.defaultdict(list)
     items[:] = [
-        item for item in
-        items if isinstance(item.parent, pytest.Instance)]
+        item for item in items if isinstance(item.parent, pytest.Instance)
+    ]
     test_classes = set(item.parent for item in items)
     for test_class in test_classes:
         for sub_cls in plugin_base.generate_sub_tests(
-                test_class.cls, test_class.parent.module):
+            test_class.cls, test_class.parent.module
+        ):
             if sub_cls is not test_class.cls:
                 list_ = rebuilt_items[test_class.cls]
 
                 for inst in pytest.Class(
-                        sub_cls.__name__,
-                        parent=test_class.parent.parent).collect():
+                    sub_cls.__name__, parent=test_class.parent.parent
+                ).collect():
                     list_.extend(inst.collect())
 
     newitems = []
@@ -139,23 +156,29 @@ def pytest_collection_modifyitems(session, config, items):
 
     # seems like the functions attached to a test class aren't sorted already?
     # is that true and why's that? (when using unittest, they're sorted)
-    items[:] = sorted(newitems, key=lambda item: (
-        item.parent.parent.parent.name,
-        item.parent.parent.name,
-        item.name
-    ))
+    items[:] = sorted(
+        newitems,
+        key=lambda item: (
+            item.parent.parent.parent.name,
+            item.parent.parent.name,
+            item.name,
+        ),
+    )
 
 
 def pytest_pycollect_makeitem(collector, name, obj):
     if inspect.isclass(obj) and plugin_base.want_class(obj):
         return pytest.Class(name, parent=collector)
-    elif inspect.isfunction(obj) and \
-            isinstance(collector, pytest.Instance) and \
-            plugin_base.want_method(collector.cls, obj):
+    elif (
+        inspect.isfunction(obj)
+        and isinstance(collector, pytest.Instance)
+        and plugin_base.want_method(collector.cls, obj)
+    ):
         return pytest.Function(name, parent=collector)
     else:
         return []
 
+
 _current_class = None
 
 
@@ -180,6 +203,7 @@ def pytest_runtest_setup(item):
             global _current_class
             class_teardown(item.parent.parent)
             _current_class = None
+
         item.parent.parent.addfinalizer(finalize)
 
     test_setup(item)
@@ -194,8 +218,9 @@ def pytest_runtest_teardown(item):
 
 
 def test_setup(item):
-    plugin_base.before_test(item, item.parent.module.__name__,
-                            item.parent.cls, item.name)
+    plugin_base.before_test(
+        item, item.parent.module.__name__, item.parent.cls, item.name
+    )
 
 
 def test_teardown(item):
index 05a21d371718fed7b11c3f0291064bf47c84e776..a5ce53c1b380fc6001e031dac4e2cb0b6d4b0ba2 100644 (file)
@@ -30,6 +30,7 @@ class register(object):
         def decorate(fn):
             self.fns[dbname] = fn
             return self
+
         return decorate
 
     def __call__(self, cfg, *arg):
@@ -43,7 +44,7 @@ class register(object):
         if backend in self.fns:
             return self.fns[backend](cfg, *arg)
         else:
-            return self.fns['*'](cfg, *arg)
+            return self.fns["*"](cfg, *arg)
 
 
 def create_follower_db(follower_ident):
@@ -86,9 +87,7 @@ def _configs_for_db_operation():
     for cfg in config.Config.all_configs():
         url = cfg.db.url
         backend = get_url_backend_name(url)
-        host_conf = (
-            backend,
-            url.username, url.host, url.database)
+        host_conf = (backend, url.username, url.host, url.database)
 
         if host_conf not in hosts:
             yield cfg
@@ -132,13 +131,13 @@ def _follower_url_from_main(url, ident):
 
 @_update_db_opts.for_db("mssql")
 def _mssql_update_db_opts(db_url, db_opts):
-    db_opts['legacy_schema_aliasing'] = False
+    db_opts["legacy_schema_aliasing"] = False
 
 
 @_follower_url_from_main.for_db("sqlite")
 def _sqlite_follower_url_from_main(url, ident):
     url = sa_url.make_url(url)
-    if not url.database or url.database == ':memory:':
+    if not url.database or url.database == ":memory:":
         return url
     else:
         return sa_url.make_url("sqlite:///%s.db" % ident)
@@ -154,19 +153,20 @@ def _sqlite_post_configure_engine(url, engine, follower_ident):
         # as an attached
         if not follower_ident:
             dbapi_connection.execute(
-                'ATTACH DATABASE "test_schema.db" AS test_schema')
+                'ATTACH DATABASE "test_schema.db" AS test_schema'
+            )
         else:
             dbapi_connection.execute(
                 'ATTACH DATABASE "%s_test_schema.db" AS test_schema'
-                % follower_ident)
+                % follower_ident
+            )
 
 
 @_create_db.for_db("postgresql")
 def _pg_create_db(cfg, eng, ident):
     template_db = cfg.options.postgresql_templatedb
 
-    with eng.connect().execution_options(
-            isolation_level="AUTOCOMMIT") as conn:
+    with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
         try:
             _pg_drop_db(cfg, conn, ident)
         except Exception:
@@ -222,14 +222,15 @@ def _sqlite_create_db(cfg, eng, ident):
 
 @_drop_db.for_db("postgresql")
 def _pg_drop_db(cfg, eng, ident):
-    with eng.connect().execution_options(
-            isolation_level="AUTOCOMMIT") as conn:
+    with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
         conn.execute(
             text(
                 "select pg_terminate_backend(pid) from pg_stat_activity "
                 "where usename=current_user and pid != pg_backend_pid() "
                 "and datname=:dname"
-            ), dname=ident)
+            ),
+            dname=ident,
+        )
         conn.execute("DROP DATABASE %s" % ident)
 
 
@@ -258,7 +259,7 @@ def _oracle_create_db(cfg, eng, ident):
         conn.execute("create user %s identified by xe" % ident)
         conn.execute("create user %s_ts1 identified by xe" % ident)
         conn.execute("create user %s_ts2 identified by xe" % ident)
-        conn.execute("grant dba to %s" % (ident, ))
+        conn.execute("grant dba to %s" % (ident,))
         conn.execute("grant unlimited tablespace to %s" % ident)
         conn.execute("grant unlimited tablespace to %s_ts1" % ident)
         conn.execute("grant unlimited tablespace to %s_ts2" % ident)
@@ -316,8 +317,9 @@ def reap_oracle_dbs(idents_file):
             to_reap = conn.execute(
                 "select u.username from all_users u where username "
                 "like 'TEST_%' and not exists (select username "
-                "from v$session where username=u.username)")
-            all_names = set(username.lower() for (username, ) in to_reap)
+                "from v$session where username=u.username)"
+            )
+            all_names = set(username.lower() for (username,) in to_reap)
             to_drop = set()
             for name in all_names:
                 if name.endswith("_ts1") or name.endswith("_ts2"):
@@ -334,15 +336,13 @@ def reap_oracle_dbs(idents_file):
                 if _ora_drop_ignore(conn, username):
                     dropped += 1
             log.info(
-                "Dropped %d out of %d stale databases detected",
-                dropped, total)
+                "Dropped %d out of %d stale databases detected", dropped, total
+            )
 
 
 @_follower_url_from_main.for_db("oracle")
 def _oracle_follower_url_from_main(url, ident):
     url = sa_url.make_url(url)
     url.username = ident
-    url.password = 'xe'
+    url.password = "xe"
     return url
-
-
index 400642f66ff66e90b0b26e4a79615165b8a71ac8..f25f5d706d1eed54f5252cc195fc90f25b1a331b 100644 (file)
@@ -5,6 +5,7 @@ from . import exclusions
 if util.sqla_094:
     from sqlalchemy.testing.requirements import Requirements
 else:
+
     class Requirements(object):
         pass
 
@@ -28,7 +29,7 @@ class SuiteRequirements(Requirements):
 
             insp = inspect(config.db)
             try:
-                insp.get_unique_constraints('x')
+                insp.get_unique_constraints("x")
             except NotImplementedError:
                 return True
             except TypeError:
@@ -62,83 +63,80 @@ class SuiteRequirements(Requirements):
     def fail_before_sqla_100(self):
         return exclusions.fails_if(
             lambda config: not util.sqla_100,
-            "SQLAlchemy 1.0.0 or greater required"
+            "SQLAlchemy 1.0.0 or greater required",
         )
 
     @property
     def fail_before_sqla_1010(self):
         return exclusions.fails_if(
             lambda config: not util.sqla_1010,
-            "SQLAlchemy 1.0.10 or greater required"
+            "SQLAlchemy 1.0.10 or greater required",
         )
 
     @property
     def fail_before_sqla_099(self):
         return exclusions.fails_if(
             lambda config: not util.sqla_099,
-            "SQLAlchemy 0.9.9 or greater required"
+            "SQLAlchemy 0.9.9 or greater required",
         )
 
     @property
     def fail_before_sqla_110(self):
         return exclusions.fails_if(
             lambda config: not util.sqla_110,
-            "SQLAlchemy 1.1.0 or greater required"
+            "SQLAlchemy 1.1.0 or greater required",
         )
 
     @property
     def sqlalchemy_092(self):
         return exclusions.skip_if(
             lambda config: not util.sqla_092,
-            "SQLAlchemy 0.9.2 or greater required"
+            "SQLAlchemy 0.9.2 or greater required",
         )
 
     @property
     def sqlalchemy_094(self):
         return exclusions.skip_if(
             lambda config: not util.sqla_094,
-            "SQLAlchemy 0.9.4 or greater required"
+            "SQLAlchemy 0.9.4 or greater required",
         )
 
     @property
     def sqlalchemy_099(self):
         return exclusions.skip_if(
             lambda config: not util.sqla_099,
-            "SQLAlchemy 0.9.9 or greater required"
+            "SQLAlchemy 0.9.9 or greater required",
         )
 
     @property
     def sqlalchemy_100(self):
         return exclusions.skip_if(
             lambda config: not util.sqla_100,
-            "SQLAlchemy 1.0.0 or greater required"
+            "SQLAlchemy 1.0.0 or greater required",
         )
 
     @property
     def sqlalchemy_1014(self):
         return exclusions.skip_if(
             lambda config: not util.sqla_1014,
-            "SQLAlchemy 1.0.14 or greater required"
+            "SQLAlchemy 1.0.14 or greater required",
         )
 
     @property
     def sqlalchemy_1115(self):
         return exclusions.skip_if(
             lambda config: not util.sqla_1115,
-            "SQLAlchemy 1.1.15 or greater required"
+            "SQLAlchemy 1.1.15 or greater required",
         )
 
     @property
     def sqlalchemy_110(self):
         return exclusions.skip_if(
             lambda config: not util.sqla_110,
-            "SQLAlchemy 1.1.0 or greater required"
+            "SQLAlchemy 1.1.0 or greater required",
         )
 
     @property
     def pep3147(self):
 
-        return exclusions.only_if(
-            lambda config: util.compat.has_pep3147()
-        )
-
+        return exclusions.only_if(lambda config: util.compat.has_pep3147())
index d4adbcf85593c5198de0e465eebba08108599f06..46236a04bb637daaaa50a6ed3dc072958b3ef843 100644 (file)
@@ -45,4 +45,4 @@ def setup_py_test():
     to nose.
 
     """
-    nose.main(addplugins=[NoseSQLAlchemy()], argv=['runner'])
+    nose.main(addplugins=[NoseSQLAlchemy()], argv=["runner"])
index 466dea301077bfc9b540c3bab8e4dd0a568d4dbe..b2b3476431713de886dd9c02ad5d55a2d031eab3 100644 (file)
@@ -10,7 +10,7 @@ def provide_metadata(fn, *args, **kw):
 
     metadata = schema.MetaData(config.db)
     self = args[0]
-    prev_meta = getattr(self, 'metadata', None)
+    prev_meta = getattr(self, "metadata", None)
     self.metadata = metadata
     try:
         return fn(*args, **kw)
index de91778593f626c9548a2b872e862310cbfc5290..cb59a64ab42a2803a8e7ad60acb6be5d15e4b8da 100644 (file)
@@ -17,11 +17,12 @@ import re
 
 def setup_filters():
     """Set global warning behavior for the test suite."""
-    warnings.filterwarnings('ignore',
-                            category=sa_exc.SAPendingDeprecationWarning)
-    warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
-    warnings.filterwarnings('error', category=sa_exc.SAWarning)
-    warnings.filterwarnings('error', category=DeprecationWarning)
+    warnings.filterwarnings(
+        "ignore", category=sa_exc.SAPendingDeprecationWarning
+    )
+    warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning)
+    warnings.filterwarnings("error", category=sa_exc.SAWarning)
+    warnings.filterwarnings("error", category=DeprecationWarning)
 
 
 def assert_warnings(fn, warning_msgs, regex=False):
index 1e6c64561b9fc037cb29de6711926e7538e44822..e28f71528062b463caa8946b8471c23a37e6ff74 100644 (file)
@@ -1,17 +1,45 @@
 from .langhelpers import (  # noqa
-    asbool, rev_id, to_tuple, to_list, memoized_property, dedupe_tuple,
-    immutabledict, _with_legacy_names, Dispatcher, ModuleClsProxy)
+    asbool,
+    rev_id,
+    to_tuple,
+    to_list,
+    memoized_property,
+    dedupe_tuple,
+    immutabledict,
+    _with_legacy_names,
+    Dispatcher,
+    ModuleClsProxy,
+)
 from .messaging import (  # noqa
-    write_outstream, status, err, obfuscate_url_pw, warn, msg, format_as_comma)
+    write_outstream,
+    status,
+    err,
+    obfuscate_url_pw,
+    warn,
+    msg,
+    format_as_comma,
+)
 from .pyfiles import (  # noqa
-    template_to_file, coerce_resource_to_filename,
-    pyc_file_from_path, load_python_file, edit)
+    template_to_file,
+    coerce_resource_to_filename,
+    pyc_file_from_path,
+    load_python_file,
+    edit,
+)
 from .sqla_compat import (  # noqa
-    sqla_09, sqla_092, sqla_094, sqla_099, sqla_100, sqla_105, sqla_110, sqla_1010,
-    sqla_1014, sqla_1115)
+    sqla_09,
+    sqla_092,
+    sqla_094,
+    sqla_099,
+    sqla_100,
+    sqla_105,
+    sqla_110,
+    sqla_1010,
+    sqla_1014,
+    sqla_1115,
+)
 from .exc import CommandError
 
 
 if not sqla_09:
-    raise CommandError(
-        "SQLAlchemy 0.9.0 or greater is required. ")
+    raise CommandError("SQLAlchemy 0.9.0 or greater is required. ")
index dec2ca8fecbaeb1f2015e9b7c55aecd8e38ac25a..7e07ed4795b51131a44bacfee1c46fe93f642831 100644 (file)
@@ -19,12 +19,13 @@ else:
 
 if py3k:
     import builtins as compat_builtins
-    string_types = str,
+
+    string_types = (str,)
     binary_type = bytes
     text_type = str
 
     def callable(fn):
-        return hasattr(fn, '__call__')
+        return hasattr(fn, "__call__")
 
     def u(s):
         return s
@@ -35,7 +36,8 @@ if py3k:
     range = range
 else:
     import __builtin__ as compat_builtins
-    string_types = basestring,
+
+    string_types = (basestring,)
     binary_type = str
     text_type = unicode
     callable = callable
@@ -55,16 +57,17 @@ else:
 
 if py3k:
     import collections
+
     ArgSpec = collections.namedtuple(
-        "ArgSpec",
-        ["args", "varargs", "keywords", "defaults"])
+        "ArgSpec", ["args", "varargs", "keywords", "defaults"]
+    )
 
     from inspect import getfullargspec as inspect_getfullargspec
 
     def inspect_getargspec(func):
-        return ArgSpec(
-            *inspect_getfullargspec(func)[0:4]
-        )
+        return ArgSpec(*inspect_getfullargspec(func)[0:4])
+
+
 else:
     from inspect import getargspec as inspect_getargspec  # noqa
 
@@ -72,14 +75,20 @@ if py35:
     from inspect import formatannotation
 
     def inspect_formatargspec(
-            args, varargs=None, varkw=None, defaults=None,
-            kwonlyargs=(), kwonlydefaults={}, annotations={},
-            formatarg=str,
-            formatvarargs=lambda name: '*' + name,
-            formatvarkw=lambda name: '**' + name,
-            formatvalue=lambda value: '=' + repr(value),
-            formatreturns=lambda text: ' -> ' + text,
-            formatannotation=formatannotation):
+        args,
+        varargs=None,
+        varkw=None,
+        defaults=None,
+        kwonlyargs=(),
+        kwonlydefaults={},
+        annotations={},
+        formatarg=str,
+        formatvarargs=lambda name: "*" + name,
+        formatvarkw=lambda name: "**" + name,
+        formatvalue=lambda value: "=" + repr(value),
+        formatreturns=lambda text: " -> " + text,
+        formatannotation=formatannotation,
+    ):
         """Copy formatargspec from python 3.7 standard library.
 
         Python 3 has deprecated formatargspec and requested that Signature
@@ -93,8 +102,9 @@ if py35:
         def formatargandannotation(arg):
             result = formatarg(arg)
             if arg in annotations:
-                result += ': ' + formatannotation(annotations[arg])
+                result += ": " + formatannotation(annotations[arg])
             return result
+
         specs = []
         if defaults:
             firstdefault = len(args) - len(defaults)
@@ -107,7 +117,7 @@ if py35:
             specs.append(formatvarargs(formatargandannotation(varargs)))
         else:
             if kwonlyargs:
-                specs.append('*')
+                specs.append("*")
         if kwonlyargs:
             for kwonlyarg in kwonlyargs:
                 spec = formatargandannotation(kwonlyarg)
@@ -116,11 +126,12 @@ if py35:
                 specs.append(spec)
         if varkw is not None:
             specs.append(formatvarkw(formatargandannotation(varkw)))
-        result = '(' + ', '.join(specs) + ')'
-        if 'return' in annotations:
-            result += formatreturns(formatannotation(annotations['return']))
+        result = "(" + ", ".join(specs) + ")"
+        if "return" in annotations:
+            result += formatreturns(formatannotation(annotations["return"]))
         return result
 
+
 else:
     from inspect import formatargspec as inspect_formatargspec
 
@@ -151,22 +162,27 @@ if py35:
         spec.loader.exec_module(module)
         return module
 
+
 elif py3k:
     import importlib.machinery
 
     def load_module_py(module_id, path):
         module = importlib.machinery.SourceFileLoader(
-            module_id, path).load_module(module_id)
+            module_id, path
+        ).load_module(module_id)
         del sys.modules[module_id]
         return module
 
     def load_module_pyc(module_id, path):
         module = importlib.machinery.SourcelessFileLoader(
-            module_id, path).load_module(module_id)
+            module_id, path
+        ).load_module(module_id)
         del sys.modules[module_id]
         return module
 
+
 if py3k:
+
     def get_bytecode_suffixes():
         try:
             return importlib.machinery.BYTECODE_SUFFIXES
@@ -188,13 +204,15 @@ if py3k:
         # http://www.python.org/dev/peps/pep-3147/#detecting-pep-3147-availability
 
         import imp
-        return hasattr(imp, 'get_tag')
+
+        return hasattr(imp, "get_tag")
+
 
 else:
     import imp
 
     def load_module_py(module_id, path):  # noqa
-        with open(path, 'rb') as fp:
+        with open(path, "rb") as fp:
             mod = imp.load_source(module_id, path, fp)
             if py2k:
                 source_encoding = parse_encoding(fp)
@@ -204,7 +222,7 @@ else:
             return mod
 
     def load_module_pyc(module_id, path):  # noqa
-        with open(path, 'rb') as fp:
+        with open(path, "rb") as fp:
             mod = imp.load_compiled(module_id, path, fp)
             # no source encoding here
             del sys.modules[module_id]
@@ -219,12 +237,14 @@ else:
     def has_pep3147():
         return False
 
+
 try:
-    exec_ = getattr(compat_builtins, 'exec')
+    exec_ = getattr(compat_builtins, "exec")
 except AttributeError:
     # Python 2
     def exec_(func_text, globals_, lcl):
-        exec('exec func_text in globals_, lcl')
+        exec("exec func_text in globals_, lcl")
+
 
 ################################################
 # cross-compatible metaclass implementation
@@ -234,9 +254,12 @@ except AttributeError:
 def with_metaclass(meta, base=object):
     """Create a base class with a metaclass."""
     return meta("%sBase" % meta.__name__, (base,), {})
+
+
 ################################################
 
 if py3k:
+
     def reraise(tp, value, tb=None, cause=None):
         if cause is not None:
             value.__cause__ = cause
@@ -249,9 +272,13 @@ if py3k:
             exc_info = sys.exc_info()
         exc_type, exc_value, exc_tb = exc_info
         reraise(type(exception), exception, tb=exc_tb, cause=exc_value)
+
+
 else:
-    exec("def reraise(tp, value, tb=None, cause=None):\n"
-         "    raise tp, value, tb\n")
+    exec(
+        "def reraise(tp, value, tb=None, cause=None):\n"
+        "    raise tp, value, tb\n"
+    )
 
     def raise_from_cause(exception, exc_info=None):
         # not as nice as that of Py3K, but at least preserves
@@ -261,14 +288,15 @@ else:
         exc_type, exc_value, exc_tb = exc_info
         reraise(type(exception), exception, tb=exc_tb)
 
+
 # produce a wrapper that allows encoded text to stream
 # into a given buffer, but doesn't close it.
 # not sure of a more idiomatic approach to this.
 class EncodedIO(io.TextIOWrapper):
-
     def close(self):
         pass
 
+
 if py2k:
     # in Py2K, the io.* package is awkward because it does not
     # easily wrap the file type (e.g. sys.stdout) and I can't
@@ -303,7 +331,7 @@ if py2k:
             return self.file_.flush()
 
     class EncodedIO(EncodedIO):
-
         def __init__(self, file_, encoding):
             super(EncodedIO, self).__init__(
-                ActLikePy3kIO(file_), encoding=encoding)
+                ActLikePy3kIO(file_), encoding=encoding
+            )
index 832332cbe76647aeeb01b09a55d922b30ca6fe4f..a298cc07e8ba82ffc935e3bfa09ef307e14b1978 100644 (file)
@@ -37,23 +37,21 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
     def _install_proxy(self):
         attr_names, modules = self._setups[self.__class__]
         for globals_, locals_ in modules:
-            globals_['_proxy'] = self
+            globals_["_proxy"] = self
             for attr_name in attr_names:
                 globals_[attr_name] = getattr(self, attr_name)
 
     def _remove_proxy(self):
         attr_names, modules = self._setups[self.__class__]
         for globals_, locals_ in modules:
-            globals_['_proxy'] = None
+            globals_["_proxy"] = None
             for attr_name in attr_names:
                 del globals_[attr_name]
 
     @classmethod
     def create_module_class_proxy(cls, globals_, locals_):
         attr_names, modules = cls._setups[cls]
-        modules.append(
-            (globals_, locals_)
-        )
+        modules.append((globals_, locals_))
         cls._setup_proxy(globals_, locals_, attr_names)
 
     @classmethod
@@ -63,11 +61,12 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
 
     @classmethod
     def _add_proxied_attribute(cls, methname, globals_, locals_, attr_names):
-        if not methname.startswith('_'):
+        if not methname.startswith("_"):
             meth = getattr(cls, methname)
             if callable(meth):
                 locals_[methname] = cls._create_method_proxy(
-                    methname, globals_, locals_)
+                    methname, globals_, locals_
+                )
             else:
                 attr_names.add(methname)
 
@@ -75,7 +74,7 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
     def _create_method_proxy(cls, name, globals_, locals_):
         fn = getattr(cls, name)
         spec = inspect_getargspec(fn)
-        if spec[0] and spec[0][0] == 'self':
+        if spec[0] and spec[0][0] == "self":
             spec[0].pop(0)
         args = inspect_formatargspec(*spec)
         num_defaults = 0
@@ -83,24 +82,28 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
             num_defaults += len(spec[3])
         name_args = spec[0]
         if num_defaults:
-            defaulted_vals = name_args[0 - num_defaults:]
+            defaulted_vals = name_args[0 - num_defaults :]
         else:
             defaulted_vals = ()
 
         apply_kw = inspect_formatargspec(
-            name_args, spec[1], spec[2],
+            name_args,
+            spec[1],
+            spec[2],
             defaulted_vals,
-            formatvalue=lambda x: '=' + x)
+            formatvalue=lambda x: "=" + x,
+        )
 
         def _name_error(name):
             raise NameError(
                 "Can't invoke function '%s', as the proxy object has "
                 "not yet been "
                 "established for the Alembic '%s' class.  "
-                "Try placing this code inside a callable." % (
-                    name, cls.__name__
-                ))
-        globals_['_name_error'] = _name_error
+                "Try placing this code inside a callable."
+                % (name, cls.__name__)
+            )
+
+        globals_["_name_error"] = _name_error
 
         translations = getattr(fn, "_legacy_translations", [])
         if translations:
@@ -108,7 +111,7 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
             translate_str = "args, kw = _translate(%r, %r, %r, args, kw)" % (
                 fn.__name__,
                 tuple(spec),
-                translations
+                translations,
             )
 
             def translate(fn_name, spec, translations, args, kw):
@@ -119,15 +122,14 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
                     if oldname in kw:
                         warnings.warn(
                             "Argument %r is now named %r "
-                            "for method %s()." % (
-                                oldname, newname, fn_name
-                            ))
+                            "for method %s()." % (oldname, newname, fn_name)
+                        )
                         return_kw[newname] = kw.pop(oldname)
                 return_kw.update(kw)
 
                 args = list(args)
                 if spec[3]:
-                    pos_only = spec[0][:-len(spec[3])]
+                    pos_only = spec[0][: -len(spec[3])]
                 else:
                     pos_only = spec[0]
                 for arg in pos_only:
@@ -137,17 +139,20 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
                         except IndexError:
                             raise TypeError(
                                 "missing required positional argument: %s"
-                                % arg)
+                                % arg
+                            )
                 return_args.extend(args)
 
                 return return_args, return_kw
-            globals_['_translate'] = translate
+
+            globals_["_translate"] = translate
         else:
             outer_args = args[1:-1]
             inner_args = apply_kw[1:-1]
             translate_str = ""
 
-        func_text = textwrap.dedent("""\
+        func_text = textwrap.dedent(
+            """\
         def %(name)s(%(args)s):
             %(doc)r
             %(translate)s
@@ -157,13 +162,15 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
                 _name_error('%(name)s')
             return _proxy.%(name)s(%(apply_kw)s)
             e
-        """ % {
-            'name': name,
-            'translate': translate_str,
-            'args': outer_args,
-            'apply_kw': inner_args,
-            'doc': fn.__doc__,
-        })
+        """
+            % {
+                "name": name,
+                "translate": translate_str,
+                "args": outer_args,
+                "apply_kw": inner_args,
+                "doc": fn.__doc__,
+            }
+        )
         lcl = {}
         exec_(func_text, globals_, lcl)
         return lcl[name]
@@ -178,8 +185,7 @@ def _with_legacy_names(translations):
 
 
 def asbool(value):
-    return value is not None and \
-        value.lower() == 'true'
+    return value is not None and value.lower() == "true"
 
 
 def rev_id():
@@ -201,31 +207,30 @@ def to_tuple(x, default=None):
     if x is None:
         return default
     elif isinstance(x, string_types):
-        return (x, )
+        return (x,)
     elif isinstance(x, collections_abc.Iterable):
         return tuple(x)
     else:
-        return (x, )
+        return (x,)
 
 
 def unique_list(seq, hashfunc=None):
     seen = set()
     seen_add = seen.add
     if not hashfunc:
-        return [x for x in seq
-                if x not in seen
-                and not seen_add(x)]
+        return [x for x in seq if x not in seen and not seen_add(x)]
     else:
-        return [x for x in seq
-                if hashfunc(x) not in seen
-                and not seen_add(hashfunc(x))]
+        return [
+            x
+            for x in seq
+            if hashfunc(x) not in seen and not seen_add(hashfunc(x))
+        ]
 
 
 def dedupe_tuple(tup):
     return tuple(unique_list(tup))
 
 
-
 class memoized_property(object):
 
     """A read-only @property that is only evaluated once."""
@@ -243,13 +248,12 @@ class memoized_property(object):
 
 
 class immutabledict(dict):
-
     def _immutable(self, *arg, **kw):
         raise TypeError("%s object is immutable" % self.__class__.__name__)
 
-    __delitem__ = __setitem__ = __setattr__ = \
-        clear = pop = popitem = setdefault = \
-        update = _immutable
+    __delitem__ = (
+        __setitem__
+    ) = __setattr__ = clear = pop = popitem = setdefault = update = _immutable
 
     def __new__(cls, *args):
         new = dict.__new__(cls)
@@ -260,7 +264,7 @@ class immutabledict(dict):
         pass
 
     def __reduce__(self):
-        return immutabledict, (dict(self), )
+        return immutabledict, (dict(self),)
 
     def union(self, d):
         if not self:
@@ -279,7 +283,7 @@ class Dispatcher(object):
         self._registry = {}
         self.uselist = uselist
 
-    def dispatch_for(self, target, qualifier='default'):
+    def dispatch_for(self, target, qualifier="default"):
         def decorate(fn):
             if self.uselist:
                 self._registry.setdefault((target, qualifier), []).append(fn)
@@ -287,9 +291,10 @@ class Dispatcher(object):
                 assert (target, qualifier) not in self._registry
                 self._registry[(target, qualifier)] = fn
             return fn
+
         return decorate
 
-    def dispatch(self, obj, qualifier='default'):
+    def dispatch(self, obj, qualifier="default"):
 
         if isinstance(obj, string_types):
             targets = [obj]
@@ -299,20 +304,20 @@ class Dispatcher(object):
             targets = type(obj).__mro__
 
         for spcls in targets:
-            if qualifier != 'default' and (spcls, qualifier) in self._registry:
-                return self._fn_or_list(
-                    self._registry[(spcls, qualifier)])
-            elif (spcls, 'default') in self._registry:
-                return self._fn_or_list(
-                    self._registry[(spcls, 'default')])
+            if qualifier != "default" and (spcls, qualifier) in self._registry:
+                return self._fn_or_list(self._registry[(spcls, qualifier)])
+            elif (spcls, "default") in self._registry:
+                return self._fn_or_list(self._registry[(spcls, "default")])
         else:
             raise ValueError("no dispatch function for object: %s" % obj)
 
     def _fn_or_list(self, fn_or_list):
         if self.uselist:
+
             def go(*arg, **kw):
                 for fn in fn_or_list:
                     fn(*arg, **kw)
+
             return go
         else:
             return fn_or_list
@@ -324,8 +329,7 @@ class Dispatcher(object):
         d = Dispatcher()
         if self.uselist:
             d._registry.update(
-                (k, [fn for fn in self._registry[k]])
-                for k in self._registry
+                (k, [fn for fn in self._registry[k]]) for k in self._registry
             )
         else:
             d._registry.update(self._registry)
index 872345b52f6aec432758934927a02084a0a81b3e..44eacbfdf08aa24b24742be5889ee0328846fdc3 100644 (file)
@@ -11,16 +11,16 @@ log = logging.getLogger(__name__)
 
 if py27:
     # disable "no handler found" errors
-    logging.getLogger('alembic').addHandler(logging.NullHandler())
+    logging.getLogger("alembic").addHandler(logging.NullHandler())
 
 
 try:
     import fcntl
     import termios
     import struct
-    ioctl = fcntl.ioctl(0, termios.TIOCGWINSZ,
-                        struct.pack('HHHH', 0, 0, 0, 0))
-    _h, TERMWIDTH, _hp, _wp = struct.unpack('HHHH', ioctl)
+
+    ioctl = fcntl.ioctl(0, termios.TIOCGWINSZ, struct.pack("HHHH", 0, 0, 0, 0))
+    _h, TERMWIDTH, _hp, _wp = struct.unpack("HHHH", ioctl)
     if TERMWIDTH <= 0:  # can occur if running in emacs pseudo-tty
         TERMWIDTH = None
 except (ImportError, IOError):
@@ -28,10 +28,10 @@ except (ImportError, IOError):
 
 
 def write_outstream(stream, *text):
-    encoding = getattr(stream, 'encoding', 'ascii') or 'ascii'
+    encoding = getattr(stream, "encoding", "ascii") or "ascii"
     for t in text:
         if not isinstance(t, binary_type):
-            t = t.encode(encoding, 'replace')
+            t = t.encode(encoding, "replace")
         t = t.decode(encoding)
         try:
             stream.write(t)
@@ -62,7 +62,7 @@ def err(message):
 def obfuscate_url_pw(u):
     u = url.make_url(u)
     if u.password:
-        u.password = 'XXXXX'
+        u.password = "XXXXX"
     return str(u)
 
 
index 0e5213356e374ad8685389c04ee014b79c34f2a0..4093b89ff8248f11242c385cf06623978ac14d8b 100644 (file)
@@ -1,8 +1,12 @@
 import sys
 import os
 import re
-from .compat import load_module_py, load_module_pyc, \
-    get_current_bytecode_suffixes, has_pep3147
+from .compat import (
+    load_module_py,
+    load_module_pyc,
+    get_current_bytecode_suffixes,
+    has_pep3147,
+)
 from mako.template import Template
 from mako import exceptions
 import tempfile
@@ -14,16 +18,19 @@ def template_to_file(template_file, dest, output_encoding, **kw):
     try:
         output = template.render_unicode(**kw).encode(output_encoding)
     except:
-        with tempfile.NamedTemporaryFile(suffix='.txt', delete=False) as ntf:
+        with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as ntf:
             ntf.write(
-                exceptions.text_error_template().
-                render_unicode().encode(output_encoding))
+                exceptions.text_error_template()
+                .render_unicode()
+                .encode(output_encoding)
+            )
             fname = ntf.name
         raise CommandError(
             "Template rendering failed; see %s for a "
-            "template-oriented traceback." % fname)
+            "template-oriented traceback." % fname
+        )
     else:
-        with open(dest, 'wb') as f:
+        with open(dest, "wb") as f:
             f.write(output)
 
 
@@ -37,7 +44,8 @@ def coerce_resource_to_filename(fname):
     """
     if not os.path.isabs(fname) and ":" in fname:
         import pkg_resources
-        fname = pkg_resources.resource_filename(*fname.split(':'))
+
+        fname = pkg_resources.resource_filename(*fname.split(":"))
     return fname
 
 
@@ -48,6 +56,7 @@ def pyc_file_from_path(path):
 
     if has_pep3147():
         import imp
+
         candidate = imp.cache_from_source(path)
         if os.path.exists(candidate):
             return candidate
@@ -64,16 +73,17 @@ def edit(path):
     """Given a source path, run the EDITOR for it"""
 
     import editor
+
     try:
         editor.edit(path)
     except Exception as exc:
-        raise CommandError('Error executing editor (%s)' % (exc,))
+        raise CommandError("Error executing editor (%s)" % (exc,))
 
 
 def load_python_file(dir_, filename):
     """Load a file from the given path as a Python module."""
 
-    module_id = re.sub(r'\W', "_", filename)
+    module_id = re.sub(r"\W", "_", filename)
     path = os.path.join(dir_, filename)
     _, ext = os.path.splitext(filename)
     if ext == ".py":
index 05561244930bce34c28c7c4223b7206c75446f29..63c979865b1bfaaa8134cd6611eb85a55e4c8dc8 100644 (file)
@@ -15,8 +15,11 @@ def _safe_int(value):
         return int(value)
     except:
         return value
+
+
 _vers = tuple(
-    [_safe_int(x) for x in re.findall(r'(\d+|[abc]\d)', __version__)])
+    [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
+)
 sqla_09 = _vers >= (0, 9, 0)
 sqla_092 = _vers >= (0, 9, 2)
 sqla_094 = _vers >= (0, 9, 4)
@@ -31,7 +34,7 @@ sqla_1115 = _vers >= (1, 1, 15)
 
 
 if sqla_110:
-    AUTOINCREMENT_DEFAULT = 'auto'
+    AUTOINCREMENT_DEFAULT = "auto"
 else:
     AUTOINCREMENT_DEFAULT = True
 
@@ -55,10 +58,12 @@ def _columns_for_constraint(constraint):
 def _fk_spec(constraint):
     if sqla_100:
         source_columns = [
-            constraint.columns[key].name for key in constraint.column_keys]
+            constraint.columns[key].name for key in constraint.column_keys
+        ]
     else:
         source_columns = [
-            element.parent.name for element in constraint.elements]
+            element.parent.name for element in constraint.elements
+        ]
 
     source_table = constraint.parent.name
     source_schema = constraint.parent.schema
@@ -70,9 +75,17 @@ def _fk_spec(constraint):
     deferrable = constraint.deferrable
     initially = constraint.initially
     return (
-        source_schema, source_table,
-        source_columns, target_schema, target_table, target_columns,
-        onupdate, ondelete, deferrable, initially)
+        source_schema,
+        source_table,
+        source_columns,
+        target_schema,
+        target_table,
+        target_columns,
+        onupdate,
+        ondelete,
+        deferrable,
+        initially,
+    )
 
 
 def _fk_is_self_referential(constraint):
@@ -91,11 +104,9 @@ def _is_type_bound(constraint):
         return constraint._type_bound
     else:
         # old way, look at what we know Boolean/Enum to use
-        return (
-            constraint._create_rule is not None and
-            isinstance(
-                getattr(constraint._create_rule, "target", None),
-                sqltypes.SchemaType)
+        return constraint._create_rule is not None and isinstance(
+            getattr(constraint._create_rule, "target", None),
+            sqltypes.SchemaType,
         )
 
 
@@ -103,7 +114,7 @@ def _find_columns(clause):
     """locate Column objects within the given expression."""
 
     cols = set()
-    traverse(clause, {}, {'column': cols.add})
+    traverse(clause, {}, {"column": cols.add})
     return cols
 
 
@@ -143,7 +154,8 @@ class _textual_index_element(sql.ColumnElement):
     See SQLAlchemy issue 3174.
 
     """
-    __visit_name__ = '_textual_idx_element'
+
+    __visit_name__ = "_textual_idx_element"
 
     def __init__(self, table, text):
         self.table = table
@@ -198,7 +210,7 @@ def _get_index_final_name(dialect, idx):
 
 
 def _is_mariadb(mysql_dialect):
-    return 'MariaDB' in mysql_dialect.server_version_info
+    return "MariaDB" in mysql_dialect.server_version_info
 
 
 def _mariadb_normalized_version_info(mysql_dialect):
index 945d94a709b15f103707fef5a1b2c831a54915da..50d02c10855d27c4cf32bd29847da623de31d34a 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -15,6 +15,21 @@ identity = C4DAFEE1
 with-sqla_testing = true
 where = tests
 
+[flake8]
+show-source = true
+enable-extensions = G
+# E203 is due to https://github.com/PyCQA/pycodestyle/issues/373
+ignore =
+    A003,
+    D,
+    E203,E305,E711,E712,E721,E722,E741,
+    N801,N802,N806,
+    RST304,RST303,RST299,RST399,
+    W503,W504
+exclude = .venv,.git,.tox,dist,doc,*egg,build
+import-order-style = google
+application-import-names = alembic,tests
+
 
 [sqla_testing]
 requirement_cls=tests.requirements:DefaultRequirements
index e962f896a3eff8bbd7ceaab4f4c7acd164141a5b..70da388e1a1241774c2525b4553203a593c3b74d 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -5,23 +5,23 @@ import re
 import sys
 
 
-v = open(os.path.join(os.path.dirname(__file__), 'alembic', '__init__.py'))
-VERSION = re.compile(r".*__version__ = '(.*?)'", re.S).match(v.read()).group(1)
+v = open(os.path.join(os.path.dirname(__file__), "alembic", "__init__.py"))
+VERSION = re.compile(r""".*__version__ = ["'](.*?)["']""", re.S).match(v.read()).group(1)
 v.close()
 
 
-readme = os.path.join(os.path.dirname(__file__), 'README.rst')
+readme = os.path.join(os.path.dirname(__file__), "README.rst")
 
 requires = [
-    'SQLAlchemy>=0.9.0',
-    'Mako',
-    'python-editor>=0.3',
-    'python-dateutil'
+    "SQLAlchemy>=0.9.0",
+    "Mako",
+    "python-editor>=0.3",
+    "python-dateutil",
 ]
 
 
 class PyTest(TestCommand):
-    user_options = [('pytest-args=', 'a', "Arguments to pass to py.test")]
+    user_options = [("pytest-args=", "a", "Arguments to pass to py.test")]
 
     def initialize_options(self):
         TestCommand.initialize_options(self)
@@ -35,42 +35,42 @@ class PyTest(TestCommand):
     def run_tests(self):
         # import here, cause outside the eggs aren't loaded
         import pytest
+
         errno = pytest.main(self.pytest_args)
         sys.exit(errno)
 
 
-setup(name='alembic',
-      version=VERSION,
-      description="A database migration tool for SQLAlchemy.",
-      long_description=open(readme).read(),
-      python_requires='>=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*',
-      classifiers=[
-          'Development Status :: 5 - Production/Stable',
-          'Environment :: Console',
-          'Intended Audience :: Developers',
-          'Programming Language :: Python',
-          'Programming Language :: Python :: 2',
-          'Programming Language :: Python :: 2.7',
-          'Programming Language :: Python :: 3',
-          'Programming Language :: Python :: 3.4',
-          'Programming Language :: Python :: 3.5',
-          'Programming Language :: Python :: 3.6',
-          'Programming Language :: Python :: Implementation :: CPython',
-          'Programming Language :: Python :: Implementation :: PyPy',
-          'Topic :: Database :: Front-Ends',
-      ],
-      keywords='SQLAlchemy migrations',
-      author='Mike Bayer',
-      author_email='mike@zzzcomputing.com',
-      url='https://alembic.sqlalchemy.org',
-      license='MIT',
-      packages=find_packages('.', exclude=['examples*', 'test*']),
-      include_package_data=True,
-      tests_require=['pytest!=3.9.1,!=3.9.2', 'mock', 'Mako'],
-      cmdclass={'test': PyTest},
-      zip_safe=False,
-      install_requires=requires,
-      entry_points={
-          'console_scripts': ['alembic = alembic.config:main'],
-      }
-      )
+setup(
+    name="alembic",
+    version=VERSION,
+    description="A database migration tool for SQLAlchemy.",
+    long_description=open(readme).read(),
+    python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*",
+    classifiers=[
+        "Development Status :: 5 - Production/Stable",
+        "Environment :: Console",
+        "Intended Audience :: Developers",
+        "Programming Language :: Python",
+        "Programming Language :: Python :: 2",
+        "Programming Language :: Python :: 2.7",
+        "Programming Language :: Python :: 3",
+        "Programming Language :: Python :: 3.4",
+        "Programming Language :: Python :: 3.5",
+        "Programming Language :: Python :: 3.6",
+        "Programming Language :: Python :: Implementation :: CPython",
+        "Programming Language :: Python :: Implementation :: PyPy",
+        "Topic :: Database :: Front-Ends",
+    ],
+    keywords="SQLAlchemy migrations",
+    author="Mike Bayer",
+    author_email="mike@zzzcomputing.com",
+    url="https://alembic.sqlalchemy.org",
+    license="MIT",
+    packages=find_packages(".", exclude=["examples*", "test*"]),
+    include_package_data=True,
+    tests_require=["pytest!=3.9.1,!=3.9.2", "mock", "Mako"],
+    cmdclass={"test": PyTest},
+    zip_safe=False,
+    install_requires=requires,
+    entry_points={"console_scripts": ["alembic = alembic.config:main"]},
+)
index 94c6866a6d009aa4072ac3576cd69c894e8f8316..4bda756091e3233f2588858f9d467757c3b671be 100644 (file)
@@ -1,5 +1,18 @@
-from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
-    Numeric, CHAR, ForeignKey, Index, UniqueConstraint, CheckConstraint, text
+from sqlalchemy import (
+    MetaData,
+    Column,
+    Table,
+    Integer,
+    String,
+    Text,
+    Numeric,
+    CHAR,
+    ForeignKey,
+    Index,
+    UniqueConstraint,
+    CheckConstraint,
+    text,
+)
 from sqlalchemy.engine.reflection import Inspector
 
 from alembic.operations import ops
@@ -27,11 +40,12 @@ def _default_include_object(obj, name, type_, reflected, compare_to):
     else:
         return True
 
+
 _default_object_filters = _default_include_object
 
 
 class ModelOne(object):
-    __requires__ = ('unique_constraint_reflection', )
+    __requires__ = ("unique_constraint_reflection",)
 
     schema = None
 
@@ -41,30 +55,42 @@ class ModelOne(object):
 
         m = MetaData(schema=schema)
 
-        Table('user', m,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50)),
-              Column('a1', Text),
-              Column("pw", String(50)),
-              Index('pw_idx', 'pw')
-              )
-
-        Table('address', m,
-              Column('id', Integer, primary_key=True),
-              Column('email_address', String(100), nullable=False),
-              )
-
-        Table('order', m,
-              Column('order_id', Integer, primary_key=True),
-              Column("amount", Numeric(8, 2), nullable=False,
-                     server_default=text("0")),
-              CheckConstraint('amount >= 0', name='ck_order_amount')
-              )
-
-        Table('extra', m,
-              Column("x", CHAR),
-              Column('uid', Integer, ForeignKey('user.id'))
-              )
+        Table(
+            "user",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50)),
+            Column("a1", Text),
+            Column("pw", String(50)),
+            Index("pw_idx", "pw"),
+        )
+
+        Table(
+            "address",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("email_address", String(100), nullable=False),
+        )
+
+        Table(
+            "order",
+            m,
+            Column("order_id", Integer, primary_key=True),
+            Column(
+                "amount",
+                Numeric(8, 2),
+                nullable=False,
+                server_default=text("0"),
+            ),
+            CheckConstraint("amount >= 0", name="ck_order_amount"),
+        )
+
+        Table(
+            "extra",
+            m,
+            Column("x", CHAR),
+            Column("uid", Integer, ForeignKey("user.id")),
+        )
 
         return m
 
@@ -74,50 +100,80 @@ class ModelOne(object):
 
         m = MetaData(schema=schema)
 
-        Table('user', m,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('a1', Text, server_default="x")
-              )
-
-        Table('address', m,
-              Column('id', Integer, primary_key=True),
-              Column('email_address', String(100), nullable=False),
-              Column('street', String(50)),
-              UniqueConstraint('email_address', name="uq_email")
-              )
-
-        Table('order', m,
-              Column('order_id', Integer, primary_key=True),
-              Column('amount', Numeric(10, 2), nullable=True,
-                     server_default=text("0")),
-              Column('user_id', Integer, ForeignKey('user.id')),
-              CheckConstraint('amount > -1', name='ck_order_amount'),
-              )
-
-        Table('item', m,
-              Column('id', Integer, primary_key=True),
-              Column('description', String(100)),
-              Column('order_id', Integer, ForeignKey('order.order_id')),
-              CheckConstraint('len(description) > 5')
-              )
+        Table(
+            "user",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("a1", Text, server_default="x"),
+        )
+
+        Table(
+            "address",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("email_address", String(100), nullable=False),
+            Column("street", String(50)),
+            UniqueConstraint("email_address", name="uq_email"),
+        )
+
+        Table(
+            "order",
+            m,
+            Column("order_id", Integer, primary_key=True),
+            Column(
+                "amount",
+                Numeric(10, 2),
+                nullable=True,
+                server_default=text("0"),
+            ),
+            Column("user_id", Integer, ForeignKey("user.id")),
+            CheckConstraint("amount > -1", name="ck_order_amount"),
+        )
+
+        Table(
+            "item",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("description", String(100)),
+            Column("order_id", Integer, ForeignKey("order.order_id")),
+            CheckConstraint("len(description) > 5"),
+        )
         return m
 
 
 class _ComparesFKs(object):
     def _assert_fk_diff(
-            self, diff, type_, source_table, source_columns,
-            target_table, target_columns, name=None, conditional_name=None,
-            source_schema=None, onupdate=None, ondelete=None,
-            initially=None, deferrable=None):
+        self,
+        diff,
+        type_,
+        source_table,
+        source_columns,
+        target_table,
+        target_columns,
+        name=None,
+        conditional_name=None,
+        source_schema=None,
+        onupdate=None,
+        ondelete=None,
+        initially=None,
+        deferrable=None,
+    ):
         # the public API for ForeignKeyConstraint was not very rich
         # in 0.7, 0.8, so here we use the well-known but slightly
         # private API to get at its elements
-        (fk_source_schema, fk_source_table,
-         fk_source_columns, fk_target_schema, fk_target_table,
-         fk_target_columns,
-         fk_onupdate, fk_ondelete, fk_deferrable, fk_initially
-         ) = _fk_spec(diff[1])
+        (
+            fk_source_schema,
+            fk_source_table,
+            fk_source_columns,
+            fk_target_schema,
+            fk_target_table,
+            fk_target_columns,
+            fk_onupdate,
+            fk_ondelete,
+            fk_deferrable,
+            fk_initially,
+        ) = _fk_spec(diff[1])
 
         eq_(diff[0], type_)
         eq_(fk_source_table, source_table)
@@ -129,15 +185,15 @@ class _ComparesFKs(object):
         eq_(fk_initially, initially)
         eq_(fk_deferrable, deferrable)
 
-        eq_([elem.column.name for elem in diff[1].elements],
-            target_columns)
+        eq_([elem.column.name for elem in diff[1].elements], target_columns)
         if conditional_name is not None:
             if config.requirements.no_fk_names.enabled:
                 eq_(diff[1].name, None)
-            elif conditional_name == 'servergenerated':
-                fks = Inspector.from_engine(self.bind).\
-                    get_foreign_keys(source_table)
-                server_fk_name = fks[0]['name']
+            elif conditional_name == "servergenerated":
+                fks = Inspector.from_engine(self.bind).get_foreign_keys(
+                    source_table
+                )
+                server_fk_name = fks[0]["name"]
                 eq_(diff[1].name, server_fk_name)
             else:
                 eq_(diff[1].name, conditional_name)
@@ -146,7 +202,6 @@ class _ComparesFKs(object):
 
 
 class AutogenTest(_ComparesFKs):
-
     def _flatten_diffs(self, diffs):
         for d in diffs:
             if isinstance(d, list):
@@ -177,20 +232,19 @@ class AutogenTest(_ComparesFKs):
     def setUp(self):
         self.conn = conn = self.bind.connect()
         ctx_opts = {
-            'compare_type': True,
-            'compare_server_default': True,
-            'target_metadata': self.m2,
-            'upgrade_token': "upgrades",
-            'downgrade_token': "downgrades",
-            'alembic_module_prefix': 'op.',
-            'sqlalchemy_module_prefix': 'sa.',
-            'include_object': _default_object_filters
+            "compare_type": True,
+            "compare_server_default": True,
+            "target_metadata": self.m2,
+            "upgrade_token": "upgrades",
+            "downgrade_token": "downgrades",
+            "alembic_module_prefix": "op.",
+            "sqlalchemy_module_prefix": "sa.",
+            "include_object": _default_object_filters,
         }
         if self.configure_opts:
             ctx_opts.update(self.configure_opts)
         self.context = context = MigrationContext.configure(
-            connection=conn,
-            opts=ctx_opts
+            connection=conn, opts=ctx_opts
         )
 
         self.autogen_context = api.AutogenContext(context, self.m2)
@@ -200,46 +254,47 @@ class AutogenTest(_ComparesFKs):
 
     def _update_context(self, object_filters=None, include_schemas=None):
         if include_schemas is not None:
-            self.autogen_context.opts['include_schemas'] = include_schemas
+            self.autogen_context.opts["include_schemas"] = include_schemas
         if object_filters is not None:
             self.autogen_context._object_filters = [object_filters]
         return self.autogen_context
 
 
 class AutogenFixtureTest(_ComparesFKs):
-
     def _fixture(
-            self, m1, m2, include_schemas=False,
-            opts=None, object_filters=_default_object_filters,
-            return_ops=False):
+        self,
+        m1,
+        m2,
+        include_schemas=False,
+        opts=None,
+        object_filters=_default_object_filters,
+        return_ops=False,
+    ):
         self.metadata, model_metadata = m1, m2
         for m in util.to_list(self.metadata):
             m.create_all(self.bind)
 
         with self.bind.connect() as conn:
             ctx_opts = {
-                'compare_type': True,
-                'compare_server_default': True,
-                'target_metadata': model_metadata,
-                'upgrade_token': "upgrades",
-                'downgrade_token': "downgrades",
-                'alembic_module_prefix': 'op.',
-                'sqlalchemy_module_prefix': 'sa.',
-                'include_object': object_filters,
-                'include_schemas': include_schemas
+                "compare_type": True,
+                "compare_server_default": True,
+                "target_metadata": model_metadata,
+                "upgrade_token": "upgrades",
+                "downgrade_token": "downgrades",
+                "alembic_module_prefix": "op.",
+                "sqlalchemy_module_prefix": "sa.",
+                "include_object": object_filters,
+                "include_schemas": include_schemas,
             }
             if opts:
                 ctx_opts.update(opts)
             self.context = context = MigrationContext.configure(
-                connection=conn,
-                opts=ctx_opts
+                connection=conn, opts=ctx_opts
             )
 
             autogen_context = api.AutogenContext(context, model_metadata)
             uo = ops.UpgradeOps(ops=[])
-            autogenerate._produce_net_changes(
-                autogen_context, uo
-            )
+            autogenerate._produce_net_changes(autogen_context, uo)
 
             if return_ops:
                 return uo
@@ -253,8 +308,7 @@ class AutogenFixtureTest(_ComparesFKs):
         self.bind = config.db
 
     def tearDown(self):
-        if hasattr(self, 'metadata'):
+        if hasattr(self, "metadata"):
             for m in util.to_list(self.metadata):
                 m.drop_all(self.bind)
         clear_staging_env()
-
index bc7133fdf1a515fad18ed1b2c34b3d4d3e46d32e..13ac41fade64202aa61b9514acaa68b60fae5f6d 100644 (file)
 from alembic.script.revision import RevisionMap, Revision
 
 data = [
-    Revision('3fc8a578bc0a', ('4878cb1cb7f6', '454a0529f84e'), ),
-    Revision('69285b0faaa', ('36c31e4e1c37', '3a3b24a31b57'), ),
-    Revision('3b0452c64639', '2f1a0f3667f3', ),
-    Revision('2d9d787a496', '135b5fd31062', ),
-    Revision('184f65ed83af', '3b0452c64639', ),
-    Revision('430074f99c29', '54f871bfe0b0', ),
-    Revision('3ffb59981d9a', '519c9f3ce294', ),
-    Revision('454a0529f84e', ('40f6508e4373', '38a936c6ab11'), ),
-    Revision('24c2620b2e3f', ('430074f99c29', '1f5ceb1ec255'), ),
-    Revision('169a948471a9', '247ad6880f93', ),
-    Revision('2f1a0f3667f3', '17dd0f165262', ),
-    Revision('27227dc4fda8', '2a66d7c4d8a1', ),
-    Revision('4b2ad1ffe2e7', ('3b409f268da4', '4f8a9b79a063'), ),
-    Revision('124ef6a17781', '2529684536da', ),
-    Revision('4789d9c82ca7', '593b8076fb2c', ),
-    Revision('64ed798bcc3', ('44ed1bf512a0', '169a948471a9'), ),
-    Revision('2588a3c36a0f', '50c7b21c9089', ),
-    Revision('359329c2ebb', ('5810e9eff996', '339faa12616'), ),
-    Revision('540bc5634bd', '3a5db5f31209', ),
-    Revision('20fe477817d2', '53d5ff905573', ),
-    Revision('4f8a9b79a063', ('3cf34fcd6473', '300209d8594'), ),
-    Revision('6918589deaf', '3314c17f6e35', ),
-    Revision('1755e3b1481c', ('17b66754be21', '31b1d4b7fc95'), ),
-    Revision('58c988e1aa4e', ('219240032b88', 'f067f0b825c'), ),
-    Revision('593b8076fb2c', '1d94175d221b', ),
-    Revision('38d069994064', ('46b70a57edc0', '3ed56beabfb7'), ),
-    Revision('3e2f6c6d1182', '7f96a01461b', ),
-    Revision('1f6969597fe7', '1811bdae9e63', ),
-    Revision('17dd0f165262', '3cf02a593a68', ),
-    Revision('3cf02a593a68', '25a7ef58d293', ),
-    Revision('34dfac7edb2d', '28f4dd53ad3a', ),
-    Revision('4009c533e05d', '42ded7355da2', ),
-    Revision('5a0003c3b09c', ('3ed56beabfb7', '2028d94d3863'), ),
-    Revision('38a936c6ab11', '2588a3c36a0f', ),
-    Revision('59223c5b7b36', '2f93dd880bae', ),
-    Revision('4121bd6e99e9', '540bc5634bd', ),
-    Revision('260714a3f2de', '6918589deaf', ),
-    Revision('ae77a2ed69b', '274fd2642933', ),
-    Revision('18ff1ab3b4c4', '430133b6d46c', ),
-    Revision('2b9a327527a9', ('359329c2ebb', '593b8076fb2c'), ),
-    Revision('4e6167c75ed0', '325b273d61bd', ),
-    Revision('21ab11a7c5c4', ('3da31f3323ec', '22f26011d635'), ),
-    Revision('3b93e98481b1', '4e28e2f4fe2f', ),
-    Revision('145d8f1e334d', 'b4143d129e', ),
-    Revision('135b5fd31062', '1d94175d221b', ),
-    Revision('300209d8594', ('52804033910e', '593b8076fb2c'), ),
-    Revision('8dca95cce28', 'f034666cd80', ),
-    Revision('46b70a57edc0', ('145d8f1e334d', '4cc2960cbe19'), ),
-    Revision('4d45e479fbb9', '2d9d787a496', ),
-    Revision('22f085bf8bbd', '540bc5634bd', ),
-    Revision('263e91fd17d8', '2b9a327527a9', ),
-    Revision('219240032b88', ('300209d8594', '2b9a327527a9'), ),
-    Revision('325b273d61bd', '4b2ad1ffe2e7', ),
-    Revision('199943ccc774', '1aa674ccfa4e', ),
-    Revision('247ad6880f93', '1f6969597fe7', ),
-    Revision('4878cb1cb7f6', '28f4dd53ad3a', ),
-    Revision('2a66d7c4d8a1', '23f1ccb18d6d', ),
-    Revision('42b079245b55', '593b8076fb2c', ),
-    Revision('1cccf82219cb', ('20fe477817d2', '915c67915c2'), ),
-    Revision('b4143d129e', ('159331d6f484', '504d5168afe1'), ),
-    Revision('53d5ff905573', '3013877bf5bd', ),
-    Revision('1f5ceb1ec255', '3ffb59981d9a', ),
-    Revision('ef1c1c1531f', '4738812e6ece', ),
-    Revision('1f6963d1ae02', '247ad6880f93', ),
-    Revision('44d58f1d31f0', '18ff1ab3b4c4', ),
-    Revision('c3ebe64dfb5', ('3409c57b0da', '31f352e77045'), ),
-    Revision('f067f0b825c', '359329c2ebb', ),
-    Revision('52ab2d3b57ce', '96d590bd82e', ),
-    Revision('3b409f268da4', ('20e90eb3eeb6', '263e91fd17d8'), ),
-    Revision('5a4ca8889674', '4e6167c75ed0', ),
-    Revision('5810e9eff996', ('2d30d79c4093', '52804033910e'), ),
-    Revision('40f6508e4373', '4ed16fad67a7', ),
-    Revision('1811bdae9e63', '260714a3f2de', ),
-    Revision('3013877bf5bd', ('8dca95cce28', '3fc8a578bc0a'), ),
-    Revision('16426dbea880', '28f4dd53ad3a', ),
-    Revision('22f26011d635', ('4c93d063d2ba', '3b93e98481b1'), ),
-    Revision('3409c57b0da', '17b66754be21', ),
-    Revision('44373001000f', ('42b079245b55', '219240032b88'), ),
-    Revision('28f4dd53ad3a', '2e71fd90eb9d', ),
-    Revision('4cc2960cbe19', '504d5168afe1', ),
-    Revision('31f352e77045', ('17b66754be21', '22f085bf8bbd'), ),
-    Revision('4ed16fad67a7', 'f034666cd80', ),
-    Revision('3da31f3323ec', '4c93d063d2ba', ),
-    Revision('31b1d4b7fc95', '1cc4459fd115', ),
-    Revision('11bc0ff42f87', '28f4dd53ad3a', ),
-    Revision('3a5db5f31209', '59742a546b84', ),
-    Revision('20e90eb3eeb6', ('58c988e1aa4e', '44373001000f'), ),
-    Revision('23f1ccb18d6d', '52ab2d3b57ce', ),
-    Revision('1d94175d221b', '21ab11a7c5c4', ),
-    Revision('36f1a410ed', '54f871bfe0b0', ),
-    Revision('181a149173e', '2ee35cac4c62', ),
-    Revision('171ad2f0c672', '4a4e0838e206', ),
-    Revision('2f93dd880bae', '540bc5634bd', ),
-    Revision('25a7ef58d293', None, ),
-    Revision('7f96a01461b', '184f65ed83af', ),
-    Revision('b21f22233f', '3e2f6c6d1182', ),
-    Revision('52804033910e', '1d94175d221b', ),
-    Revision('1e6240aba5b3', ('4121bd6e99e9', '2c50d8bab6ee'), ),
-    Revision('1cc4459fd115', '1e6240aba5b3', ),
-    Revision('274fd2642933', '4009c533e05d', ),
-    Revision('1aa674ccfa4e', ('59223c5b7b36', '42050bf030fd'), ),
-    Revision('4e28e2f4fe2f', '596d7b9e11', ),
-    Revision('49ddec8c7a5e', ('124ef6a17781', '47578179e766'), ),
-    Revision('3e9bb349cc46', 'ef1c1c1531f', ),
-    Revision('2028d94d3863', '504d5168afe1', ),
-    Revision('159331d6f484', '34dfac7edb2d', ),
-    Revision('596d7b9e11', '171ad2f0c672', ),
-    Revision('3b96bcc8da76', 'f034666cd80', ),
-    Revision('4738812e6ece', '78982bf5499', ),
-    Revision('3314c17f6e35', '27227dc4fda8', ),
-    Revision('30931c545bf', '2e71fd90eb9d', ),
-    Revision('2e71fd90eb9d', ('c3ebe64dfb5', '1755e3b1481c'), ),
-    Revision('3ed56beabfb7', ('11bc0ff42f87', '69285b0faaa'), ),
-    Revision('96d590bd82e', '3e9bb349cc46', ),
-    Revision('339faa12616', '4d45e479fbb9', ),
-    Revision('47578179e766', '2529684536da', ),
-    Revision('2ee35cac4c62', 'b21f22233f', ),
-    Revision('50c7b21c9089', ('4ed16fad67a7', '3b96bcc8da76'), ),
-    Revision('78982bf5499', 'ae77a2ed69b', ),
-    Revision('519c9f3ce294', '2c50d8bab6ee', ),
-    Revision('2720fc75e5fd', '1cccf82219cb', ),
-    Revision('21638ec787ba', '44d58f1d31f0', ),
-    Revision('59742a546b84', '49ddec8c7a5e', ),
-    Revision('2d30d79c4093', '135b5fd31062', ),
-    Revision('f034666cd80', ('5a0003c3b09c', '38d069994064'), ),
-    Revision('430133b6d46c', '181a149173e', ),
-    Revision('3a3b24a31b57', ('16426dbea880', '4cc2960cbe19'), ),
-    Revision('2529684536da', ('64ed798bcc3', '1f6963d1ae02'), ),
-    Revision('17b66754be21', ('19e0db9d806a', '24c2620b2e3f'), ),
-    Revision('3cf34fcd6473', ('52804033910e', '4789d9c82ca7'), ),
-    Revision('36c31e4e1c37', '504d5168afe1', ),
-    Revision('54f871bfe0b0', '519c9f3ce294', ),
-    Revision('4a4e0838e206', '2a7f37cf7770', ),
-    Revision('19e0db9d806a', ('430074f99c29', '36f1a410ed'), ),
-    Revision('44ed1bf512a0', '247ad6880f93', ),
-    Revision('42050bf030fd', '2f93dd880bae', ),
-    Revision('2c50d8bab6ee', '199943ccc774', ),
-    Revision('504d5168afe1', ('28f4dd53ad3a', '30931c545bf'), ),
-    Revision('915c67915c2', '3fc8a578bc0a', ),
-    Revision('2a7f37cf7770', '2720fc75e5fd', ),
-    Revision('4c93d063d2ba', '4e28e2f4fe2f', ),
-    Revision('42ded7355da2', '21638ec787ba', ),
+    Revision("3fc8a578bc0a", ("4878cb1cb7f6", "454a0529f84e")),
+    Revision("69285b0faaa", ("36c31e4e1c37", "3a3b24a31b57")),
+    Revision("3b0452c64639", "2f1a0f3667f3"),
+    Revision("2d9d787a496", "135b5fd31062"),
+    Revision("184f65ed83af", "3b0452c64639"),
+    Revision("430074f99c29", "54f871bfe0b0"),
+    Revision("3ffb59981d9a", "519c9f3ce294"),
+    Revision("454a0529f84e", ("40f6508e4373", "38a936c6ab11")),
+    Revision("24c2620b2e3f", ("430074f99c29", "1f5ceb1ec255")),
+    Revision("169a948471a9", "247ad6880f93"),
+    Revision("2f1a0f3667f3", "17dd0f165262"),
+    Revision("27227dc4fda8", "2a66d7c4d8a1"),
+    Revision("4b2ad1ffe2e7", ("3b409f268da4", "4f8a9b79a063")),
+    Revision("124ef6a17781", "2529684536da"),
+    Revision("4789d9c82ca7", "593b8076fb2c"),
+    Revision("64ed798bcc3", ("44ed1bf512a0", "169a948471a9")),
+    Revision("2588a3c36a0f", "50c7b21c9089"),
+    Revision("359329c2ebb", ("5810e9eff996", "339faa12616")),
+    Revision("540bc5634bd", "3a5db5f31209"),
+    Revision("20fe477817d2", "53d5ff905573"),
+    Revision("4f8a9b79a063", ("3cf34fcd6473", "300209d8594")),
+    Revision("6918589deaf", "3314c17f6e35"),
+    Revision("1755e3b1481c", ("17b66754be21", "31b1d4b7fc95")),
+    Revision("58c988e1aa4e", ("219240032b88", "f067f0b825c")),
+    Revision("593b8076fb2c", "1d94175d221b"),
+    Revision("38d069994064", ("46b70a57edc0", "3ed56beabfb7")),
+    Revision("3e2f6c6d1182", "7f96a01461b"),
+    Revision("1f6969597fe7", "1811bdae9e63"),
+    Revision("17dd0f165262", "3cf02a593a68"),
+    Revision("3cf02a593a68", "25a7ef58d293"),
+    Revision("34dfac7edb2d", "28f4dd53ad3a"),
+    Revision("4009c533e05d", "42ded7355da2"),
+    Revision("5a0003c3b09c", ("3ed56beabfb7", "2028d94d3863")),
+    Revision("38a936c6ab11", "2588a3c36a0f"),
+    Revision("59223c5b7b36", "2f93dd880bae"),
+    Revision("4121bd6e99e9", "540bc5634bd"),
+    Revision("260714a3f2de", "6918589deaf"),
+    Revision("ae77a2ed69b", "274fd2642933"),
+    Revision("18ff1ab3b4c4", "430133b6d46c"),
+    Revision("2b9a327527a9", ("359329c2ebb", "593b8076fb2c")),
+    Revision("4e6167c75ed0", "325b273d61bd"),
+    Revision("21ab11a7c5c4", ("3da31f3323ec", "22f26011d635")),
+    Revision("3b93e98481b1", "4e28e2f4fe2f"),
+    Revision("145d8f1e334d", "b4143d129e"),
+    Revision("135b5fd31062", "1d94175d221b"),
+    Revision("300209d8594", ("52804033910e", "593b8076fb2c")),
+    Revision("8dca95cce28", "f034666cd80"),
+    Revision("46b70a57edc0", ("145d8f1e334d", "4cc2960cbe19")),
+    Revision("4d45e479fbb9", "2d9d787a496"),
+    Revision("22f085bf8bbd", "540bc5634bd"),
+    Revision("263e91fd17d8", "2b9a327527a9"),
+    Revision("219240032b88", ("300209d8594", "2b9a327527a9")),
+    Revision("325b273d61bd", "4b2ad1ffe2e7"),
+    Revision("199943ccc774", "1aa674ccfa4e"),
+    Revision("247ad6880f93", "1f6969597fe7"),
+    Revision("4878cb1cb7f6", "28f4dd53ad3a"),
+    Revision("2a66d7c4d8a1", "23f1ccb18d6d"),
+    Revision("42b079245b55", "593b8076fb2c"),
+    Revision("1cccf82219cb", ("20fe477817d2", "915c67915c2")),
+    Revision("b4143d129e", ("159331d6f484", "504d5168afe1")),
+    Revision("53d5ff905573", "3013877bf5bd"),
+    Revision("1f5ceb1ec255", "3ffb59981d9a"),
+    Revision("ef1c1c1531f", "4738812e6ece"),
+    Revision("1f6963d1ae02", "247ad6880f93"),
+    Revision("44d58f1d31f0", "18ff1ab3b4c4"),
+    Revision("c3ebe64dfb5", ("3409c57b0da", "31f352e77045")),
+    Revision("f067f0b825c", "359329c2ebb"),
+    Revision("52ab2d3b57ce", "96d590bd82e"),
+    Revision("3b409f268da4", ("20e90eb3eeb6", "263e91fd17d8")),
+    Revision("5a4ca8889674", "4e6167c75ed0"),
+    Revision("5810e9eff996", ("2d30d79c4093", "52804033910e")),
+    Revision("40f6508e4373", "4ed16fad67a7"),
+    Revision("1811bdae9e63", "260714a3f2de"),
+    Revision("3013877bf5bd", ("8dca95cce28", "3fc8a578bc0a")),
+    Revision("16426dbea880", "28f4dd53ad3a"),
+    Revision("22f26011d635", ("4c93d063d2ba", "3b93e98481b1")),
+    Revision("3409c57b0da", "17b66754be21"),
+    Revision("44373001000f", ("42b079245b55", "219240032b88")),
+    Revision("28f4dd53ad3a", "2e71fd90eb9d"),
+    Revision("4cc2960cbe19", "504d5168afe1"),
+    Revision("31f352e77045", ("17b66754be21", "22f085bf8bbd")),
+    Revision("4ed16fad67a7", "f034666cd80"),
+    Revision("3da31f3323ec", "4c93d063d2ba"),
+    Revision("31b1d4b7fc95", "1cc4459fd115"),
+    Revision("11bc0ff42f87", "28f4dd53ad3a"),
+    Revision("3a5db5f31209", "59742a546b84"),
+    Revision("20e90eb3eeb6", ("58c988e1aa4e", "44373001000f")),
+    Revision("23f1ccb18d6d", "52ab2d3b57ce"),
+    Revision("1d94175d221b", "21ab11a7c5c4"),
+    Revision("36f1a410ed", "54f871bfe0b0"),
+    Revision("181a149173e", "2ee35cac4c62"),
+    Revision("171ad2f0c672", "4a4e0838e206"),
+    Revision("2f93dd880bae", "540bc5634bd"),
+    Revision("25a7ef58d293", None),
+    Revision("7f96a01461b", "184f65ed83af"),
+    Revision("b21f22233f", "3e2f6c6d1182"),
+    Revision("52804033910e", "1d94175d221b"),
+    Revision("1e6240aba5b3", ("4121bd6e99e9", "2c50d8bab6ee")),
+    Revision("1cc4459fd115", "1e6240aba5b3"),
+    Revision("274fd2642933", "4009c533e05d"),
+    Revision("1aa674ccfa4e", ("59223c5b7b36", "42050bf030fd")),
+    Revision("4e28e2f4fe2f", "596d7b9e11"),
+    Revision("49ddec8c7a5e", ("124ef6a17781", "47578179e766")),
+    Revision("3e9bb349cc46", "ef1c1c1531f"),
+    Revision("2028d94d3863", "504d5168afe1"),
+    Revision("159331d6f484", "34dfac7edb2d"),
+    Revision("596d7b9e11", "171ad2f0c672"),
+    Revision("3b96bcc8da76", "f034666cd80"),
+    Revision("4738812e6ece", "78982bf5499"),
+    Revision("3314c17f6e35", "27227dc4fda8"),
+    Revision("30931c545bf", "2e71fd90eb9d"),
+    Revision("2e71fd90eb9d", ("c3ebe64dfb5", "1755e3b1481c")),
+    Revision("3ed56beabfb7", ("11bc0ff42f87", "69285b0faaa")),
+    Revision("96d590bd82e", "3e9bb349cc46"),
+    Revision("339faa12616", "4d45e479fbb9"),
+    Revision("47578179e766", "2529684536da"),
+    Revision("2ee35cac4c62", "b21f22233f"),
+    Revision("50c7b21c9089", ("4ed16fad67a7", "3b96bcc8da76")),
+    Revision("78982bf5499", "ae77a2ed69b"),
+    Revision("519c9f3ce294", "2c50d8bab6ee"),
+    Revision("2720fc75e5fd", "1cccf82219cb"),
+    Revision("21638ec787ba", "44d58f1d31f0"),
+    Revision("59742a546b84", "49ddec8c7a5e"),
+    Revision("2d30d79c4093", "135b5fd31062"),
+    Revision("f034666cd80", ("5a0003c3b09c", "38d069994064")),
+    Revision("430133b6d46c", "181a149173e"),
+    Revision("3a3b24a31b57", ("16426dbea880", "4cc2960cbe19")),
+    Revision("2529684536da", ("64ed798bcc3", "1f6963d1ae02")),
+    Revision("17b66754be21", ("19e0db9d806a", "24c2620b2e3f")),
+    Revision("3cf34fcd6473", ("52804033910e", "4789d9c82ca7")),
+    Revision("36c31e4e1c37", "504d5168afe1"),
+    Revision("54f871bfe0b0", "519c9f3ce294"),
+    Revision("4a4e0838e206", "2a7f37cf7770"),
+    Revision("19e0db9d806a", ("430074f99c29", "36f1a410ed")),
+    Revision("44ed1bf512a0", "247ad6880f93"),
+    Revision("42050bf030fd", "2f93dd880bae"),
+    Revision("2c50d8bab6ee", "199943ccc774"),
+    Revision("504d5168afe1", ("28f4dd53ad3a", "30931c545bf")),
+    Revision("915c67915c2", "3fc8a578bc0a"),
+    Revision("2a7f37cf7770", "2720fc75e5fd"),
+    Revision("4c93d063d2ba", "4e28e2f4fe2f"),
+    Revision("42ded7355da2", "21638ec787ba"),
 ]
 
-map_ = RevisionMap(
-    lambda: data
-)
-
-
+map_ = RevisionMap(lambda: data)
index 608d90399a659a395d82927763d86fbfff50bcd7..6cf770d745a572cd66642b4a54bc1b1de199eb10 100755 (executable)
@@ -11,12 +11,16 @@ import os
 # use bootstrapping so that test plugins are loaded
 # without touching the main library before coverage starts
 bootstrap_file = os.path.join(
-    os.path.dirname(__file__), "..", "alembic",
-    "testing", "plugin", "bootstrap.py"
+    os.path.dirname(__file__),
+    "..",
+    "alembic",
+    "testing",
+    "plugin",
+    "bootstrap.py",
 )
 
 with open(bootstrap_file) as f:
-    code = compile(f.read(), "bootstrap.py", 'exec')
+    code = compile(f.read(), "bootstrap.py", "exec")
     to_bootstrap = "pytest"
     exec(code, globals(), locals())
     from pytestplugin import *  # noqa
index 0be95eb3659dd8d78f0596cea431523e905b25c7..74a7d83e9be879f2d2453e40d561ffc4a4e9dad7 100644 (file)
@@ -5,16 +5,12 @@ from alembic.util import sqla_compat
 
 
 class DefaultRequirements(SuiteRequirements):
-
     @property
     def schemas(self):
         """Target database must support external schemas, and have one
         named 'test_schema'."""
 
-        return exclusions.skip_if([
-            "sqlite",
-            "firebird"
-        ], "no schema support")
+        return exclusions.skip_if(["sqlite", "firebird"], "no schema support")
 
     @property
     def no_referential_integrity(self):
@@ -48,12 +44,12 @@ class DefaultRequirements(SuiteRequirements):
     @property
     def unnamed_constraints(self):
         """constraints without names are supported."""
-        return exclusions.only_on(['sqlite'])
+        return exclusions.only_on(["sqlite"])
 
     @property
     def fk_names(self):
         """foreign key constraints always have names in the DB"""
-        return exclusions.fails_on('sqlite')
+        return exclusions.fails_on("sqlite")
 
     @property
     def no_name_normalize(self):
@@ -63,20 +59,24 @@ class DefaultRequirements(SuiteRequirements):
 
     @property
     def reflects_fk_options(self):
-        return exclusions.only_on([
-            'postgresql', 'mysql',
-            lambda config: util.sqla_110 and
-            exclusions.against(config, 'sqlite')])
+        return exclusions.only_on(
+            [
+                "postgresql",
+                "mysql",
+                lambda config: util.sqla_110
+                and exclusions.against(config, "sqlite"),
+            ]
+        )
 
     @property
     def fk_initially(self):
         """backend supports INITIALLY option in foreign keys"""
-        return exclusions.only_on(['postgresql'])
+        return exclusions.only_on(["postgresql"])
 
     @property
     def fk_deferrable(self):
         """backend supports DEFERRABLE option in foreign keys"""
-        return exclusions.only_on(['postgresql'])
+        return exclusions.only_on(["postgresql"])
 
     @property
     def flexible_fk_cascades(self):
@@ -84,8 +84,7 @@ class DefaultRequirements(SuiteRequirements):
         full range of keywords (e.g. NO ACTION, etc.)"""
 
         return exclusions.skip_if(
-            ['oracle'],
-            'target backend has poor FK cascade syntax'
+            ["oracle"], "target backend has poor FK cascade syntax"
         )
 
     @property
@@ -97,10 +96,13 @@ class DefaultRequirements(SuiteRequirements):
         """Target driver reflects the name of primary key constraints."""
 
         return exclusions.fails_on_everything_except(
-            'postgresql', 'oracle', 'mssql', 'sybase',
+            "postgresql",
+            "oracle",
+            "mssql",
+            "sybase",
             lambda config: (
                 util.sqla_110 and exclusions.against(config, "sqlite")
-            )
+            ),
         )
 
     @property
@@ -122,8 +124,10 @@ class DefaultRequirements(SuiteRequirements):
                 return False
             count = config.db.scalar(
                 "SELECT count(*) FROM pg_extension "
-                "WHERE extname='%s'" % name)
+                "WHERE extname='%s'" % name
+            )
             return bool(count)
+
         return exclusions.only_if(check, "needs %s extension" % name)
 
     @property
@@ -134,7 +138,6 @@ class DefaultRequirements(SuiteRequirements):
     def btree_gist(self):
         return self._has_pg_extension("btree_gist")
 
-
     @property
     def autoincrement_on_composite_pk(self):
         return exclusions.skip_if(["sqlite"], "not supported by database")
@@ -153,34 +156,42 @@ class DefaultRequirements(SuiteRequirements):
     @property
     def mysql_check_reflection_or_none(self):
         def go(config):
-            return not self._mariadb_102(config) \
-                or self.sqlalchemy_1115.enabled
+            return (
+                not self._mariadb_102(config) or self.sqlalchemy_1115.enabled
+            )
+
         return exclusions.succeeds_if(go)
 
     @property
     def mysql_timestamp_reflection(self):
         def go(config):
-            return not self._mariadb_102(config) \
-                or self.sqlalchemy_1115.enabled
+            return (
+                not self._mariadb_102(config) or self.sqlalchemy_1115.enabled
+            )
+
         return exclusions.only_if(go)
 
     def _mariadb_102(self, config):
-        return exclusions.against(config, "mysql") and \
-            sqla_compat._is_mariadb(config.db.dialect) and \
-            sqla_compat._mariadb_normalized_version_info(
-                config.db.dialect) > (10, 2)
+        return (
+            exclusions.against(config, "mysql")
+            and sqla_compat._is_mariadb(config.db.dialect)
+            and sqla_compat._mariadb_normalized_version_info(config.db.dialect)
+            > (10, 2)
+        )
 
     def _mariadb_only_102(self, config):
-        return exclusions.against(config, "mysql") and \
-            sqla_compat._is_mariadb(config.db.dialect) and \
-            sqla_compat._mariadb_normalized_version_info(
-                config.db.dialect) >= (10, 2) and \
-            sqla_compat._mariadb_normalized_version_info(
-                config.db.dialect) < (10, 3)
+        return (
+            exclusions.against(config, "mysql")
+            and sqla_compat._is_mariadb(config.db.dialect)
+            and sqla_compat._mariadb_normalized_version_info(config.db.dialect)
+            >= (10, 2)
+            and sqla_compat._mariadb_normalized_version_info(config.db.dialect)
+            < (10, 3)
+        )
 
     def _mysql_not_mariadb_102(self, config):
         return exclusions.against(config, "mysql") and (
-            not sqla_compat._is_mariadb(config.db.dialect) or
-            sqla_compat._mariadb_normalized_version_info(
-                config.db.dialect) < (10, 2)
+            not sqla_compat._is_mariadb(config.db.dialect)
+            or sqla_compat._mariadb_normalized_version_info(config.db.dialect)
+            < (10, 2)
         )
index c53631750a753fbc935b3c7f11f7bcebcb1d9944..7694cbd6bf0aa1202d214cf8c8433c6baf8b63cf 100644 (file)
@@ -9,64 +9,73 @@ from ._autogen_fixtures import AutogenTest, ModelOne, _default_include_object
 
 
 class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
-    __only_on__ = 'sqlite'
+    __only_on__ = "sqlite"
 
     def test_render_nothing(self):
         context = MigrationContext.configure(
             connection=self.bind.connect(),
             opts={
-                'compare_type': True,
-                'compare_server_default': True,
-                'target_metadata': self.m1,
-                'upgrade_token': "upgrades",
-                'downgrade_token': "downgrades",
-            }
+                "compare_type": True,
+                "compare_server_default": True,
+                "target_metadata": self.m1,
+                "upgrade_token": "upgrades",
+                "downgrade_token": "downgrades",
+            },
         )
         template_args = {}
         autogenerate._render_migration_diffs(context, template_args)
 
-        eq_(re.sub(r"u'", "'", template_args['upgrades']),
+        eq_(
+            re.sub(r"u'", "'", template_args["upgrades"]),
             """# ### commands auto generated by Alembic - please adjust! ###
     pass
-    # ### end Alembic commands ###""")
-        eq_(re.sub(r"u'", "'", template_args['downgrades']),
+    # ### end Alembic commands ###""",
+        )
+        eq_(
+            re.sub(r"u'", "'", template_args["downgrades"]),
             """# ### commands auto generated by Alembic - please adjust! ###
     pass
-    # ### end Alembic commands ###""")
+    # ### end Alembic commands ###""",
+        )
 
     def test_render_nothing_batch(self):
         context = MigrationContext.configure(
             connection=self.bind.connect(),
             opts={
-                'compare_type': True,
-                'compare_server_default': True,
-                'target_metadata': self.m1,
-                'upgrade_token': "upgrades",
-                'downgrade_token': "downgrades",
-                'alembic_module_prefix': 'op.',
-                'sqlalchemy_module_prefix': 'sa.',
-                'render_as_batch': True,
-                'include_symbol': lambda name, schema: False
-            }
+                "compare_type": True,
+                "compare_server_default": True,
+                "target_metadata": self.m1,
+                "upgrade_token": "upgrades",
+                "downgrade_token": "downgrades",
+                "alembic_module_prefix": "op.",
+                "sqlalchemy_module_prefix": "sa.",
+                "render_as_batch": True,
+                "include_symbol": lambda name, schema: False,
+            },
         )
         template_args = {}
         autogenerate._render_migration_diffs(context, template_args)
 
-        eq_(re.sub(r"u'", "'", template_args['upgrades']),
+        eq_(
+            re.sub(r"u'", "'", template_args["upgrades"]),
             """# ### commands auto generated by Alembic - please adjust! ###
     pass
-    # ### end Alembic commands ###""")
-        eq_(re.sub(r"u'", "'", template_args['downgrades']),
+    # ### end Alembic commands ###""",
+        )
+        eq_(
+            re.sub(r"u'", "'", template_args["downgrades"]),
             """# ### commands auto generated by Alembic - please adjust! ###
     pass
-    # ### end Alembic commands ###""")
+    # ### end Alembic commands ###""",
+        )
 
     def test_render_diffs_standard(self):
         """test a full render including indentation"""
 
         template_args = {}
         autogenerate._render_migration_diffs(self.context, template_args)
-        eq_(re.sub(r"u'", "'", template_args['upgrades']),
+        eq_(
+            re.sub(r"u'", "'", template_args["upgrades"]),
             """# ### commands auto generated by Alembic - please adjust! ###
     op.create_table('item',
     sa.Column('id', sa.Integer(), nullable=False),
@@ -96,9 +105,11 @@ nullable=True))
                nullable=False)
     op.drop_index('pw_idx', table_name='user')
     op.drop_column('user', 'pw')
-    # ### end Alembic commands ###""")
+    # ### end Alembic commands ###""",
+        )
 
-        eq_(re.sub(r"u'", "'", template_args['downgrades']),
+        eq_(
+            re.sub(r"u'", "'", template_args["downgrades"]),
             """# ### commands auto generated by Alembic - please adjust! ###
     op.add_column('user', sa.Column('pw', sa.VARCHAR(length=50), \
 nullable=True))
@@ -125,16 +136,18 @@ nullable=True))
     sa.ForeignKeyConstraint(['uid'], ['user.id'], )
     )
     op.drop_table('item')
-    # ### end Alembic commands ###""")
+    # ### end Alembic commands ###""",
+        )
 
     def test_render_diffs_batch(self):
         """test a full render in batch mode including indentation"""
 
         template_args = {}
-        self.context.opts['render_as_batch'] = True
+        self.context.opts["render_as_batch"] = True
         autogenerate._render_migration_diffs(self.context, template_args)
 
-        eq_(re.sub(r"u'", "'", template_args['upgrades']),
+        eq_(
+            re.sub(r"u'", "'", template_args["upgrades"]),
             """# ### commands auto generated by Alembic - please adjust! ###
     op.create_table('item',
     sa.Column('id', sa.Integer(), nullable=False),
@@ -169,9 +182,11 @@ nullable=True))
         batch_op.drop_index('pw_idx')
         batch_op.drop_column('pw')
 
-    # ### end Alembic commands ###""")
+    # ### end Alembic commands ###""",
+        )
 
-        eq_(re.sub(r"u'", "'", template_args['downgrades']),
+        eq_(
+            re.sub(r"u'", "'", template_args["downgrades"]),
             """# ### commands auto generated by Alembic - please adjust! ###
     with op.batch_alter_table('user', schema=None) as batch_op:
         batch_op.add_column(sa.Column('pw', sa.VARCHAR(length=50), nullable=True))
@@ -203,74 +218,80 @@ nullable=True))
     sa.ForeignKeyConstraint(['uid'], ['user.id'], )
     )
     op.drop_table('item')
-    # ### end Alembic commands ###""")
+    # ### end Alembic commands ###""",
+        )
 
     def test_imports_maintined(self):
         template_args = {}
-        self.context.opts['render_as_batch'] = True
+        self.context.opts["render_as_batch"] = True
 
         def render_item(type_, col, autogen_context):
             autogen_context.imports.add(
                 "from mypackage import my_special_import"
             )
-            autogen_context.imports.add(
-                "from foobar import bat"
-            )
+            autogen_context.imports.add("from foobar import bat")
 
         self.context.opts["render_item"] = render_item
         autogenerate._render_migration_diffs(self.context, template_args)
         eq_(
+            set(template_args["imports"].split("\n")),
             set(
-                template_args['imports'].split("\n")
+                [
+                    "from foobar import bat",
+                    "from mypackage import my_special_import",
+                ]
             ),
-            set([
-                "from foobar import bat",
-                "from mypackage import my_special_import"
-            ])
         )
 
 
 class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
-    __only_on__ = 'postgresql'
+    __only_on__ = "postgresql"
     schema = "test_schema"
 
     def test_render_nothing(self):
         context = MigrationContext.configure(
             connection=self.bind.connect(),
             opts={
-                'compare_type': True,
-                'compare_server_default': True,
-                'target_metadata': self.m1,
-                'upgrade_token': "upgrades",
-                'downgrade_token': "downgrades",
-                'alembic_module_prefix': 'op.',
-                'sqlalchemy_module_prefix': 'sa.',
-                'include_symbol': lambda name, schema: False
-            }
+                "compare_type": True,
+                "compare_server_default": True,
+                "target_metadata": self.m1,
+                "upgrade_token": "upgrades",
+                "downgrade_token": "downgrades",
+                "alembic_module_prefix": "op.",
+                "sqlalchemy_module_prefix": "sa.",
+                "include_symbol": lambda name, schema: False,
+            },
         )
         template_args = {}
         autogenerate._render_migration_diffs(context, template_args)
 
-        eq_(re.sub(r"u'", "'", template_args['upgrades']),
+        eq_(
+            re.sub(r"u'", "'", template_args["upgrades"]),
             """# ### commands auto generated by Alembic - please adjust! ###
     pass
-    # ### end Alembic commands ###""")
-        eq_(re.sub(r"u'", "'", template_args['downgrades']),
+    # ### end Alembic commands ###""",
+        )
+        eq_(
+            re.sub(r"u'", "'", template_args["downgrades"]),
             """# ### commands auto generated by Alembic - please adjust! ###
     pass
-    # ### end Alembic commands ###""")
+    # ### end Alembic commands ###""",
+        )
 
     def test_render_diffs_extras(self):
         """test a full render including indentation (include and schema)"""
 
         template_args = {}
-        self.context.opts.update({
-            'include_object': _default_include_object,
-            'include_schemas': True
-        })
+        self.context.opts.update(
+            {
+                "include_object": _default_include_object,
+                "include_schemas": True,
+            }
+        )
         autogenerate._render_migration_diffs(self.context, template_args)
 
-        eq_(re.sub(r"u'", "'", template_args['upgrades']),
+        eq_(
+            re.sub(r"u'", "'", template_args["upgrades"]),
             """# ### commands auto generated by Alembic - please adjust! ###
     op.create_table('item',
     sa.Column('id', sa.Integer(), nullable=False),
@@ -307,9 +328,12 @@ source_schema='%(schema)s', referent_schema='%(schema)s')
                schema='%(schema)s')
     op.drop_index('pw_idx', table_name='user', schema='test_schema')
     op.drop_column('user', 'pw', schema='%(schema)s')
-    # ### end Alembic commands ###""" % {"schema": self.schema})
+    # ### end Alembic commands ###"""
+            % {"schema": self.schema},
+        )
 
-        eq_(re.sub(r"u'", "'", template_args['downgrades']),
+        eq_(
+            re.sub(r"u'", "'", template_args["downgrades"]),
             """# ### commands auto generated by Alembic - please adjust! ###
     op.add_column('user', sa.Column('pw', sa.VARCHAR(length=50), \
 autoincrement=False, nullable=True), schema='%(schema)s')
@@ -341,5 +365,6 @@ name='extra_uid_fkey'),
     schema='%(schema)s'
     )
     op.drop_table('item', schema='%(schema)s')
-    # ### end Alembic commands ###""" % {"schema": self.schema})
-
+    # ### end Alembic commands ###"""
+            % {"schema": self.schema},
+        )
index 38e06b6cfe5c19cb15ccb0555e035080877da2cc..af2e2c28addcdacd3ad6a17483b21d9ea2b5aada 100644 (file)
@@ -1,10 +1,30 @@
 import sys
 
-from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
-    Numeric, CHAR, ForeignKey, INTEGER, Index, UniqueConstraint, \
-    TypeDecorator, CheckConstraint, text, PrimaryKeyConstraint, \
-    ForeignKeyConstraint, VARCHAR, DECIMAL, DateTime, BigInteger, BIGINT, \
-    SmallInteger
+from sqlalchemy import (
+    MetaData,
+    Column,
+    Table,
+    Integer,
+    String,
+    Text,
+    Numeric,
+    CHAR,
+    ForeignKey,
+    INTEGER,
+    Index,
+    UniqueConstraint,
+    TypeDecorator,
+    CheckConstraint,
+    text,
+    PrimaryKeyConstraint,
+    ForeignKeyConstraint,
+    VARCHAR,
+    DECIMAL,
+    DateTime,
+    BigInteger,
+    BIGINT,
+    SmallInteger,
+)
 from sqlalchemy.dialects import sqlite
 from sqlalchemy.types import NULLTYPE, VARBINARY
 from sqlalchemy.engine.reflection import Inspector
@@ -20,62 +40,41 @@ from alembic.testing import eq_, is_, is_not_
 from alembic.util import CommandError
 from ._autogen_fixtures import AutogenTest, AutogenFixtureTest
 
-py3k = sys.version_info >= (3, )
+py3k = sys.version_info >= (3,)
 
 
 class AutogenCrossSchemaTest(AutogenTest, TestBase):
-    __only_on__ = 'postgresql'
+    __only_on__ = "postgresql"
     __backend__ = True
 
     @classmethod
     def _get_db_schema(cls):
         m = MetaData()
-        Table('t1', m,
-              Column('x', Integer)
-              )
-        Table('t2', m,
-              Column('y', Integer),
-              schema=config.test_schema
-              )
-        Table('t6', m,
-              Column('u', Integer)
-              )
-        Table('t7', m,
-              Column('v', Integer),
-              schema=config.test_schema
-              )
+        Table("t1", m, Column("x", Integer))
+        Table("t2", m, Column("y", Integer), schema=config.test_schema)
+        Table("t6", m, Column("u", Integer))
+        Table("t7", m, Column("v", Integer), schema=config.test_schema)
 
         return m
 
     @classmethod
     def _get_model_schema(cls):
         m = MetaData()
-        Table('t3', m,
-              Column('q', Integer)
-              )
-        Table('t4', m,
-              Column('z', Integer),
-              schema=config.test_schema
-              )
-        Table('t6', m,
-              Column('u', Integer)
-              )
-        Table('t7', m,
-              Column('v', Integer),
-              schema=config.test_schema
-              )
+        Table("t3", m, Column("q", Integer))
+        Table("t4", m, Column("z", Integer), schema=config.test_schema)
+        Table("t6", m, Column("u", Integer))
+        Table("t7", m, Column("v", Integer), schema=config.test_schema)
         return m
 
     def test_default_schema_omitted_upgrade(self):
-
         def include_object(obj, name, type_, reflected, compare_to):
             if type_ == "table":
                 return name == "t3"
             else:
                 return True
+
         self._update_context(
-            object_filters=include_object,
-            include_schemas=True,
+            object_filters=include_object, include_schemas=True
         )
         uo = ops.UpgradeOps(ops=[])
         autogenerate._produce_net_changes(self.autogen_context, uo)
@@ -85,7 +84,6 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
         eq_(diffs[0][1].schema, None)
 
     def test_alt_schema_included_upgrade(self):
-
         def include_object(obj, name, type_, reflected, compare_to):
             if type_ == "table":
                 return name == "t4"
@@ -93,8 +91,7 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
                 return True
 
         self._update_context(
-            object_filters=include_object,
-            include_schemas=True,
+            object_filters=include_object, include_schemas=True
         )
         uo = ops.UpgradeOps(ops=[])
         autogenerate._produce_net_changes(self.autogen_context, uo)
@@ -109,9 +106,9 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
                 return name == "t1"
             else:
                 return True
+
         self._update_context(
-            object_filters=include_object,
-            include_schemas=True,
+            object_filters=include_object, include_schemas=True
         )
         uo = ops.UpgradeOps(ops=[])
         autogenerate._produce_net_changes(self.autogen_context, uo)
@@ -121,15 +118,14 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
         eq_(diffs[0][1].schema, None)
 
     def test_alt_schema_included_downgrade(self):
-
         def include_object(obj, name, type_, reflected, compare_to):
             if type_ == "table":
                 return name == "t2"
             else:
                 return True
+
         self._update_context(
-            object_filters=include_object,
-            include_schemas=True,
+            object_filters=include_object, include_schemas=True
         )
         uo = ops.UpgradeOps(ops=[])
         autogenerate._produce_net_changes(self.autogen_context, uo)
@@ -139,7 +135,7 @@ class AutogenCrossSchemaTest(AutogenTest, TestBase):
 
 
 class AutogenDefaultSchemaTest(AutogenFixtureTest, TestBase):
-    __only_on__ = 'postgresql'
+    __only_on__ = "postgresql"
     __backend__ = True
 
     def test_uses_explcit_schema_in_default_one(self):
@@ -149,8 +145,8 @@ class AutogenDefaultSchemaTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('a', m1, Column('x', String(50)))
-        Table('a', m2, Column('x', String(50)), schema=default_schema)
+        Table("a", m1, Column("x", String(50)))
+        Table("a", m2, Column("x", String(50)), schema=default_schema)
 
         diffs = self._fixture(m1, m2, include_schemas=True)
         eq_(diffs, [])
@@ -162,15 +158,15 @@ class AutogenDefaultSchemaTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('a', m1, Column('x', String(50)))
-        Table('a', m2, Column('x', String(50)), schema=default_schema)
-        Table('a', m2, Column('y', String(50)), schema="test_schema")
+        Table("a", m1, Column("x", String(50)))
+        Table("a", m2, Column("x", String(50)), schema=default_schema)
+        Table("a", m2, Column("y", String(50)), schema="test_schema")
 
         diffs = self._fixture(m1, m2, include_schemas=True)
         eq_(len(diffs), 1)
         eq_(diffs[0][0], "add_table")
         eq_(diffs[0][1].schema, "test_schema")
-        eq_(diffs[0][1].c.keys(), ['y'])
+        eq_(diffs[0][1].c.keys(), ["y"])
 
     def test_uses_explcit_schema_in_default_three(self):
 
@@ -179,20 +175,20 @@ class AutogenDefaultSchemaTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('a', m1, Column('y', String(50)), schema="test_schema")
+        Table("a", m1, Column("y", String(50)), schema="test_schema")
 
-        Table('a', m2, Column('x', String(50)), schema=default_schema)
-        Table('a', m2, Column('y', String(50)), schema="test_schema")
+        Table("a", m2, Column("x", String(50)), schema=default_schema)
+        Table("a", m2, Column("y", String(50)), schema="test_schema")
 
         diffs = self._fixture(m1, m2, include_schemas=True)
         eq_(len(diffs), 1)
         eq_(diffs[0][0], "add_table")
         eq_(diffs[0][1].schema, default_schema)
-        eq_(diffs[0][1].c.keys(), ['x'])
+        eq_(diffs[0][1].c.keys(), ["x"])
 
 
 class AutogenDefaultSchemaIsNoneTest(AutogenFixtureTest, TestBase):
-    __only_on__ = 'sqlite'
+    __only_on__ = "sqlite"
 
     def setUp(self):
         super(AutogenDefaultSchemaIsNoneTest, self).setUp()
@@ -205,23 +201,23 @@ class AutogenDefaultSchemaIsNoneTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('a', m1, Column('x', String(50)))
-        Table('a', m2, Column('x', String(50)))
+        Table("a", m1, Column("x", String(50)))
+        Table("a", m2, Column("x", String(50)))
 
         def _include_object(obj, name, type_, reflected, compare_to):
             if type_ == "table":
-                return name in 'a' and obj.schema != 'main'
+                return name in "a" and obj.schema != "main"
             else:
                 return True
 
         diffs = self._fixture(
-            m1, m2, include_schemas=True,
-            object_filters=_include_object)
+            m1, m2, include_schemas=True, object_filters=_include_object
+        )
         eq_(len(diffs), 0)
 
 
 class ModelOne(object):
-    __requires__ = ('unique_constraint_reflection', )
+    __requires__ = ("unique_constraint_reflection",)
 
     schema = None
 
@@ -231,30 +227,42 @@ class ModelOne(object):
 
         m = MetaData(schema=schema)
 
-        Table('user', m,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50)),
-              Column('a1', Text),
-              Column("pw", String(50)),
-              Index('pw_idx', 'pw')
-              )
-
-        Table('address', m,
-              Column('id', Integer, primary_key=True),
-              Column('email_address', String(100), nullable=False),
-              )
-
-        Table('order', m,
-              Column('order_id', Integer, primary_key=True),
-              Column("amount", Numeric(8, 2), nullable=False,
-                     server_default=text("0")),
-              CheckConstraint('amount >= 0', name='ck_order_amount')
-              )
-
-        Table('extra', m,
-              Column("x", CHAR),
-              Column('uid', Integer, ForeignKey('user.id'))
-              )
+        Table(
+            "user",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50)),
+            Column("a1", Text),
+            Column("pw", String(50)),
+            Index("pw_idx", "pw"),
+        )
+
+        Table(
+            "address",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("email_address", String(100), nullable=False),
+        )
+
+        Table(
+            "order",
+            m,
+            Column("order_id", Integer, primary_key=True),
+            Column(
+                "amount",
+                Numeric(8, 2),
+                nullable=False,
+                server_default=text("0"),
+            ),
+            CheckConstraint("amount >= 0", name="ck_order_amount"),
+        )
+
+        Table(
+            "extra",
+            m,
+            Column("x", CHAR),
+            Column("uid", Integer, ForeignKey("user.id")),
+        )
 
         return m
 
@@ -264,38 +272,50 @@ class ModelOne(object):
 
         m = MetaData(schema=schema)
 
-        Table('user', m,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('a1', Text, server_default="x")
-              )
-
-        Table('address', m,
-              Column('id', Integer, primary_key=True),
-              Column('email_address', String(100), nullable=False),
-              Column('street', String(50)),
-              UniqueConstraint('email_address', name="uq_email")
-              )
-
-        Table('order', m,
-              Column('order_id', Integer, primary_key=True),
-              Column('amount', Numeric(10, 2), nullable=True,
-                     server_default=text("0")),
-              Column('user_id', Integer, ForeignKey('user.id')),
-              CheckConstraint('amount > -1', name='ck_order_amount'),
-              )
-
-        Table('item', m,
-              Column('id', Integer, primary_key=True),
-              Column('description', String(100)),
-              Column('order_id', Integer, ForeignKey('order.order_id')),
-              CheckConstraint('len(description) > 5')
-              )
+        Table(
+            "user",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("a1", Text, server_default="x"),
+        )
+
+        Table(
+            "address",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("email_address", String(100), nullable=False),
+            Column("street", String(50)),
+            UniqueConstraint("email_address", name="uq_email"),
+        )
+
+        Table(
+            "order",
+            m,
+            Column("order_id", Integer, primary_key=True),
+            Column(
+                "amount",
+                Numeric(10, 2),
+                nullable=True,
+                server_default=text("0"),
+            ),
+            Column("user_id", Integer, ForeignKey("user.id")),
+            CheckConstraint("amount > -1", name="ck_order_amount"),
+        )
+
+        Table(
+            "item",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("description", String(100)),
+            Column("order_id", Integer, ForeignKey("order.order_id")),
+            CheckConstraint("len(description) > 5"),
+        )
         return m
 
 
 class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
-    __only_on__ = 'sqlite'
+    __only_on__ = "sqlite"
 
     def test_diffs(self):
         """test generation of diff rules"""
@@ -304,23 +324,18 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
         uo = ops.UpgradeOps(ops=[])
         ctx = self.autogen_context
 
-        autogenerate._produce_net_changes(
-            ctx, uo
-        )
+        autogenerate._produce_net_changes(ctx, uo)
 
         diffs = uo.as_diffs()
-        eq_(
-            diffs[0],
-            ('add_table', metadata.tables['item'])
-        )
+        eq_(diffs[0], ("add_table", metadata.tables["item"]))
 
-        eq_(diffs[1][0], 'remove_table')
+        eq_(diffs[1][0], "remove_table")
         eq_(diffs[1][1].name, "extra")
 
         eq_(diffs[2][0], "add_column")
         eq_(diffs[2][1], None)
         eq_(diffs[2][2], "address")
-        eq_(diffs[2][3], metadata.tables['address'].c.street)
+        eq_(diffs[2][3], metadata.tables["address"].c.street)
 
         eq_(diffs[3][0], "add_constraint")
         eq_(diffs[3][1].name, "uq_email")
@@ -328,7 +343,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
         eq_(diffs[4][0], "add_column")
         eq_(diffs[4][1], None)
         eq_(diffs[4][2], "order")
-        eq_(diffs[4][3], metadata.tables['order'].c.user_id)
+        eq_(diffs[4][3], metadata.tables["order"].c.user_id)
 
         eq_(diffs[5][0][0], "modify_type")
         eq_(diffs[5][0][1], None)
@@ -338,9 +353,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
         eq_(repr(diffs[5][0][6]), "Numeric(precision=10, scale=2)")
 
         self._assert_fk_diff(
-            diffs[6], "add_fk",
-            "order", ["user_id"],
-            "user", ["id"]
+            diffs[6], "add_fk", "order", ["user_id"], "user", ["id"]
         )
 
         eq_(diffs[7][0][0], "modify_default")
@@ -349,45 +362,47 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
         eq_(diffs[7][0][3], "a1")
         eq_(diffs[7][0][6].arg, "x")
 
-        eq_(diffs[8][0][0], 'modify_nullable')
+        eq_(diffs[8][0][0], "modify_nullable")
         eq_(diffs[8][0][5], True)
         eq_(diffs[8][0][6], False)
 
-        eq_(diffs[9][0], 'remove_index')
-        eq_(diffs[9][1].name, 'pw_idx')
+        eq_(diffs[9][0], "remove_index")
+        eq_(diffs[9][1].name, "pw_idx")
 
-        eq_(diffs[10][0], 'remove_column')
-        eq_(diffs[10][3].name, 'pw')
-        eq_(diffs[10][3].table.name, 'user')
-        assert isinstance(
-            diffs[10][3].type, String
-        )
+        eq_(diffs[10][0], "remove_column")
+        eq_(diffs[10][3].name, "pw")
+        eq_(diffs[10][3].table.name, "user")
+        assert isinstance(diffs[10][3].type, String)
 
     def test_include_symbol(self):
 
         diffs = []
 
         def include_symbol(name, schema=None):
-            return name in ('address', 'order')
+            return name in ("address", "order")
 
         context = MigrationContext.configure(
             connection=self.bind.connect(),
             opts={
-                'compare_type': True,
-                'compare_server_default': True,
-                'target_metadata': self.m2,
-                'include_symbol': include_symbol,
-            }
+                "compare_type": True,
+                "compare_server_default": True,
+                "target_metadata": self.m2,
+                "include_symbol": include_symbol,
+            },
         )
 
         diffs = autogenerate.compare_metadata(
-            context, context.opts['target_metadata'])
+            context, context.opts["target_metadata"]
+        )
 
-        alter_cols = set([
-            d[2] for d in self._flatten_diffs(diffs)
-            if d[0].startswith('modify')
-        ])
-        eq_(alter_cols, set(['order']))
+        alter_cols = set(
+            [
+                d[2]
+                for d in self._flatten_diffs(diffs)
+                if d[0].startswith("modify")
+            ]
+        )
+        eq_(alter_cols, set(["order"]))
 
     def test_include_object(self):
         def include_object(obj, name, type_, reflected, compare_to):
@@ -410,33 +425,46 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
         context = MigrationContext.configure(
             connection=self.bind.connect(),
             opts={
-                'compare_type': True,
-                'compare_server_default': True,
-                'target_metadata': self.m2,
-                'include_object': include_object,
-            }
+                "compare_type": True,
+                "compare_server_default": True,
+                "target_metadata": self.m2,
+                "include_object": include_object,
+            },
         )
 
         diffs = autogenerate.compare_metadata(
-            context, context.opts['target_metadata'])
+            context, context.opts["target_metadata"]
+        )
 
-        alter_cols = set([
-            d[2] for d in self._flatten_diffs(diffs)
-            if d[0].startswith('modify')
-        ]).union(
-            d[3].name for d in self._flatten_diffs(diffs)
-            if d[0] == 'add_column'
-        ).union(
-            d[1].name for d in self._flatten_diffs(diffs)
-            if d[0] == 'add_table'
+        alter_cols = (
+            set(
+                [
+                    d[2]
+                    for d in self._flatten_diffs(diffs)
+                    if d[0].startswith("modify")
+                ]
+            )
+            .union(
+                d[3].name
+                for d in self._flatten_diffs(diffs)
+                if d[0] == "add_column"
+            )
+            .union(
+                d[1].name
+                for d in self._flatten_diffs(diffs)
+                if d[0] == "add_table"
+            )
         )
-        eq_(alter_cols, set(['user_id', 'order', 'user']))
+        eq_(alter_cols, set(["user_id", "order", "user"]))
 
     def test_skip_null_type_comparison_reflected(self):
         ac = ops.AlterColumnOp("sometable", "somecol")
         autogenerate.compare._compare_type(
-            self.autogen_context, ac,
-            None, "sometable", "somecol",
+            self.autogen_context,
+            ac,
+            None,
+            "sometable",
+            "somecol",
             Column("somecol", NULLTYPE),
             Column("somecol", Integer()),
         )
@@ -446,8 +474,11 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
     def test_skip_null_type_comparison_local(self):
         ac = ops.AlterColumnOp("sometable", "somecol")
         autogenerate.compare._compare_type(
-            self.autogen_context, ac,
-            None, "sometable", "somecol",
+            self.autogen_context,
+            ac,
+            None,
+            "sometable",
+            "somecol",
             Column("somecol", Integer()),
             Column("somecol", NULLTYPE),
         )
@@ -463,8 +494,11 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
 
         ac = ops.AlterColumnOp("sometable", "somecol")
         autogenerate.compare._compare_type(
-            self.autogen_context, ac,
-            None, "sometable", "somecol",
+            self.autogen_context,
+            ac,
+            None,
+            "sometable",
+            "somecol",
             Column("somecol", INTEGER()),
             Column("somecol", MyType()),
         )
@@ -473,56 +507,63 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
 
         ac = ops.AlterColumnOp("sometable", "somecol")
         autogenerate.compare._compare_type(
-            self.autogen_context, ac,
-            None, "sometable", "somecol",
+            self.autogen_context,
+            ac,
+            None,
+            "sometable",
+            "somecol",
             Column("somecol", String()),
             Column("somecol", MyType()),
         )
         diff = ac.to_diff_tuple()
-        eq_(
-            diff[0][0:4],
-            ('modify_type', None, 'sometable', 'somecol')
-        )
+        eq_(diff[0][0:4], ("modify_type", None, "sometable", "somecol"))
 
     def test_affinity_typedec(self):
         class MyType(TypeDecorator):
             impl = CHAR
 
             def load_dialect_impl(self, dialect):
-                if dialect.name == 'sqlite':
+                if dialect.name == "sqlite":
                     return dialect.type_descriptor(Integer())
                 else:
                     return dialect.type_descriptor(CHAR(32))
 
-        uo = ops.AlterColumnOp('sometable', 'somecol')
+        uo = ops.AlterColumnOp("sometable", "somecol")
         autogenerate.compare._compare_type(
-            self.autogen_context, uo,
-            None, "sometable", "somecol",
+            self.autogen_context,
+            uo,
+            None,
+            "sometable",
+            "somecol",
             Column("somecol", Integer, nullable=True),
-            Column("somecol", MyType())
+            Column("somecol", MyType()),
         )
         assert not uo.has_changes()
 
     def test_dont_barf_on_already_reflected(self):
         from sqlalchemy.util import OrderedSet
+
         inspector = Inspector.from_engine(self.bind)
         uo = ops.UpgradeOps(ops=[])
         autogenerate.compare._compare_tables(
-            OrderedSet([(None, 'extra'), (None, 'user')]),
-            OrderedSet(), inspector,
-            uo, self.autogen_context
+            OrderedSet([(None, "extra"), (None, "user")]),
+            OrderedSet(),
+            inspector,
+            uo,
+            self.autogen_context,
         )
         eq_(
             [(rec[0], rec[1].name) for rec in uo.as_diffs()],
             [
-                ('remove_table', 'extra'),
-                ('remove_index', 'pw_idx'),
-                ('remove_table', 'user'), ]
+                ("remove_table", "extra"),
+                ("remove_index", "pw_idx"),
+                ("remove_table", "user"),
+            ],
         )
 
 
 class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
-    __only_on__ = 'postgresql'
+    __only_on__ = "postgresql"
     __backend__ = True
     schema = "test_schema"
 
@@ -531,26 +572,21 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
 
         metadata = self.m2
 
-        self._update_context(
-            include_schemas=True,
-        )
+        self._update_context(include_schemas=True)
         uo = ops.UpgradeOps(ops=[])
         autogenerate._produce_net_changes(self.autogen_context, uo)
 
         diffs = uo.as_diffs()
 
-        eq_(
-            diffs[0],
-            ('add_table', metadata.tables['%s.item' % self.schema])
-        )
+        eq_(diffs[0], ("add_table", metadata.tables["%s.item" % self.schema]))
 
-        eq_(diffs[1][0], 'remove_table')
+        eq_(diffs[1][0], "remove_table")
         eq_(diffs[1][1].name, "extra")
 
         eq_(diffs[2][0], "add_column")
         eq_(diffs[2][1], self.schema)
         eq_(diffs[2][2], "address")
-        eq_(diffs[2][3], metadata.tables['%s.address' % self.schema].c.street)
+        eq_(diffs[2][3], metadata.tables["%s.address" % self.schema].c.street)
 
         eq_(diffs[3][0], "add_constraint")
         eq_(diffs[3][1].name, "uq_email")
@@ -558,7 +594,7 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
         eq_(diffs[4][0], "add_column")
         eq_(diffs[4][1], self.schema)
         eq_(diffs[4][2], "order")
-        eq_(diffs[4][3], metadata.tables['%s.order' % self.schema].c.user_id)
+        eq_(diffs[4][3], metadata.tables["%s.order" % self.schema].c.user_id)
 
         eq_(diffs[5][0][0], "modify_type")
         eq_(diffs[5][0][1], self.schema)
@@ -568,10 +604,13 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
         eq_(repr(diffs[5][0][6]), "Numeric(precision=10, scale=2)")
 
         self._assert_fk_diff(
-            diffs[6], "add_fk",
-            "order", ["user_id"],
-            "user", ["id"],
-            source_schema=config.test_schema
+            diffs[6],
+            "add_fk",
+            "order",
+            ["user_id"],
+            "user",
+            ["id"],
+            source_schema=config.test_schema,
         )
 
         eq_(diffs[7][0][0], "modify_default")
@@ -580,15 +619,15 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
         eq_(diffs[7][0][3], "a1")
         eq_(diffs[7][0][6].arg, "x")
 
-        eq_(diffs[8][0][0], 'modify_nullable')
+        eq_(diffs[8][0][0], "modify_nullable")
         eq_(diffs[8][0][5], True)
         eq_(diffs[8][0][6], False)
 
-        eq_(diffs[9][0], 'remove_index')
-        eq_(diffs[9][1].name, 'pw_idx')
+        eq_(diffs[9][0], "remove_index")
+        eq_(diffs[9][1].name, "pw_idx")
 
-        eq_(diffs[10][0], 'remove_column')
-        eq_(diffs[10][3].name, 'pw')
+        eq_(diffs[10][0], "remove_column")
+        eq_(diffs[10][3].name, "pw")
 
 
 class CompareTypeSpecificityTest(TestBase):
@@ -597,10 +636,10 @@ class CompareTypeSpecificityTest(TestBase):
         from sqlalchemy.engine import default
 
         return impl.DefaultImpl(
-            default.DefaultDialect(), None, False, True, None, {})
+            default.DefaultDialect(), None, False, True, None, {}
+        )
 
     def test_typedec_to_nonstandard(self):
-
         class PasswordType(TypeDecorator):
             impl = VARBINARY
 
@@ -608,7 +647,7 @@ class CompareTypeSpecificityTest(TestBase):
                 return PasswordType(self.impl.length)
 
             def load_dialect_impl(self, dialect):
-                if dialect.name == 'default':
+                if dialect.name == "default":
                     impl = sqlite.NUMERIC(self.length)
                 else:
                     impl = VARBINARY(self.length)
@@ -616,8 +655,8 @@ class CompareTypeSpecificityTest(TestBase):
 
         impl = self._fixture()
         impl.compare_type(
-            Column('x', sqlite.NUMERIC(50)),
-            Column('x', PasswordType(50)))
+            Column("x", sqlite.NUMERIC(50)), Column("x", PasswordType(50))
+        )
 
     def test_string(self):
         t1 = String(30)
@@ -626,9 +665,9 @@ class CompareTypeSpecificityTest(TestBase):
         t4 = Integer
 
         impl = self._fixture()
-        is_(impl.compare_type(Column('x', t3), Column('x', t1)), False)
-        is_(impl.compare_type(Column('x', t3), Column('x', t2)), True)
-        is_(impl.compare_type(Column('x', t3), Column('x', t4)), True)
+        is_(impl.compare_type(Column("x", t3), Column("x", t1)), False)
+        is_(impl.compare_type(Column("x", t3), Column("x", t2)), True)
+        is_(impl.compare_type(Column("x", t3), Column("x", t4)), True)
 
     def test_numeric(self):
         t1 = Numeric(10, 5)
@@ -637,16 +676,16 @@ class CompareTypeSpecificityTest(TestBase):
         t4 = DateTime
 
         impl = self._fixture()
-        is_(impl.compare_type(Column('x', t3), Column('x', t1)), False)
-        is_(impl.compare_type(Column('x', t3), Column('x', t2)), True)
-        is_(impl.compare_type(Column('x', t3), Column('x', t4)), True)
+        is_(impl.compare_type(Column("x", t3), Column("x", t1)), False)
+        is_(impl.compare_type(Column("x", t3), Column("x", t2)), True)
+        is_(impl.compare_type(Column("x", t3), Column("x", t4)), True)
 
     def test_numeric_noprecision(self):
         t1 = Numeric()
         t2 = Numeric(scale=5)
 
         impl = self._fixture()
-        is_(impl.compare_type(Column('x', t1), Column('x', t2)), False)
+        is_(impl.compare_type(Column("x", t1), Column("x", t2)), False)
 
     def test_integer(self):
         t1 = Integer()
@@ -657,12 +696,12 @@ class CompareTypeSpecificityTest(TestBase):
         t6 = BigInteger()
 
         impl = self._fixture()
-        is_(impl.compare_type(Column('x', t5), Column('x', t1)), False)
-        is_(impl.compare_type(Column('x', t3), Column('x', t1)), True)
-        is_(impl.compare_type(Column('x', t3), Column('x', t6)), False)
-        is_(impl.compare_type(Column('x', t3), Column('x', t2)), True)
-        is_(impl.compare_type(Column('x', t5), Column('x', t2)), True)
-        is_(impl.compare_type(Column('x', t1), Column('x', t4)), True)
+        is_(impl.compare_type(Column("x", t5), Column("x", t1)), False)
+        is_(impl.compare_type(Column("x", t3), Column("x", t1)), True)
+        is_(impl.compare_type(Column("x", t3), Column("x", t6)), False)
+        is_(impl.compare_type(Column("x", t3), Column("x", t2)), True)
+        is_(impl.compare_type(Column("x", t5), Column("x", t2)), True)
+        is_(impl.compare_type(Column("x", t1), Column("x", t4)), True)
 
     def test_datetime(self):
         t1 = DateTime()
@@ -670,22 +709,19 @@ class CompareTypeSpecificityTest(TestBase):
         t3 = DateTime(timezone=True)
 
         impl = self._fixture()
-        is_(impl.compare_type(Column('x', t1), Column('x', t2)), False)
-        is_(impl.compare_type(Column('x', t1), Column('x', t3)), True)
-        is_(impl.compare_type(Column('x', t2), Column('x', t3)), True)
+        is_(impl.compare_type(Column("x", t1), Column("x", t2)), False)
+        is_(impl.compare_type(Column("x", t1), Column("x", t3)), True)
+        is_(impl.compare_type(Column("x", t2), Column("x", t3)), True)
 
 
 class AutogenSystemColTest(AutogenTest, TestBase):
-    __only_on__ = 'postgresql'
+    __only_on__ = "postgresql"
 
     @classmethod
     def _get_db_schema(cls):
         m = MetaData()
 
-        Table(
-            'sometable', m,
-            Column('id', Integer, primary_key=True),
-        )
+        Table("sometable", m, Column("id", Integer, primary_key=True))
         return m
 
     @classmethod
@@ -695,9 +731,10 @@ class AutogenSystemColTest(AutogenTest, TestBase):
         # 'xmin' is implicitly present, when added to a model should produce
         # no change
         Table(
-            'sometable', m,
-            Column('id', Integer, primary_key=True),
-            Column('xmin', Integer, system=True)
+            "sometable",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("xmin", Integer, system=True),
         )
         return m
 
@@ -715,30 +752,38 @@ class AutogenerateVariantCompareTest(AutogenTest, TestBase):
     # 1.0.13 and lower fail on Postgresql due to variant / bigserial issue
     # #3739
 
-    __requires__ = ('sqlalchemy_1014', )
+    __requires__ = ("sqlalchemy_1014",)
 
     @classmethod
     def _get_db_schema(cls):
         m = MetaData()
 
-        Table('sometable', m,
-              Column(
-                  'id',
-                  BigInteger().with_variant(Integer, "sqlite"),
-                  primary_key=True),
-              Column('value', String(50)))
+        Table(
+            "sometable",
+            m,
+            Column(
+                "id",
+                BigInteger().with_variant(Integer, "sqlite"),
+                primary_key=True,
+            ),
+            Column("value", String(50)),
+        )
         return m
 
     @classmethod
     def _get_model_schema(cls):
         m = MetaData()
 
-        Table('sometable', m,
-              Column(
-                  'id',
-                  BigInteger().with_variant(Integer, "sqlite"),
-                  primary_key=True),
-              Column('value', String(50)))
+        Table(
+            "sometable",
+            m,
+            Column(
+                "id",
+                BigInteger().with_variant(Integer, "sqlite"),
+                primary_key=True,
+            ),
+            Column("value", String(50)),
+        )
         return m
 
     def test_variant_no_issue(self):
@@ -750,24 +795,30 @@ class AutogenerateVariantCompareTest(AutogenTest, TestBase):
 
 
 class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase):
-    __only_on__ = 'sqlite'
+    __only_on__ = "sqlite"
 
     @classmethod
     def _get_db_schema(cls):
         m = MetaData()
 
-        Table('sometable', m,
-              Column('id', Integer, primary_key=True),
-              Column('value', Integer))
+        Table(
+            "sometable",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("value", Integer),
+        )
         return m
 
     @classmethod
     def _get_model_schema(cls):
         m = MetaData()
 
-        Table('sometable', m,
-              Column('id', Integer, primary_key=True),
-              Column('value', String))
+        Table(
+            "sometable",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("value", String),
+        )
         return m
 
     def test_uses_custom_compare_type_function(self):
@@ -779,15 +830,20 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase):
         ctx = self.autogen_context
         autogenerate._produce_net_changes(ctx, uo)
 
-        first_table = self.m2.tables['sometable']
-        first_column = first_table.columns['id']
+        first_table = self.m2.tables["sometable"]
+        first_column = first_table.columns["id"]
 
         eq_(len(my_compare_type.mock_calls), 2)
 
         # We'll just test the first call
         _, args, _ = my_compare_type.mock_calls[0]
-        (context, inspected_column, metadata_column,
-         inspected_type, metadata_type) = args
+        (
+            context,
+            inspected_column,
+            metadata_column,
+            inspected_type,
+            metadata_type,
+        ) = args
         eq_(context, self.context)
         eq_(metadata_column, first_column)
         eq_(metadata_type, first_column.type)
@@ -816,8 +872,8 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase):
         autogenerate._produce_net_changes(ctx, uo)
         diffs = uo.as_diffs()
 
-        eq_(diffs[0][0][0], 'modify_type')
-        eq_(diffs[1][0][0], 'modify_type')
+        eq_(diffs[0][0][0], "modify_type")
+        eq_(diffs[1][0][0], "modify_type")
 
 
 class PKConstraintUpgradesIgnoresNullableTest(AutogenTest, TestBase):
@@ -829,10 +885,11 @@ class PKConstraintUpgradesIgnoresNullableTest(AutogenTest, TestBase):
         m = MetaData()
 
         Table(
-            'person_to_role', m,
-            Column('person_id', Integer, autoincrement=False),
-            Column('role_id', Integer, autoincrement=False),
-            PrimaryKeyConstraint('person_id', 'role_id')
+            "person_to_role",
+            m,
+            Column("person_id", Integer, autoincrement=False),
+            Column("role_id", Integer, autoincrement=False),
+            PrimaryKeyConstraint("person_id", "role_id"),
         )
         return m
 
@@ -849,34 +906,40 @@ class PKConstraintUpgradesIgnoresNullableTest(AutogenTest, TestBase):
 
 
 class AutogenKeyTest(AutogenTest, TestBase):
-    __only_on__ = 'sqlite'
+    __only_on__ = "sqlite"
 
     @classmethod
     def _get_db_schema(cls):
         m = MetaData()
 
-        Table('someothertable', m,
-              Column('id', Integer, primary_key=True),
-              Column('value', Integer, key="somekey"),
-              )
+        Table(
+            "someothertable",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("value", Integer, key="somekey"),
+        )
         return m
 
     @classmethod
     def _get_model_schema(cls):
         m = MetaData()
 
-        Table('sometable', m,
-              Column('id', Integer, primary_key=True),
-              Column('value', Integer, key="someotherkey"),
-              )
-        Table('someothertable', m,
-              Column('id', Integer, primary_key=True),
-              Column('value', Integer, key="somekey"),
-              Column("othervalue", Integer, key="otherkey")
-              )
+        Table(
+            "sometable",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("value", Integer, key="someotherkey"),
+        )
+        Table(
+            "someothertable",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("value", Integer, key="somekey"),
+            Column("othervalue", Integer, key="otherkey"),
+        )
         return m
 
-    symbols = ['someothertable', 'sometable']
+    symbols = ["someothertable", "sometable"]
 
     def test_autogen(self):
 
@@ -892,16 +955,19 @@ class AutogenKeyTest(AutogenTest, TestBase):
 
 
 class AutogenVersionTableTest(AutogenTest, TestBase):
-    __only_on__ = 'sqlite'
-    version_table_name = 'alembic_version'
+    __only_on__ = "sqlite"
+    version_table_name = "alembic_version"
     version_table_schema = None
 
     @classmethod
     def _get_db_schema(cls):
         m = MetaData()
         Table(
-            cls.version_table_name, m,
-            Column('x', Integer), schema=cls.version_table_schema)
+            cls.version_table_name,
+            m,
+            Column("x", Integer),
+            schema=cls.version_table_schema,
+        )
         return m
 
     @classmethod
@@ -919,7 +985,10 @@ class AutogenVersionTableTest(AutogenTest, TestBase):
     def test_version_table_in_target(self):
         Table(
             self.version_table_name,
-            self.m2, Column('x', Integer), schema=self.version_table_schema)
+            self.m2,
+            Column("x", Integer),
+            schema=self.version_table_schema,
+        )
 
         ctx = self.autogen_context
         uo = ops.UpgradeOps(ops=[])
@@ -928,29 +997,30 @@ class AutogenVersionTableTest(AutogenTest, TestBase):
 
 
 class AutogenCustomVersionTableSchemaTest(AutogenVersionTableTest):
-    __only_on__ = 'postgresql'
+    __only_on__ = "postgresql"
     __backend__ = True
-    version_table_schema = 'test_schema'
-    configure_opts = {'version_table_schema': 'test_schema'}
+    version_table_schema = "test_schema"
+    configure_opts = {"version_table_schema": "test_schema"}
 
 
 class AutogenCustomVersionTableTest(AutogenVersionTableTest):
-    version_table_name = 'my_version_table'
-    configure_opts = {'version_table': 'my_version_table'}
+    version_table_name = "my_version_table"
+    configure_opts = {"version_table": "my_version_table"}
 
 
 class AutogenCustomVersionTableAndSchemaTest(AutogenVersionTableTest):
-    __only_on__ = 'postgresql'
+    __only_on__ = "postgresql"
     __backend__ = True
-    version_table_name = 'my_version_table'
-    version_table_schema = 'test_schema'
+    version_table_name = "my_version_table"
+    version_table_schema = "test_schema"
     configure_opts = {
-        'version_table': 'my_version_table',
-        'version_table_schema': 'test_schema'}
+        "version_table": "my_version_table",
+        "version_table_schema": "test_schema",
+    }
 
 
 class AutogenerateDiffOrderTest(AutogenTest, TestBase):
-    __only_on__ = 'sqlite'
+    __only_on__ = "sqlite"
 
     @classmethod
     def _get_db_schema(cls):
@@ -959,13 +1029,11 @@ class AutogenerateDiffOrderTest(AutogenTest, TestBase):
     @classmethod
     def _get_model_schema(cls):
         m = MetaData()
-        Table('parent', m,
-              Column('id', Integer, primary_key=True)
-              )
+        Table("parent", m, Column("id", Integer, primary_key=True))
 
-        Table('child', m,
-              Column('parent_id', Integer, ForeignKey('parent.id')),
-              )
+        Table(
+            "child", m, Column("parent_id", Integer, ForeignKey("parent.id"))
+        )
 
         return m
 
@@ -980,32 +1048,29 @@ class AutogenerateDiffOrderTest(AutogenTest, TestBase):
         autogenerate._produce_net_changes(ctx, uo)
         diffs = uo.as_diffs()
 
-        eq_(diffs[0][0], 'add_table')
+        eq_(diffs[0][0], "add_table")
         eq_(diffs[0][1].name, "parent")
-        eq_(diffs[1][0], 'add_table')
+        eq_(diffs[1][0], "add_table")
         eq_(diffs[1][1].name, "child")
 
 
 class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
-    __only_on__ = 'sqlite'
+    __only_on__ = "sqlite"
 
     def test_compare_metadata(self):
         metadata = self.m2
 
         diffs = autogenerate.compare_metadata(self.context, metadata)
 
-        eq_(
-            diffs[0],
-            ('add_table', metadata.tables['item'])
-        )
+        eq_(diffs[0], ("add_table", metadata.tables["item"]))
 
-        eq_(diffs[1][0], 'remove_table')
+        eq_(diffs[1][0], "remove_table")
         eq_(diffs[1][1].name, "extra")
 
         eq_(diffs[2][0], "add_column")
         eq_(diffs[2][1], None)
         eq_(diffs[2][2], "address")
-        eq_(diffs[2][3], metadata.tables['address'].c.street)
+        eq_(diffs[2][3], metadata.tables["address"].c.street)
 
         eq_(diffs[3][0], "add_constraint")
         eq_(diffs[3][1].name, "uq_email")
@@ -1013,7 +1078,7 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
         eq_(diffs[4][0], "add_column")
         eq_(diffs[4][1], None)
         eq_(diffs[4][2], "order")
-        eq_(diffs[4][3], metadata.tables['order'].c.user_id)
+        eq_(diffs[4][3], metadata.tables["order"].c.user_id)
 
         eq_(diffs[5][0][0], "modify_type")
         eq_(diffs[5][0][1], None)
@@ -1023,9 +1088,7 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
         eq_(repr(diffs[5][0][6]), "Numeric(precision=10, scale=2)")
 
         self._assert_fk_diff(
-            diffs[6], "add_fk",
-            "order", ["user_id"],
-            "user", ["id"]
+            diffs[6], "add_fk", "order", ["user_id"], "user", ["id"]
         )
 
         eq_(diffs[7][0][0], "modify_default")
@@ -1034,15 +1097,15 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
         eq_(diffs[7][0][3], "a1")
         eq_(diffs[7][0][6].arg, "x")
 
-        eq_(diffs[8][0][0], 'modify_nullable')
+        eq_(diffs[8][0][0], "modify_nullable")
         eq_(diffs[8][0][5], True)
         eq_(diffs[8][0][6], False)
 
-        eq_(diffs[9][0], 'remove_index')
-        eq_(diffs[9][1].name, 'pw_idx')
+        eq_(diffs[9][0], "remove_index")
+        eq_(diffs[9][1].name, "pw_idx")
 
-        eq_(diffs[10][0], 'remove_column')
-        eq_(diffs[10][3].name, 'pw')
+        eq_(diffs[10][0], "remove_column")
+        eq_(diffs[10][3].name, "pw")
 
     def test_compare_metadata_include_object(self):
         metadata = self.m2
@@ -1058,46 +1121,46 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
         context = MigrationContext.configure(
             connection=self.bind.connect(),
             opts={
-                'compare_type': True,
-                'compare_server_default': True,
-                'include_object': include_object,
-            }
+                "compare_type": True,
+                "compare_server_default": True,
+                "include_object": include_object,
+            },
         )
 
         diffs = autogenerate.compare_metadata(context, metadata)
 
-        eq_(diffs[0][0], 'remove_table')
+        eq_(diffs[0][0], "remove_table")
         eq_(diffs[0][1].name, "extra")
 
         eq_(diffs[1][0], "add_column")
         eq_(diffs[1][1], None)
         eq_(diffs[1][2], "order")
-        eq_(diffs[1][3], metadata.tables['order'].c.user_id)
+        eq_(diffs[1][3], metadata.tables["order"].c.user_id)
 
     def test_compare_metadata_include_symbol(self):
         metadata = self.m2
 
         def include_symbol(table_name, schema_name):
-            return table_name in ('extra', 'order')
+            return table_name in ("extra", "order")
 
         context = MigrationContext.configure(
             connection=self.bind.connect(),
             opts={
-                'compare_type': True,
-                'compare_server_default': True,
-                'include_symbol': include_symbol,
-            }
+                "compare_type": True,
+                "compare_server_default": True,
+                "include_symbol": include_symbol,
+            },
         )
 
         diffs = autogenerate.compare_metadata(context, metadata)
 
-        eq_(diffs[0][0], 'remove_table')
+        eq_(diffs[0][0], "remove_table")
         eq_(diffs[0][1].name, "extra")
 
         eq_(diffs[1][0], "add_column")
         eq_(diffs[1][1], None)
         eq_(diffs[1][2], "order")
-        eq_(diffs[1][3], metadata.tables['order'].c.user_id)
+        eq_(diffs[1][3], metadata.tables["order"].c.user_id)
 
         eq_(diffs[2][0][0], "modify_type")
         eq_(diffs[2][0][1], None)
@@ -1106,15 +1169,14 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
         eq_(repr(diffs[2][0][5]), "NUMERIC(precision=8, scale=2)")
         eq_(repr(diffs[2][0][6]), "Numeric(precision=10, scale=2)")
 
-        eq_(diffs[2][1][0], 'modify_nullable')
-        eq_(diffs[2][1][2], 'order')
+        eq_(diffs[2][1][0], "modify_nullable")
+        eq_(diffs[2][1][2], "order")
         eq_(diffs[2][1][5], False)
         eq_(diffs[2][1][6], True)
 
     def test_compare_metadata_as_sql(self):
         context = MigrationContext.configure(
-            connection=self.bind.connect(),
-            opts={'as_sql': True}
+            connection=self.bind.connect(), opts={"as_sql": True}
         )
         metadata = self.m2
 
@@ -1122,12 +1184,14 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
             CommandError,
             "autogenerate can't use as_sql=True as it prevents "
             "querying the database for schema information",
-            autogenerate.compare_metadata, context, metadata
+            autogenerate.compare_metadata,
+            context,
+            metadata,
         )
 
 
 class PGCompareMetaData(ModelOne, AutogenTest, TestBase):
-    __only_on__ = 'postgresql'
+    __only_on__ = "postgresql"
     __backend__ = True
     schema = "test_schema"
 
@@ -1135,26 +1199,20 @@ class PGCompareMetaData(ModelOne, AutogenTest, TestBase):
         metadata = self.m2
 
         context = MigrationContext.configure(
-            connection=self.bind.connect(),
-            opts={
-                "include_schemas": True
-            }
+            connection=self.bind.connect(), opts={"include_schemas": True}
         )
 
         diffs = autogenerate.compare_metadata(context, metadata)
 
-        eq_(
-            diffs[0],
-            ('add_table', metadata.tables['test_schema.item'])
-        )
+        eq_(diffs[0], ("add_table", metadata.tables["test_schema.item"]))
 
-        eq_(diffs[1][0], 'remove_table')
+        eq_(diffs[1][0], "remove_table")
         eq_(diffs[1][1].name, "extra")
 
         eq_(diffs[2][0], "add_column")
         eq_(diffs[2][1], "test_schema")
         eq_(diffs[2][2], "address")
-        eq_(diffs[2][3], metadata.tables['test_schema.address'].c.street)
+        eq_(diffs[2][3], metadata.tables["test_schema.address"].c.street)
 
         eq_(diffs[3][0], "add_constraint")
         eq_(diffs[3][1].name, "uq_email")
@@ -1162,27 +1220,25 @@ class PGCompareMetaData(ModelOne, AutogenTest, TestBase):
         eq_(diffs[4][0], "add_column")
         eq_(diffs[4][1], "test_schema")
         eq_(diffs[4][2], "order")
-        eq_(diffs[4][3], metadata.tables['test_schema.order'].c.user_id)
+        eq_(diffs[4][3], metadata.tables["test_schema.order"].c.user_id)
 
-        eq_(diffs[5][0][0], 'modify_nullable')
+        eq_(diffs[5][0][0], "modify_nullable")
         eq_(diffs[5][0][5], False)
         eq_(diffs[5][0][6], True)
 
+
 class OrigObjectTest(TestBase):
     def setUp(self):
         self.metadata = m = MetaData()
         t = Table(
-            't', m,
-            Column('id', Integer(), primary_key=True),
-            Column('x', Integer())
-        )
-        self.ix = Index('ix1', t.c.id)
-        fk = ForeignKeyConstraint(['t_id'], ['t.id'])
-        q = Table(
-            'q', m,
-            Column('t_id', Integer()),
-            fk
+            "t",
+            m,
+            Column("id", Integer(), primary_key=True),
+            Column("x", Integer()),
         )
+        self.ix = Index("ix1", t.c.id)
+        fk = ForeignKeyConstraint(["t_id"], ["t.id"])
+        q = Table("q", m, Column("t_id", Integer()), fk)
         self.table = t
         self.fk = fk
         self.ck = CheckConstraint(t.c.x > 5)
@@ -1232,10 +1288,10 @@ class OrigObjectTest(TestBase):
         is_not_(None, op.to_constraint().table)
 
     def test_add_pk_no_orig(self):
-        op = ops.CreatePrimaryKeyOp('pk1', 't', ['x', 'y'])
+        op = ops.CreatePrimaryKeyOp("pk1", "t", ["x", "y"])
         pk = op.to_constraint()
-        eq_(pk.name, 'pk1')
-        eq_(pk.table.name, 't')
+        eq_(pk.name, "pk1")
+        eq_(pk.table.name, "t")
 
     def test_add_pk(self):
         pk = self.pk
@@ -1254,7 +1310,7 @@ class OrigObjectTest(TestBase):
     def test_drop_column(self):
         t = self.table
 
-        op = ops.DropColumnOp.from_column_and_tablename(None, 't', t.c.x)
+        op = ops.DropColumnOp.from_column_and_tablename(None, "t", t.c.x)
         is_(op.to_column(), t.c.x)
         is_(op.reverse().to_column(), t.c.x)
         is_not_(None, op.to_column().table)
@@ -1262,7 +1318,7 @@ class OrigObjectTest(TestBase):
     def test_add_column(self):
         t = self.table
 
-        op = ops.AddColumnOp.from_column_and_tablename(None, 't', t.c.x)
+        op = ops.AddColumnOp.from_column_and_tablename(None, "t", t.c.x)
         is_(op.to_column(), t.c.x)
         is_(op.reverse().to_column(), t.c.x)
         is_not_(None, op.to_column().table)
@@ -1304,25 +1360,33 @@ class MultipleMetaDataTest(AutogenFixtureTest, TestBase):
         m2b = MetaData()
         m2c = MetaData()
 
-        Table('a', m1a, Column('id', Integer, primary_key=True))
-        Table('b1', m1b, Column('id', Integer, primary_key=True))
-        Table('b2', m1b, Column('id', Integer, primary_key=True))
-        Table('c1', m1c, Column('id', Integer, primary_key=True),
-              Column('x', Integer))
+        Table("a", m1a, Column("id", Integer, primary_key=True))
+        Table("b1", m1b, Column("id", Integer, primary_key=True))
+        Table("b2", m1b, Column("id", Integer, primary_key=True))
+        Table(
+            "c1",
+            m1c,
+            Column("id", Integer, primary_key=True),
+            Column("x", Integer),
+        )
 
-        a = Table('a', m2a, Column('id', Integer, primary_key=True),
-                  Column('q', Integer))
-        Table('b1', m2b, Column('id', Integer, primary_key=True))
-        Table('c1', m2c, Column('id', Integer, primary_key=True))
-        c2 = Table('c2', m2c, Column('id', Integer, primary_key=True))
+        a = Table(
+            "a",
+            m2a,
+            Column("id", Integer, primary_key=True),
+            Column("q", Integer),
+        )
+        Table("b1", m2b, Column("id", Integer, primary_key=True))
+        Table("c1", m2c, Column("id", Integer, primary_key=True))
+        c2 = Table("c2", m2c, Column("id", Integer, primary_key=True))
 
         diffs = self._fixture([m1a, m1b, m1c], [m2a, m2b, m2c])
-        eq_(diffs[0], ('add_table', c2))
-        eq_(diffs[1][0], 'remove_table')
-        eq_(diffs[1][1].name, 'b2')
-        eq_(diffs[2], ('add_column', None, 'a', a.c.q))
-        eq_(diffs[3][0:3], ('remove_column', None, 'c1'))
-        eq_(diffs[3][3].name, 'x')
+        eq_(diffs[0], ("add_table", c2))
+        eq_(diffs[1][0], "remove_table")
+        eq_(diffs[1][1].name, "b2")
+        eq_(diffs[2], ("add_column", None, "a", a.c.q))
+        eq_(diffs[3][0:3], ("remove_column", None, "c1"))
+        eq_(diffs[3][3].name, "x")
 
     def test_empty_list(self):
         # because they're going to do it....
@@ -1339,18 +1403,19 @@ class MultipleMetaDataTest(AutogenFixtureTest, TestBase):
         m2a = MetaData()
         m2b = MetaData()
 
-        Table('a', m1a, Column('id', Integer, primary_key=True))
-        Table('b', m1b, Column('id', Integer, primary_key=True))
+        Table("a", m1a, Column("id", Integer, primary_key=True))
+        Table("b", m1b, Column("id", Integer, primary_key=True))
 
-        Table('a', m2a, Column('id', Integer, primary_key=True))
-        b = Table('b', m2b, Column('id', Integer, primary_key=True),
-                  Column('q', Integer))
+        Table("a", m2a, Column("id", Integer, primary_key=True))
+        b = Table(
+            "b",
+            m2b,
+            Column("id", Integer, primary_key=True),
+            Column("q", Integer),
+        )
 
         diffs = self._fixture((m1a, m1b), (m2a, m2b))
-        eq_(
-            diffs,
-            [('add_column', None, 'b', b.c.q)]
-        )
+        eq_(diffs, [("add_column", None, "b", b.c.q)])
 
     def test_raise_on_dupe(self):
         m1a = MetaData()
@@ -1359,116 +1424,123 @@ class MultipleMetaDataTest(AutogenFixtureTest, TestBase):
         m2a = MetaData()
         m2b = MetaData()
 
-        Table('a', m1a, Column('id', Integer, primary_key=True))
-        Table('b1', m1b, Column('id', Integer, primary_key=True))
-        Table('b2', m1b, Column('id', Integer, primary_key=True))
-        Table('b3', m1b, Column('id', Integer, primary_key=True))
+        Table("a", m1a, Column("id", Integer, primary_key=True))
+        Table("b1", m1b, Column("id", Integer, primary_key=True))
+        Table("b2", m1b, Column("id", Integer, primary_key=True))
+        Table("b3", m1b, Column("id", Integer, primary_key=True))
 
-        Table('a', m2a, Column('id', Integer, primary_key=True))
-        Table('a', m2b, Column('id', Integer, primary_key=True))
-        Table('b1', m2b, Column('id', Integer, primary_key=True))
-        Table('b2', m2a, Column('id', Integer, primary_key=True))
-        Table('b2', m2b, Column('id', Integer, primary_key=True))
+        Table("a", m2a, Column("id", Integer, primary_key=True))
+        Table("a", m2b, Column("id", Integer, primary_key=True))
+        Table("b1", m2b, Column("id", Integer, primary_key=True))
+        Table("b2", m2a, Column("id", Integer, primary_key=True))
+        Table("b2", m2b, Column("id", Integer, primary_key=True))
 
         assert_raises_message(
             ValueError,
             'Duplicate table keys across multiple MetaData objects: "a", "b2"',
             self._fixture,
-            [m1a, m1b], [m2a, m2b]
+            [m1a, m1b],
+            [m2a, m2b],
         )
 
 
 class AutoincrementTest(AutogenFixtureTest, TestBase):
     __backend__ = True
-    __requires__ = 'integer_subtype_comparisons',
+    __requires__ = ("integer_subtype_comparisons",)
 
     def test_alter_column_autoincrement_none(self):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('a', m1, Column('x', Integer, nullable=False))
-        Table('a', m2, Column('x', Integer, nullable=True))
+        Table("a", m1, Column("x", Integer, nullable=False))
+        Table("a", m2, Column("x", Integer, nullable=True))
 
         ops = self._fixture(m1, m2, return_ops=True)
-        assert 'autoincrement' not in ops.ops[0].ops[0].kw
+        assert "autoincrement" not in ops.ops[0].ops[0].kw
 
     def test_alter_column_autoincrement_pk_false(self):
         m1 = MetaData()
         m2 = MetaData()
 
         Table(
-            'a', m1,
-            Column('x', Integer, primary_key=True, autoincrement=False))
+            "a",
+            m1,
+            Column("x", Integer, primary_key=True, autoincrement=False),
+        )
         Table(
-            'a', m2,
-            Column('x', BigInteger, primary_key=True, autoincrement=False))
+            "a",
+            m2,
+            Column("x", BigInteger, primary_key=True, autoincrement=False),
+        )
 
         ops = self._fixture(m1, m2, return_ops=True)
-        is_(ops.ops[0].ops[0].kw['autoincrement'], False)
+        is_(ops.ops[0].ops[0].kw["autoincrement"], False)
 
     def test_alter_column_autoincrement_pk_implicit_true(self):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table(
-            'a', m1,
-            Column('x', Integer, primary_key=True))
-        Table(
-            'a', m2,
-            Column('x', BigInteger, primary_key=True))
+        Table("a", m1, Column("x", Integer, primary_key=True))
+        Table("a", m2, Column("x", BigInteger, primary_key=True))
 
         ops = self._fixture(m1, m2, return_ops=True)
-        is_(ops.ops[0].ops[0].kw['autoincrement'], True)
+        is_(ops.ops[0].ops[0].kw["autoincrement"], True)
 
     def test_alter_column_autoincrement_pk_explicit_true(self):
         m1 = MetaData()
         m2 = MetaData()
 
         Table(
-            'a', m1,
-            Column('x', Integer, primary_key=True, autoincrement=True))
+            "a", m1, Column("x", Integer, primary_key=True, autoincrement=True)
+        )
         Table(
-            'a', m2,
-            Column('x', BigInteger, primary_key=True, autoincrement=True))
+            "a",
+            m2,
+            Column("x", BigInteger, primary_key=True, autoincrement=True),
+        )
 
         ops = self._fixture(m1, m2, return_ops=True)
-        is_(ops.ops[0].ops[0].kw['autoincrement'], True)
+        is_(ops.ops[0].ops[0].kw["autoincrement"], True)
 
     def test_alter_column_autoincrement_nonpk_false(self):
         m1 = MetaData()
         m2 = MetaData()
 
         Table(
-            'a', m1,
-            Column('id', Integer, primary_key=True),
-            Column('x', Integer, autoincrement=False)
+            "a",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("x", Integer, autoincrement=False),
         )
         Table(
-            'a', m2,
-            Column('id', Integer, primary_key=True),
-            Column('x', BigInteger, autoincrement=False)
+            "a",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("x", BigInteger, autoincrement=False),
         )
 
         ops = self._fixture(m1, m2, return_ops=True)
-        is_(ops.ops[0].ops[0].kw['autoincrement'], False)
+        is_(ops.ops[0].ops[0].kw["autoincrement"], False)
 
     def test_alter_column_autoincrement_nonpk_implicit_false(self):
         m1 = MetaData()
         m2 = MetaData()
 
         Table(
-            'a', m1,
-            Column('id', Integer, primary_key=True),
-            Column('x', Integer)
+            "a",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("x", Integer),
         )
         Table(
-            'a', m2,
-            Column('id', Integer, primary_key=True),
-            Column('x', BigInteger)
+            "a",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("x", BigInteger),
         )
 
         ops = self._fixture(m1, m2, return_ops=True)
-        assert 'autoincrement' not in ops.ops[0].ops[0].kw
+        assert "autoincrement" not in ops.ops[0].ops[0].kw
 
     @config.requirements.fail_before_sqla_110
     def test_alter_column_autoincrement_nonpk_explicit_true(self):
@@ -1476,54 +1548,60 @@ class AutoincrementTest(AutogenFixtureTest, TestBase):
         m2 = MetaData()
 
         Table(
-            'a', m1,
-            Column('id', Integer, primary_key=True),
-            Column('x', Integer, autoincrement=True)
+            "a",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("x", Integer, autoincrement=True),
         )
         Table(
-            'a', m2,
-            Column('id', Integer, primary_key=True),
-            Column('x', BigInteger, autoincrement=True)
+            "a",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("x", BigInteger, autoincrement=True),
         )
 
         ops = self._fixture(m1, m2, return_ops=True)
-        is_(ops.ops[0].ops[0].kw['autoincrement'], True)
+        is_(ops.ops[0].ops[0].kw["autoincrement"], True)
 
     def test_alter_column_autoincrement_compositepk_false(self):
         m1 = MetaData()
         m2 = MetaData()
 
         Table(
-            'a', m1,
-            Column('id', Integer, primary_key=True),
-            Column('x', Integer, primary_key=True, autoincrement=False)
+            "a",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("x", Integer, primary_key=True, autoincrement=False),
         )
         Table(
-            'a', m2,
-            Column('id', Integer, primary_key=True),
-            Column('x', BigInteger, primary_key=True, autoincrement=False)
+            "a",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("x", BigInteger, primary_key=True, autoincrement=False),
         )
 
         ops = self._fixture(m1, m2, return_ops=True)
-        is_(ops.ops[0].ops[0].kw['autoincrement'], False)
+        is_(ops.ops[0].ops[0].kw["autoincrement"], False)
 
     def test_alter_column_autoincrement_compositepk_implicit_false(self):
         m1 = MetaData()
         m2 = MetaData()
 
         Table(
-            'a', m1,
-            Column('id', Integer, primary_key=True),
-            Column('x', Integer, primary_key=True)
+            "a",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("x", Integer, primary_key=True),
         )
         Table(
-            'a', m2,
-            Column('id', Integer, primary_key=True),
-            Column('x', BigInteger, primary_key=True)
+            "a",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("x", BigInteger, primary_key=True),
         )
 
         ops = self._fixture(m1, m2, return_ops=True)
-        assert 'autoincrement' not in ops.ops[0].ops[0].kw
+        assert "autoincrement" not in ops.ops[0].ops[0].kw
 
     @config.requirements.autoincrement_on_composite_pk
     def test_alter_column_autoincrement_compositepk_explicit_true(self):
@@ -1531,20 +1609,22 @@ class AutoincrementTest(AutogenFixtureTest, TestBase):
         m2 = MetaData()
 
         Table(
-            'a', m1,
-            Column('id', Integer, primary_key=True, autoincrement=False),
-            Column('x', Integer, primary_key=True, autoincrement=True),
+            "a",
+            m1,
+            Column("id", Integer, primary_key=True, autoincrement=False),
+            Column("x", Integer, primary_key=True, autoincrement=True),
             # on SQLA 1.0 and earlier, this being present
             # trips the "add KEY for the primary key" so that the
             # AUTO_INCREMENT keyword is accepted by MySQL.  SQLA 1.1 and
             # greater the columns are just reorganized.
-            mysql_engine='InnoDB'
+            mysql_engine="InnoDB",
         )
         Table(
-            'a', m2,
-            Column('id', Integer, primary_key=True, autoincrement=False),
-            Column('x', BigInteger, primary_key=True, autoincrement=True)
+            "a",
+            m2,
+            Column("id", Integer, primary_key=True, autoincrement=False),
+            Column("x", BigInteger, primary_key=True, autoincrement=True),
         )
 
         ops = self._fixture(m1, m2, return_ops=True)
-        is_(ops.ops[0].ops[0].kw['autoincrement'], True)
+        is_(ops.ops[0].ops[0].kw["autoincrement"], True)
index 3dd66aea284a26116a9937fa853568a2d4b0b870..66c5ac44bcf1c321dcb125334e737e82ac562bab 100644 (file)
@@ -1,8 +1,14 @@
 import sys
 from alembic.testing import TestBase, config, mock
 
-from sqlalchemy import MetaData, Column, Table, Integer, String, \
-    ForeignKeyConstraint
+from sqlalchemy import (
+    MetaData,
+    Column,
+    Table,
+    Integer,
+    String,
+    ForeignKeyConstraint,
+)
 from alembic.testing import eq_
 
 py3k = sys.version_info.major >= 3
@@ -17,105 +23,141 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('some_table', m1,
-              Column('test', String(10), primary_key=True),
-              mysql_engine='InnoDB')
-
-        Table('user', m1,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('a1', String(10), server_default="x"),
-              Column('test2', String(10)),
-              ForeignKeyConstraint(['test2'], ['some_table.test']),
-              mysql_engine='InnoDB')
-
-        Table('some_table', m2,
-              Column('test', String(10), primary_key=True),
-              mysql_engine='InnoDB')
-
-        Table('user', m2,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('a1', String(10), server_default="x"),
-              Column('test2', String(10)),
-              mysql_engine='InnoDB'
-              )
+        Table(
+            "some_table",
+            m1,
+            Column("test", String(10), primary_key=True),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("a1", String(10), server_default="x"),
+            Column("test2", String(10)),
+            ForeignKeyConstraint(["test2"], ["some_table.test"]),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "some_table",
+            m2,
+            Column("test", String(10), primary_key=True),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("a1", String(10), server_default="x"),
+            Column("test2", String(10)),
+            mysql_engine="InnoDB",
+        )
 
         diffs = self._fixture(m1, m2)
 
         self._assert_fk_diff(
-            diffs[0], "remove_fk",
-            "user", ['test2'],
-            'some_table', ['test'],
-            conditional_name="servergenerated"
+            diffs[0],
+            "remove_fk",
+            "user",
+            ["test2"],
+            "some_table",
+            ["test"],
+            conditional_name="servergenerated",
         )
 
     def test_add_fk(self):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('some_table', m1,
-              Column('id', Integer, primary_key=True),
-              Column('test', String(10)),
-              mysql_engine='InnoDB')
-
-        Table('user', m1,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('a1', String(10), server_default="x"),
-              Column('test2', String(10)),
-              mysql_engine='InnoDB')
-
-        Table('some_table', m2,
-              Column('id', Integer, primary_key=True),
-              Column('test', String(10)),
-              mysql_engine='InnoDB')
-
-        Table('user', m2,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('a1', String(10), server_default="x"),
-              Column('test2', String(10)),
-              ForeignKeyConstraint(['test2'], ['some_table.test']),
-              mysql_engine='InnoDB')
+        Table(
+            "some_table",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("test", String(10)),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("a1", String(10), server_default="x"),
+            Column("test2", String(10)),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "some_table",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("test", String(10)),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("a1", String(10), server_default="x"),
+            Column("test2", String(10)),
+            ForeignKeyConstraint(["test2"], ["some_table.test"]),
+            mysql_engine="InnoDB",
+        )
 
         diffs = self._fixture(m1, m2)
 
         self._assert_fk_diff(
-            diffs[0], "add_fk",
-            "user", ["test2"],
-            "some_table", ["test"]
+            diffs[0], "add_fk", "user", ["test2"], "some_table", ["test"]
         )
 
     def test_no_change(self):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('some_table', m1,
-              Column('id', Integer, primary_key=True),
-              Column('test', String(10)),
-              mysql_engine='InnoDB')
-
-        Table('user', m1,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('a1', String(10), server_default="x"),
-              Column('test2', Integer),
-              ForeignKeyConstraint(['test2'], ['some_table.id']),
-              mysql_engine='InnoDB')
-
-        Table('some_table', m2,
-              Column('id', Integer, primary_key=True),
-              Column('test', String(10)),
-              mysql_engine='InnoDB')
-
-        Table('user', m2,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('a1', String(10), server_default="x"),
-              Column('test2', Integer),
-              ForeignKeyConstraint(['test2'], ['some_table.id']),
-              mysql_engine='InnoDB')
+        Table(
+            "some_table",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("test", String(10)),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("a1", String(10), server_default="x"),
+            Column("test2", Integer),
+            ForeignKeyConstraint(["test2"], ["some_table.id"]),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "some_table",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("test", String(10)),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("a1", String(10), server_default="x"),
+            Column("test2", Integer),
+            ForeignKeyConstraint(["test2"], ["some_table.id"]),
+            mysql_engine="InnoDB",
+        )
 
         diffs = self._fixture(m1, m2)
 
@@ -125,36 +167,51 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('some_table', m1,
-              Column('id_1', String(10), primary_key=True),
-              Column('id_2', String(10), primary_key=True),
-              mysql_engine='InnoDB')
-
-        Table('user', m1,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('a1', String(10), server_default="x"),
-              Column('other_id_1', String(10)),
-              Column('other_id_2', String(10)),
-              ForeignKeyConstraint(['other_id_1', 'other_id_2'],
-                                   ['some_table.id_1', 'some_table.id_2']),
-              mysql_engine='InnoDB')
-
-        Table('some_table', m2,
-              Column('id_1', String(10), primary_key=True),
-              Column('id_2', String(10), primary_key=True),
-              mysql_engine='InnoDB'
-              )
-
-        Table('user', m2,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('a1', String(10), server_default="x"),
-              Column('other_id_1', String(10)),
-              Column('other_id_2', String(10)),
-              ForeignKeyConstraint(['other_id_1', 'other_id_2'],
-                                   ['some_table.id_1', 'some_table.id_2']),
-              mysql_engine='InnoDB')
+        Table(
+            "some_table",
+            m1,
+            Column("id_1", String(10), primary_key=True),
+            Column("id_2", String(10), primary_key=True),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("a1", String(10), server_default="x"),
+            Column("other_id_1", String(10)),
+            Column("other_id_2", String(10)),
+            ForeignKeyConstraint(
+                ["other_id_1", "other_id_2"],
+                ["some_table.id_1", "some_table.id_2"],
+            ),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "some_table",
+            m2,
+            Column("id_1", String(10), primary_key=True),
+            Column("id_2", String(10), primary_key=True),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("a1", String(10), server_default="x"),
+            Column("other_id_1", String(10)),
+            Column("other_id_2", String(10)),
+            ForeignKeyConstraint(
+                ["other_id_1", "other_id_2"],
+                ["some_table.id_1", "some_table.id_2"],
+            ),
+            mysql_engine="InnoDB",
+        )
 
         diffs = self._fixture(m1, m2)
 
@@ -164,42 +221,59 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('some_table', m1,
-              Column('id_1', String(10), primary_key=True),
-              Column('id_2', String(10), primary_key=True),
-              mysql_engine='InnoDB')
-
-        Table('user', m1,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('a1', String(10), server_default="x"),
-              Column('other_id_1', String(10)),
-              Column('other_id_2', String(10)),
-              mysql_engine='InnoDB')
-
-        Table('some_table', m2,
-              Column('id_1', String(10), primary_key=True),
-              Column('id_2', String(10), primary_key=True),
-              mysql_engine='InnoDB')
-
-        Table('user', m2,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('a1', String(10), server_default="x"),
-              Column('other_id_1', String(10)),
-              Column('other_id_2', String(10)),
-              ForeignKeyConstraint(['other_id_1', 'other_id_2'],
-                                   ['some_table.id_1', 'some_table.id_2'],
-                                   name='fk_test_name'),
-              mysql_engine='InnoDB')
+        Table(
+            "some_table",
+            m1,
+            Column("id_1", String(10), primary_key=True),
+            Column("id_2", String(10), primary_key=True),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("a1", String(10), server_default="x"),
+            Column("other_id_1", String(10)),
+            Column("other_id_2", String(10)),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "some_table",
+            m2,
+            Column("id_1", String(10), primary_key=True),
+            Column("id_2", String(10), primary_key=True),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("a1", String(10), server_default="x"),
+            Column("other_id_1", String(10)),
+            Column("other_id_2", String(10)),
+            ForeignKeyConstraint(
+                ["other_id_1", "other_id_2"],
+                ["some_table.id_1", "some_table.id_2"],
+                name="fk_test_name",
+            ),
+            mysql_engine="InnoDB",
+        )
 
         diffs = self._fixture(m1, m2)
 
         self._assert_fk_diff(
-            diffs[0], "add_fk",
-            "user", ['other_id_1', 'other_id_2'],
-            'some_table', ['id_1', 'id_2'],
-            name="fk_test_name"
+            diffs[0],
+            "add_fk",
+            "user",
+            ["other_id_1", "other_id_2"],
+            "some_table",
+            ["id_1", "id_2"],
+            name="fk_test_name",
         )
 
     @config.requirements.no_name_normalize
@@ -207,111 +281,160 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('some_table', m1,
-              Column('id_1', String(10), primary_key=True),
-              Column('id_2', String(10), primary_key=True),
-              mysql_engine='InnoDB')
-
-        Table('user', m1,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('a1', String(10), server_default="x"),
-              Column('other_id_1', String(10)),
-              Column('other_id_2', String(10)),
-              ForeignKeyConstraint(['other_id_1', 'other_id_2'],
-                                   ['some_table.id_1', 'some_table.id_2'],
-                                   name='fk_test_name'),
-              mysql_engine='InnoDB')
-
-        Table('some_table', m2,
-              Column('id_1', String(10), primary_key=True),
-              Column('id_2', String(10), primary_key=True),
-              mysql_engine='InnoDB')
-
-        Table('user', m2,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('a1', String(10), server_default="x"),
-              Column('other_id_1', String(10)),
-              Column('other_id_2', String(10)),
-              mysql_engine='InnoDB')
+        Table(
+            "some_table",
+            m1,
+            Column("id_1", String(10), primary_key=True),
+            Column("id_2", String(10), primary_key=True),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("a1", String(10), server_default="x"),
+            Column("other_id_1", String(10)),
+            Column("other_id_2", String(10)),
+            ForeignKeyConstraint(
+                ["other_id_1", "other_id_2"],
+                ["some_table.id_1", "some_table.id_2"],
+                name="fk_test_name",
+            ),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "some_table",
+            m2,
+            Column("id_1", String(10), primary_key=True),
+            Column("id_2", String(10), primary_key=True),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("a1", String(10), server_default="x"),
+            Column("other_id_1", String(10)),
+            Column("other_id_2", String(10)),
+            mysql_engine="InnoDB",
+        )
 
         diffs = self._fixture(m1, m2)
 
         self._assert_fk_diff(
-            diffs[0], "remove_fk",
-            "user", ['other_id_1', 'other_id_2'],
-            "some_table", ['id_1', 'id_2'],
-            conditional_name="fk_test_name"
+            diffs[0],
+            "remove_fk",
+            "user",
+            ["other_id_1", "other_id_2"],
+            "some_table",
+            ["id_1", "id_2"],
+            conditional_name="fk_test_name",
         )
 
     def test_add_fk_colkeys(self):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('some_table', m1,
-              Column('id_1', String(10), primary_key=True),
-              Column('id_2', String(10), primary_key=True),
-              mysql_engine='InnoDB')
-
-        Table('user', m1,
-              Column('id', Integer, primary_key=True),
-              Column('other_id_1', String(10)),
-              Column('other_id_2', String(10)),
-              mysql_engine='InnoDB')
-
-        Table('some_table', m2,
-              Column('id_1', String(10), key='tid1', primary_key=True),
-              Column('id_2', String(10), key='tid2', primary_key=True),
-              mysql_engine='InnoDB')
-
-        Table('user', m2,
-              Column('id', Integer, primary_key=True),
-              Column('other_id_1', String(10), key='oid1'),
-              Column('other_id_2', String(10), key='oid2'),
-              ForeignKeyConstraint(['oid1', 'oid2'],
-                                   ['some_table.tid1', 'some_table.tid2'],
-                                   name='fk_test_name'),
-              mysql_engine='InnoDB')
+        Table(
+            "some_table",
+            m1,
+            Column("id_1", String(10), primary_key=True),
+            Column("id_2", String(10), primary_key=True),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("other_id_1", String(10)),
+            Column("other_id_2", String(10)),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "some_table",
+            m2,
+            Column("id_1", String(10), key="tid1", primary_key=True),
+            Column("id_2", String(10), key="tid2", primary_key=True),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("other_id_1", String(10), key="oid1"),
+            Column("other_id_2", String(10), key="oid2"),
+            ForeignKeyConstraint(
+                ["oid1", "oid2"],
+                ["some_table.tid1", "some_table.tid2"],
+                name="fk_test_name",
+            ),
+            mysql_engine="InnoDB",
+        )
 
         diffs = self._fixture(m1, m2)
 
         self._assert_fk_diff(
-            diffs[0], "add_fk",
-            "user", ['other_id_1', 'other_id_2'],
-            'some_table', ['id_1', 'id_2'],
-            name="fk_test_name"
+            diffs[0],
+            "add_fk",
+            "user",
+            ["other_id_1", "other_id_2"],
+            "some_table",
+            ["id_1", "id_2"],
+            name="fk_test_name",
         )
 
     def test_no_change_colkeys(self):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('some_table', m1,
-              Column('id_1', String(10), primary_key=True),
-              Column('id_2', String(10), primary_key=True),
-              mysql_engine='InnoDB')
-
-        Table('user', m1,
-              Column('id', Integer, primary_key=True),
-              Column('other_id_1', String(10)),
-              Column('other_id_2', String(10)),
-              ForeignKeyConstraint(['other_id_1', 'other_id_2'],
-                                   ['some_table.id_1', 'some_table.id_2']),
-              mysql_engine='InnoDB')
-
-        Table('some_table', m2,
-              Column('id_1', String(10), key='tid1', primary_key=True),
-              Column('id_2', String(10), key='tid2', primary_key=True),
-              mysql_engine='InnoDB')
-
-        Table('user', m2,
-              Column('id', Integer, primary_key=True),
-              Column('other_id_1', String(10), key='oid1'),
-              Column('other_id_2', String(10), key='oid2'),
-              ForeignKeyConstraint(['oid1', 'oid2'],
-                                   ['some_table.tid1', 'some_table.tid2']),
-              mysql_engine='InnoDB')
+        Table(
+            "some_table",
+            m1,
+            Column("id_1", String(10), primary_key=True),
+            Column("id_2", String(10), primary_key=True),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("other_id_1", String(10)),
+            Column("other_id_2", String(10)),
+            ForeignKeyConstraint(
+                ["other_id_1", "other_id_2"],
+                ["some_table.id_1", "some_table.id_2"],
+            ),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "some_table",
+            m2,
+            Column("id_1", String(10), key="tid1", primary_key=True),
+            Column("id_2", String(10), key="tid2", primary_key=True),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("other_id_1", String(10), key="oid1"),
+            Column("other_id_2", String(10), key="oid2"),
+            ForeignKeyConstraint(
+                ["oid1", "oid2"], ["some_table.tid1", "some_table.tid2"]
+            ),
+            mysql_engine="InnoDB",
+        )
 
         diffs = self._fixture(m1, m2)
 
@@ -320,7 +443,7 @@ class AutogenerateForeignKeysTest(AutogenFixtureTest, TestBase):
 
 class IncludeHooksTest(AutogenFixtureTest, TestBase):
     __backend__ = True
-    __requires__ = 'fk_names',
+    __requires__ = ("fk_names",)
 
     @config.requirements.no_name_normalize
     def test_remove_connection_fk(self):
@@ -328,11 +451,18 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
         m2 = MetaData()
 
         ref = Table(
-            'ref', m1, Column('id', Integer, primary_key=True),
-            mysql_engine='InnoDB')
+            "ref",
+            m1,
+            Column("id", Integer, primary_key=True),
+            mysql_engine="InnoDB",
+        )
         t1 = Table(
-            't', m1, Column('x', Integer), Column('y', Integer),
-            mysql_engine='InnoDB')
+            "t",
+            m1,
+            Column("x", Integer),
+            Column("y", Integer),
+            mysql_engine="InnoDB",
+        )
         t1.append_constraint(
             ForeignKeyConstraint([t1.c.x], [ref.c.id], name="fk1")
         )
@@ -341,24 +471,37 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
         )
 
         ref = Table(
-            'ref', m2, Column('id', Integer, primary_key=True),
-            mysql_engine='InnoDB')
+            "ref",
+            m2,
+            Column("id", Integer, primary_key=True),
+            mysql_engine="InnoDB",
+        )
         Table(
-            't', m2, Column('x', Integer), Column('y', Integer),
-            mysql_engine='InnoDB')
+            "t",
+            m2,
+            Column("x", Integer),
+            Column("y", Integer),
+            mysql_engine="InnoDB",
+        )
 
         def include_object(object_, name, type_, reflected, compare_to):
             return not (
-                isinstance(object_, ForeignKeyConstraint) and
-                type_ == 'foreign_key_constraint'
-                and reflected and name == 'fk1')
+                isinstance(object_, ForeignKeyConstraint)
+                and type_ == "foreign_key_constraint"
+                and reflected
+                and name == "fk1"
+            )
 
         diffs = self._fixture(m1, m2, object_filters=include_object)
 
         self._assert_fk_diff(
-            diffs[0], "remove_fk",
-            't', ['y'], 'ref', ['id'],
-            conditional_name='fk2'
+            diffs[0],
+            "remove_fk",
+            "t",
+            ["y"],
+            "ref",
+            ["id"],
+            conditional_name="fk2",
         )
         eq_(len(diffs), 1)
 
@@ -367,18 +510,32 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
         m2 = MetaData()
 
         Table(
-            'ref', m1,
-            Column('id', Integer, primary_key=True), mysql_engine='InnoDB')
+            "ref",
+            m1,
+            Column("id", Integer, primary_key=True),
+            mysql_engine="InnoDB",
+        )
         Table(
-            't', m1,
-            Column('x', Integer), Column('y', Integer), mysql_engine='InnoDB')
+            "t",
+            m1,
+            Column("x", Integer),
+            Column("y", Integer),
+            mysql_engine="InnoDB",
+        )
 
         ref = Table(
-            'ref', m2, Column('id', Integer, primary_key=True),
-            mysql_engine='InnoDB')
+            "ref",
+            m2,
+            Column("id", Integer, primary_key=True),
+            mysql_engine="InnoDB",
+        )
         t2 = Table(
-            't', m2, Column('x', Integer), Column('y', Integer),
-            mysql_engine='InnoDB')
+            "t",
+            m2,
+            Column("x", Integer),
+            Column("y", Integer),
+            mysql_engine="InnoDB",
+        )
         t2.append_constraint(
             ForeignKeyConstraint([t2.c.x], [ref.c.id], name="fk1")
         )
@@ -388,16 +545,16 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
 
         def include_object(object_, name, type_, reflected, compare_to):
             return not (
-                isinstance(object_, ForeignKeyConstraint) and
-                type_ == 'foreign_key_constraint'
-                and not reflected and name == 'fk1')
+                isinstance(object_, ForeignKeyConstraint)
+                and type_ == "foreign_key_constraint"
+                and not reflected
+                and name == "fk1"
+            )
 
         diffs = self._fixture(m1, m2, object_filters=include_object)
 
         self._assert_fk_diff(
-            diffs[0], "add_fk",
-            't', ['y'], 'ref', ['id'],
-            name='fk2'
+            diffs[0], "add_fk", "t", ["y"], "ref", ["id"], name="fk2"
         )
         eq_(len(diffs), 1)
 
@@ -407,20 +564,26 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
         m2 = MetaData()
 
         r1a = Table(
-            'ref_a', m1,
-            Column('a', Integer, primary_key=True),
-            mysql_engine='InnoDB'
+            "ref_a",
+            m1,
+            Column("a", Integer, primary_key=True),
+            mysql_engine="InnoDB",
         )
         Table(
-            'ref_b', m1,
-            Column('a', Integer, primary_key=True),
-            Column('b', Integer, primary_key=True),
-            mysql_engine='InnoDB'
+            "ref_b",
+            m1,
+            Column("a", Integer, primary_key=True),
+            Column("b", Integer, primary_key=True),
+            mysql_engine="InnoDB",
         )
         t1 = Table(
-            't', m1, Column('x', Integer),
-            Column('y', Integer), Column('z', Integer),
-            mysql_engine='InnoDB')
+            "t",
+            m1,
+            Column("x", Integer),
+            Column("y", Integer),
+            Column("z", Integer),
+            mysql_engine="InnoDB",
+        )
         t1.append_constraint(
             ForeignKeyConstraint([t1.c.x], [r1a.c.a], name="fk1")
         )
@@ -429,82 +592,104 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
         )
 
         Table(
-            'ref_a', m2,
-            Column('a', Integer, primary_key=True),
-            mysql_engine='InnoDB'
+            "ref_a",
+            m2,
+            Column("a", Integer, primary_key=True),
+            mysql_engine="InnoDB",
         )
         r2b = Table(
-            'ref_b', m2,
-            Column('a', Integer, primary_key=True),
-            Column('b', Integer, primary_key=True),
-            mysql_engine='InnoDB'
+            "ref_b",
+            m2,
+            Column("a", Integer, primary_key=True),
+            Column("b", Integer, primary_key=True),
+            mysql_engine="InnoDB",
         )
         t2 = Table(
-            't', m2, Column('x', Integer),
-            Column('y', Integer), Column('z', Integer),
-            mysql_engine='InnoDB')
+            "t",
+            m2,
+            Column("x", Integer),
+            Column("y", Integer),
+            Column("z", Integer),
+            mysql_engine="InnoDB",
+        )
         t2.append_constraint(
             ForeignKeyConstraint(
-                [t2.c.x, t2.c.z], [r2b.c.a, r2b.c.b], name="fk1")
+                [t2.c.x, t2.c.z], [r2b.c.a, r2b.c.b], name="fk1"
+            )
         )
         t2.append_constraint(
             ForeignKeyConstraint(
-                [t2.c.y, t2.c.z], [r2b.c.a, r2b.c.b], name="fk2")
+                [t2.c.y, t2.c.z], [r2b.c.a, r2b.c.b], name="fk2"
+            )
         )
 
         def include_object(object_, name, type_, reflected, compare_to):
             return not (
-                isinstance(object_, ForeignKeyConstraint) and
-                type_ == 'foreign_key_constraint'
-                and name == 'fk1'
+                isinstance(object_, ForeignKeyConstraint)
+                and type_ == "foreign_key_constraint"
+                and name == "fk1"
             )
 
         diffs = self._fixture(m1, m2, object_filters=include_object)
 
         self._assert_fk_diff(
-            diffs[0], "remove_fk",
-            't', ['y'], 'ref_a', ['a'],
-            name='fk2'
+            diffs[0], "remove_fk", "t", ["y"], "ref_a", ["a"], name="fk2"
         )
         self._assert_fk_diff(
-            diffs[1], "add_fk",
-            't', ['y', 'z'], 'ref_b', ['a', 'b'],
-            name='fk2'
+            diffs[1],
+            "add_fk",
+            "t",
+            ["y", "z"],
+            "ref_b",
+            ["a", "b"],
+            name="fk2",
         )
         eq_(len(diffs), 2)
 
 
 class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
     __backend__ = True
-    __requires__ = ('flexible_fk_cascades', )
+    __requires__ = ("flexible_fk_cascades",)
 
     def _fk_opts_fixture(self, old_opts, new_opts):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('some_table', m1,
-              Column('id', Integer, primary_key=True),
-              Column('test', String(10)),
-              mysql_engine='InnoDB')
-
-        Table('user', m1,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('tid', Integer),
-              ForeignKeyConstraint(['tid'], ['some_table.id'], **old_opts),
-              mysql_engine='InnoDB')
-
-        Table('some_table', m2,
-              Column('id', Integer, primary_key=True),
-              Column('test', String(10)),
-              mysql_engine='InnoDB')
-
-        Table('user', m2,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('tid', Integer),
-              ForeignKeyConstraint(['tid'], ['some_table.id'], **new_opts),
-              mysql_engine='InnoDB')
+        Table(
+            "some_table",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("test", String(10)),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("tid", Integer),
+            ForeignKeyConstraint(["tid"], ["some_table.id"], **old_opts),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "some_table",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("test", String(10)),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "user",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("tid", Integer),
+            ForeignKeyConstraint(["tid"], ["some_table.id"], **new_opts),
+            mysql_engine="InnoDB",
+        )
 
         return self._fixture(m1, m2)
 
@@ -521,47 +706,55 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
         return True
 
     def test_add_ondelete(self):
-        diffs = self._fk_opts_fixture(
-            {}, {"ondelete": "cascade"}
-        )
+        diffs = self._fk_opts_fixture({}, {"ondelete": "cascade"})
 
         if self._expect_opts_supported():
             self._assert_fk_diff(
-                diffs[0], "remove_fk",
-                "user", ["tid"],
-                "some_table", ["id"],
+                diffs[0],
+                "remove_fk",
+                "user",
+                ["tid"],
+                "some_table",
+                ["id"],
                 ondelete=None,
-                conditional_name="servergenerated"
+                conditional_name="servergenerated",
             )
 
             self._assert_fk_diff(
-                diffs[1], "add_fk",
-                "user", ["tid"],
-                "some_table", ["id"],
-                ondelete="cascade"
+                diffs[1],
+                "add_fk",
+                "user",
+                ["tid"],
+                "some_table",
+                ["id"],
+                ondelete="cascade",
             )
         else:
             eq_(diffs, [])
 
     def test_remove_ondelete(self):
-        diffs = self._fk_opts_fixture(
-            {"ondelete": "CASCADE"}, {}
-        )
+        diffs = self._fk_opts_fixture({"ondelete": "CASCADE"}, {})
 
         if self._expect_opts_supported():
             self._assert_fk_diff(
-                diffs[0], "remove_fk",
-                "user", ["tid"],
-                "some_table", ["id"],
+                diffs[0],
+                "remove_fk",
+                "user",
+                ["tid"],
+                "some_table",
+                ["id"],
                 ondelete="CASCADE",
-                conditional_name="servergenerated"
+                conditional_name="servergenerated",
             )
 
             self._assert_fk_diff(
-                diffs[1], "add_fk",
-                "user", ["tid"],
-                "some_table", ["id"],
-                ondelete=None
+                diffs[1],
+                "add_fk",
+                "user",
+                ["tid"],
+                "some_table",
+                ["id"],
+                ondelete=None,
             )
         else:
             eq_(diffs, [])
@@ -574,47 +767,55 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
         eq_(diffs, [])
 
     def test_add_onupdate(self):
-        diffs = self._fk_opts_fixture(
-            {}, {"onupdate": "cascade"}
-        )
+        diffs = self._fk_opts_fixture({}, {"onupdate": "cascade"})
 
         if self._expect_opts_supported():
             self._assert_fk_diff(
-                diffs[0], "remove_fk",
-                "user", ["tid"],
-                "some_table", ["id"],
+                diffs[0],
+                "remove_fk",
+                "user",
+                ["tid"],
+                "some_table",
+                ["id"],
                 onupdate=None,
-                conditional_name="servergenerated"
+                conditional_name="servergenerated",
             )
 
             self._assert_fk_diff(
-                diffs[1], "add_fk",
-                "user", ["tid"],
-                "some_table", ["id"],
-                onupdate="cascade"
+                diffs[1],
+                "add_fk",
+                "user",
+                ["tid"],
+                "some_table",
+                ["id"],
+                onupdate="cascade",
             )
         else:
             eq_(diffs, [])
 
     def test_remove_onupdate(self):
-        diffs = self._fk_opts_fixture(
-            {"onupdate": "CASCADE"}, {}
-        )
+        diffs = self._fk_opts_fixture({"onupdate": "CASCADE"}, {})
 
         if self._expect_opts_supported():
             self._assert_fk_diff(
-                diffs[0], "remove_fk",
-                "user", ["tid"],
-                "some_table", ["id"],
+                diffs[0],
+                "remove_fk",
+                "user",
+                ["tid"],
+                "some_table",
+                ["id"],
                 onupdate="CASCADE",
-                conditional_name="servergenerated"
+                conditional_name="servergenerated",
             )
 
             self._assert_fk_diff(
-                diffs[1], "add_fk",
-                "user", ["tid"],
-                "some_table", ["id"],
-                onupdate=None
+                diffs[1],
+                "add_fk",
+                "user",
+                ["tid"],
+                "some_table",
+                ["id"],
+                onupdate=None,
             )
         else:
             eq_(diffs, [])
@@ -668,20 +869,26 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
         )
         if self._expect_opts_supported():
             self._assert_fk_diff(
-                diffs[0], "remove_fk",
-                "user", ["tid"],
-                "some_table", ["id"],
+                diffs[0],
+                "remove_fk",
+                "user",
+                ["tid"],
+                "some_table",
+                ["id"],
                 onupdate=None,
                 ondelete=mock.ANY,  # MySQL reports None, PG reports RESTRICT
-                conditional_name="servergenerated"
+                conditional_name="servergenerated",
             )
 
             self._assert_fk_diff(
-                diffs[1], "add_fk",
-                "user", ["tid"],
-                "some_table", ["id"],
+                diffs[1],
+                "add_fk",
+                "user",
+                ["tid"],
+                "some_table",
+                ["id"],
                 onupdate=None,
-                ondelete="cascade"
+                ondelete="cascade",
             )
         else:
             eq_(diffs, [])
@@ -696,20 +903,26 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
         )
         if self._expect_opts_supported():
             self._assert_fk_diff(
-                diffs[0], "remove_fk",
-                "user", ["tid"],
-                "some_table", ["id"],
+                diffs[0],
+                "remove_fk",
+                "user",
+                ["tid"],
+                "some_table",
+                ["id"],
                 onupdate=mock.ANY,  # MySQL reports None, PG reports RESTRICT
                 ondelete=None,
-                conditional_name="servergenerated"
+                conditional_name="servergenerated",
             )
 
             self._assert_fk_diff(
-                diffs[1], "add_fk",
-                "user", ["tid"],
-                "some_table", ["id"],
+                diffs[1],
+                "add_fk",
+                "user",
+                ["tid"],
+                "some_table",
+                ["id"],
                 onupdate="cascade",
-                ondelete=None
+                ondelete=None,
             )
         else:
             eq_(diffs, [])
@@ -717,70 +930,84 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
     def test_ondelete_onupdate_combo(self):
         diffs = self._fk_opts_fixture(
             {"onupdate": "CASCADE", "ondelete": "SET NULL"},
-            {"onupdate": "RESTRICT", "ondelete": "RESTRICT"}
+            {"onupdate": "RESTRICT", "ondelete": "RESTRICT"},
         )
 
         if self._expect_opts_supported():
             self._assert_fk_diff(
-                diffs[0], "remove_fk",
-                "user", ["tid"],
-                "some_table", ["id"],
+                diffs[0],
+                "remove_fk",
+                "user",
+                ["tid"],
+                "some_table",
+                ["id"],
                 onupdate="CASCADE",
                 ondelete="SET NULL",
-                conditional_name="servergenerated"
+                conditional_name="servergenerated",
             )
 
             self._assert_fk_diff(
-                diffs[1], "add_fk",
-                "user", ["tid"],
-                "some_table", ["id"],
+                diffs[1],
+                "add_fk",
+                "user",
+                ["tid"],
+                "some_table",
+                ["id"],
                 onupdate="RESTRICT",
-                ondelete="RESTRICT"
+                ondelete="RESTRICT",
             )
         else:
             eq_(diffs, [])
 
     @config.requirements.fk_initially
     def test_add_initially_deferred(self):
-        diffs = self._fk_opts_fixture(
-            {}, {"initially": "deferred"}
-        )
+        diffs = self._fk_opts_fixture({}, {"initially": "deferred"})
 
         self._assert_fk_diff(
-            diffs[0], "remove_fk",
-            "user", ["tid"],
-            "some_table", ["id"],
+            diffs[0],
+            "remove_fk",
+            "user",
+            ["tid"],
+            "some_table",
+            ["id"],
             initially=None,
-            conditional_name="servergenerated"
+            conditional_name="servergenerated",
         )
 
         self._assert_fk_diff(
-            diffs[1], "add_fk",
-            "user", ["tid"],
-            "some_table", ["id"],
-            initially="deferred"
+            diffs[1],
+            "add_fk",
+            "user",
+            ["tid"],
+            "some_table",
+            ["id"],
+            initially="deferred",
         )
 
     @config.requirements.fk_initially
     def test_remove_initially_deferred(self):
-        diffs = self._fk_opts_fixture(
-            {"initially": "deferred"}, {}
-        )
+        diffs = self._fk_opts_fixture({"initially": "deferred"}, {})
 
         self._assert_fk_diff(
-            diffs[0], "remove_fk",
-            "user", ["tid"],
-            "some_table", ["id"],
+            diffs[0],
+            "remove_fk",
+            "user",
+            ["tid"],
+            "some_table",
+            ["id"],
             initially="DEFERRED",
             deferrable=True,
-            conditional_name="servergenerated"
+            conditional_name="servergenerated",
         )
 
         self._assert_fk_diff(
-            diffs[1], "add_fk",
-            "user", ["tid"],
-            "some_table", ["id"],
-            initially=None
+            diffs[1],
+            "add_fk",
+            "user",
+            ["tid"],
+            "some_table",
+            ["id"],
+            initially=None,
         )
 
     @config.requirements.fk_deferrable
@@ -791,19 +1018,25 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
         )
 
         self._assert_fk_diff(
-            diffs[0], "remove_fk",
-            "user", ["tid"],
-            "some_table", ["id"],
+            diffs[0],
+            "remove_fk",
+            "user",
+            ["tid"],
+            "some_table",
+            ["id"],
             initially=None,
-            conditional_name="servergenerated"
+            conditional_name="servergenerated",
         )
 
         self._assert_fk_diff(
-            diffs[1], "add_fk",
-            "user", ["tid"],
-            "some_table", ["id"],
+            diffs[1],
+            "add_fk",
+            "user",
+            ["tid"],
+            "some_table",
+            ["id"],
             initially="immediate",
-            deferrable=True
+            deferrable=True,
         )
 
     @config.requirements.fk_deferrable
@@ -814,20 +1047,26 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
         )
 
         self._assert_fk_diff(
-            diffs[0], "remove_fk",
-            "user", ["tid"],
-            "some_table", ["id"],
+            diffs[0],
+            "remove_fk",
+            "user",
+            ["tid"],
+            "some_table",
+            ["id"],
             initially=None,  # immediate is the default
             deferrable=True,
-            conditional_name="servergenerated"
+            conditional_name="servergenerated",
         )
 
         self._assert_fk_diff(
-            diffs[1], "add_fk",
-            "user", ["tid"],
-            "some_table", ["id"],
+            diffs[1],
+            "add_fk",
+            "user",
+            ["tid"],
+            "some_table",
+            ["id"],
             initially=None,
-            deferrable=None
+            deferrable=None,
         )
 
     @config.requirements.fk_initially
@@ -835,7 +1074,7 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
     def test_add_initially_deferrable_nochange_one(self):
         diffs = self._fk_opts_fixture(
             {"deferrable": True, "initially": "immediate"},
-            {"deferrable": True, "initially": "immediate"}
+            {"deferrable": True, "initially": "immediate"},
         )
 
         eq_(diffs, [])
@@ -845,7 +1084,7 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
     def test_add_initially_deferrable_nochange_two(self):
         diffs = self._fk_opts_fixture(
             {"deferrable": True, "initially": "deferred"},
-            {"deferrable": True, "initially": "deferred"}
+            {"deferrable": True, "initially": "deferred"},
         )
 
         eq_(diffs, [])
@@ -855,49 +1094,57 @@ class AutogenerateFKOptionsTest(AutogenFixtureTest, TestBase):
     def test_add_initially_deferrable_nochange_three(self):
         diffs = self._fk_opts_fixture(
             {"deferrable": None, "initially": "deferred"},
-            {"deferrable": None, "initially": "deferred"}
+            {"deferrable": None, "initially": "deferred"},
         )
 
         eq_(diffs, [])
 
     @config.requirements.fk_deferrable
     def test_add_deferrable(self):
-        diffs = self._fk_opts_fixture(
-            {}, {"deferrable": True}
-        )
+        diffs = self._fk_opts_fixture({}, {"deferrable": True})
 
         self._assert_fk_diff(
-            diffs[0], "remove_fk",
-            "user", ["tid"],
-            "some_table", ["id"],
+            diffs[0],
+            "remove_fk",
+            "user",
+            ["tid"],
+            "some_table",
+            ["id"],
             deferrable=None,
-            conditional_name="servergenerated"
+            conditional_name="servergenerated",
         )
 
         self._assert_fk_diff(
-            diffs[1], "add_fk",
-            "user", ["tid"],
-            "some_table", ["id"],
-            deferrable=True
+            diffs[1],
+            "add_fk",
+            "user",
+            ["tid"],
+            "some_table",
+            ["id"],
+            deferrable=True,
         )
 
     @config.requirements.fk_deferrable
     def test_remove_deferrable(self):
-        diffs = self._fk_opts_fixture(
-            {"deferrable": True}, {}
-        )
+        diffs = self._fk_opts_fixture({"deferrable": True}, {})
 
         self._assert_fk_diff(
-            diffs[0], "remove_fk",
-            "user", ["tid"],
-            "some_table", ["id"],
+            diffs[0],
+            "remove_fk",
+            "user",
+            ["tid"],
+            "some_table",
+            ["id"],
             deferrable=True,
-            conditional_name="servergenerated"
+            conditional_name="servergenerated",
         )
 
         self._assert_fk_diff(
-            diffs[1], "add_fk",
-            "user", ["tid"],
-            "some_table", ["id"],
-            deferrable=None
+            diffs[1],
+            "add_fk",
+            "user",
+            ["tid"],
+            "some_table",
+            ["id"],
+            deferrable=None,
         )
index b588cbeb333e3ecb96ff7d978d54e425ecd1c609..f03155fb1be1d17dbc179d5d1716b0e76b787044 100644 (file)
@@ -3,14 +3,24 @@ from alembic.testing import TestBase
 from alembic.testing import config
 from alembic.testing import assertions
 
-from sqlalchemy import MetaData, Column, Table, Integer, String, \
-    Numeric, UniqueConstraint, Index, ForeignKeyConstraint,\
-    ForeignKey, func
+from sqlalchemy import (
+    MetaData,
+    Column,
+    Table,
+    Integer,
+    String,
+    Numeric,
+    UniqueConstraint,
+    Index,
+    ForeignKeyConstraint,
+    ForeignKey,
+    func,
+)
 from alembic.testing import engines
 from alembic.testing import eq_
 from alembic.testing.env import staging_env
 
-py3k = sys.version_info >= (3, )
+py3k = sys.version_info >= (3,)
 
 from ._autogen_fixtures import AutogenFixtureTest
 
@@ -24,6 +34,7 @@ class NoUqReflection(object):
 
         def unimpl(*arg, **kw):
             raise NotImplementedError()
+
         eng.dialect.get_unique_constraints = unimpl
 
     def test_add_ix_on_table_create(self):
@@ -37,25 +48,29 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
     reports_unique_constraints = True
     reports_unique_constraints_as_indexes = False
 
-    __requires__ = ('unique_constraint_reflection', )
-    __only_on__ = 'sqlite'
+    __requires__ = ("unique_constraint_reflection",)
+    __only_on__ = "sqlite"
 
     def test_index_flag_becomes_named_unique_constraint(self):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('user', m1,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False, index=True),
-              Column('a1', String(10), server_default="x")
-              )
+        Table(
+            "user",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False, index=True),
+            Column("a1", String(10), server_default="x"),
+        )
 
-        Table('user', m2,
-              Column('id', Integer, primary_key=True),
-              Column('name', String(50), nullable=False),
-              Column('a1', String(10), server_default="x"),
-              UniqueConstraint("name", name="uq_user_name")
-              )
+        Table(
+            "user",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), nullable=False),
+            Column("a1", String(10), server_default="x"),
+            UniqueConstraint("name", name="uq_user_name"),
+        )
 
         diffs = self._fixture(m1, m2)
 
@@ -72,17 +87,21 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
     def test_add_unique_constraint(self):
         m1 = MetaData()
         m2 = MetaData()
-        Table('address', m1,
-              Column('id', Integer, primary_key=True),
-              Column('email_address', String(100), nullable=False),
-              Column('qpr', String(10), index=True),
-              )
-        Table('address', m2,
-              Column('id', Integer, primary_key=True),
-              Column('email_address', String(100), nullable=False),
-              Column('qpr', String(10), index=True),
-              UniqueConstraint("email_address", name="uq_email_address")
-              )
+        Table(
+            "address",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("email_address", String(100), nullable=False),
+            Column("qpr", String(10), index=True),
+        )
+        Table(
+            "address",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("email_address", String(100), nullable=False),
+            Column("qpr", String(10), index=True),
+            UniqueConstraint("email_address", name="uq_email_address"),
+        )
 
         diffs = self._fixture(m1, m2)
 
@@ -96,17 +115,21 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('unq_idx', m1,
-              Column('id', Integer, primary_key=True),
-              Column('x', String(20)),
-              Index('x', 'x', unique=True)
-              )
+        Table(
+            "unq_idx",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("x", String(20)),
+            Index("x", "x", unique=True),
+        )
 
-        Table('unq_idx', m2,
-              Column('id', Integer, primary_key=True),
-              Column('x', String(20)),
-              Index('x', 'x', unique=True)
-              )
+        Table(
+            "unq_idx",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("x", String(20)),
+            Index("x", "x", unique=True),
+        )
 
         diffs = self._fixture(m1, m2)
         eq_(diffs, [])
@@ -114,27 +137,31 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
     def test_index_becomes_unique(self):
         m1 = MetaData()
         m2 = MetaData()
-        Table('order', m1,
-              Column('order_id', Integer, primary_key=True),
-              Column('amount', Numeric(10, 2), nullable=True),
-              Column('user_id', Integer),
-              UniqueConstraint('order_id', 'user_id',
-                               name='order_order_id_user_id_unique'
-                               ),
-              Index('order_user_id_amount_idx', 'user_id', 'amount')
-              )
-
-        Table('order', m2,
-              Column('order_id', Integer, primary_key=True),
-              Column('amount', Numeric(10, 2), nullable=True),
-              Column('user_id', Integer),
-              UniqueConstraint('order_id', 'user_id',
-                               name='order_order_id_user_id_unique'
-                               ),
-              Index(
-                  'order_user_id_amount_idx', 'user_id',
-                  'amount', unique=True),
-              )
+        Table(
+            "order",
+            m1,
+            Column("order_id", Integer, primary_key=True),
+            Column("amount", Numeric(10, 2), nullable=True),
+            Column("user_id", Integer),
+            UniqueConstraint(
+                "order_id", "user_id", name="order_order_id_user_id_unique"
+            ),
+            Index("order_user_id_amount_idx", "user_id", "amount"),
+        )
+
+        Table(
+            "order",
+            m2,
+            Column("order_id", Integer, primary_key=True),
+            Column("amount", Numeric(10, 2), nullable=True),
+            Column("user_id", Integer),
+            UniqueConstraint(
+                "order_id", "user_id", name="order_order_id_user_id_unique"
+            ),
+            Index(
+                "order_user_id_amount_idx", "user_id", "amount", unique=True
+            ),
+        )
 
         diffs = self._fixture(m1, m2)
         eq_(diffs[0][0], "remove_index")
@@ -148,16 +175,16 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
     def test_mismatch_db_named_col_flag(self):
         m1 = MetaData()
         m2 = MetaData()
-        Table('item', m1,
-              Column('x', Integer),
-              UniqueConstraint('x', name="db_generated_name")
-              )
+        Table(
+            "item",
+            m1,
+            Column("x", Integer),
+            UniqueConstraint("x", name="db_generated_name"),
+        )
 
         # test mismatch between unique=True and
         # named uq constraint
-        Table('item', m2,
-              Column('x', Integer, unique=True)
-              )
+        Table("item", m2, Column("x", Integer, unique=True))
 
         diffs = self._fixture(m1, m2)
 
@@ -166,11 +193,13 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
     def test_new_table_added(self):
         m1 = MetaData()
         m2 = MetaData()
-        Table('extra', m2,
-              Column('foo', Integer, index=True),
-              Column('bar', Integer),
-              Index('newtable_idx', 'bar')
-              )
+        Table(
+            "extra",
+            m2,
+            Column("foo", Integer, index=True),
+            Column("bar", Integer),
+            Index("newtable_idx", "bar"),
+        )
 
         diffs = self._fixture(m1, m2)
 
@@ -185,16 +214,20 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
     def test_named_cols_changed(self):
         m1 = MetaData()
         m2 = MetaData()
-        Table('col_change', m1,
-              Column('x', Integer),
-              Column('y', Integer),
-              UniqueConstraint('x', name="nochange")
-              )
-        Table('col_change', m2,
-              Column('x', Integer),
-              Column('y', Integer),
-              UniqueConstraint('x', 'y', name="nochange")
-              )
+        Table(
+            "col_change",
+            m1,
+            Column("x", Integer),
+            Column("y", Integer),
+            UniqueConstraint("x", name="nochange"),
+        )
+        Table(
+            "col_change",
+            m2,
+            Column("x", Integer),
+            Column("y", Integer),
+            UniqueConstraint("x", "y", name="nochange"),
+        )
 
         diffs = self._fixture(m1, m2)
 
@@ -211,13 +244,17 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('nothing_changed', m1,
-              Column('x', String(20), unique=True, index=True)
-              )
+        Table(
+            "nothing_changed",
+            m1,
+            Column("x", String(20), unique=True, index=True),
+        )
 
-        Table('nothing_changed', m2,
-              Column('x', String(20), unique=True, index=True)
-              )
+        Table(
+            "nothing_changed",
+            m2,
+            Column("x", String(20), unique=True, index=True),
+        )
 
         diffs = self._fixture(m1, m2)
         eq_(diffs, [])
@@ -226,35 +263,43 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('nothing_changed', m1,
-              Column('id1', Integer, primary_key=True),
-              Column('id2', Integer, primary_key=True),
-              Column('x', String(20), unique=True),
-              mysql_engine='InnoDB'
-              )
-        Table('nothing_changed_related', m1,
-              Column('id1', Integer),
-              Column('id2', Integer),
-              ForeignKeyConstraint(
-                  ['id1', 'id2'],
-                  ['nothing_changed.id1', 'nothing_changed.id2']),
-              mysql_engine='InnoDB'
-              )
-
-        Table('nothing_changed', m2,
-              Column('id1', Integer, primary_key=True),
-              Column('id2', Integer, primary_key=True),
-              Column('x', String(20), unique=True),
-              mysql_engine='InnoDB'
-              )
-        Table('nothing_changed_related', m2,
-              Column('id1', Integer),
-              Column('id2', Integer),
-              ForeignKeyConstraint(
-                  ['id1', 'id2'],
-                  ['nothing_changed.id1', 'nothing_changed.id2']),
-              mysql_engine='InnoDB'
-              )
+        Table(
+            "nothing_changed",
+            m1,
+            Column("id1", Integer, primary_key=True),
+            Column("id2", Integer, primary_key=True),
+            Column("x", String(20), unique=True),
+            mysql_engine="InnoDB",
+        )
+        Table(
+            "nothing_changed_related",
+            m1,
+            Column("id1", Integer),
+            Column("id2", Integer),
+            ForeignKeyConstraint(
+                ["id1", "id2"], ["nothing_changed.id1", "nothing_changed.id2"]
+            ),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "nothing_changed",
+            m2,
+            Column("id1", Integer, primary_key=True),
+            Column("id2", Integer, primary_key=True),
+            Column("x", String(20), unique=True),
+            mysql_engine="InnoDB",
+        )
+        Table(
+            "nothing_changed_related",
+            m2,
+            Column("id1", Integer),
+            Column("id2", Integer),
+            ForeignKeyConstraint(
+                ["id1", "id2"], ["nothing_changed.id1", "nothing_changed.id2"]
+            ),
+            mysql_engine="InnoDB",
+        )
 
         diffs = self._fixture(m1, m2)
         eq_(diffs, [])
@@ -263,15 +308,19 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('nothing_changed', m1,
-              Column('x', String(20), key='nx'),
-              UniqueConstraint('nx')
-              )
+        Table(
+            "nothing_changed",
+            m1,
+            Column("x", String(20), key="nx"),
+            UniqueConstraint("nx"),
+        )
 
-        Table('nothing_changed', m2,
-              Column('x', String(20), key='nx'),
-              UniqueConstraint('nx')
-              )
+        Table(
+            "nothing_changed",
+            m2,
+            Column("x", String(20), key="nx"),
+            UniqueConstraint("nx"),
+        )
 
         diffs = self._fixture(m1, m2)
         eq_(diffs, [])
@@ -280,15 +329,19 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('nothing_changed', m1,
-              Column('x', String(20), key='nx'),
-              Index('foobar', 'nx')
-              )
+        Table(
+            "nothing_changed",
+            m1,
+            Column("x", String(20), key="nx"),
+            Index("foobar", "nx"),
+        )
 
-        Table('nothing_changed', m2,
-              Column('x', String(20), key='nx'),
-              Index('foobar', 'nx')
-              )
+        Table(
+            "nothing_changed",
+            m2,
+            Column("x", String(20), key="nx"),
+            Index("foobar", "nx"),
+        )
 
         diffs = self._fixture(m1, m2)
         eq_(diffs, [])
@@ -297,19 +350,23 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('nothing_changed', m1,
-              Column('id1', Integer, primary_key=True),
-              Column('id2', Integer, primary_key=True),
-              Column('x', String(20)),
-              Index('x', 'x')
-              )
+        Table(
+            "nothing_changed",
+            m1,
+            Column("id1", Integer, primary_key=True),
+            Column("id2", Integer, primary_key=True),
+            Column("x", String(20)),
+            Index("x", "x"),
+        )
 
-        Table('nothing_changed', m2,
-              Column('id1', Integer, primary_key=True),
-              Column('id2', Integer, primary_key=True),
-              Column('x', String(20)),
-              Index('x', 'x')
-              )
+        Table(
+            "nothing_changed",
+            m2,
+            Column("id1", Integer, primary_key=True),
+            Column("id2", Integer, primary_key=True),
+            Column("x", String(20)),
+            Index("x", "x"),
+        )
 
         diffs = self._fixture(m1, m2)
         eq_(diffs, [])
@@ -318,29 +375,43 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table("nothing_changed", m1,
-              Column('id', Integer, primary_key=True),
-              Column('other_id',
-                     ForeignKey('nc2.id',
-                                name='fk_my_table_other_table'
-                                ),
-                     nullable=False),
-              Column('foo', Integer),
-              mysql_engine='InnoDB')
-        Table('nc2', m1,
-              Column('id', Integer, primary_key=True),
-              mysql_engine='InnoDB')
-
-        Table("nothing_changed", m2,
-              Column('id', Integer, primary_key=True),
-              Column('other_id', ForeignKey('nc2.id',
-                                            name='fk_my_table_other_table'),
-                     nullable=False),
-              Column('foo', Integer),
-              mysql_engine='InnoDB')
-        Table('nc2', m2,
-              Column('id', Integer, primary_key=True),
-              mysql_engine='InnoDB')
+        Table(
+            "nothing_changed",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column(
+                "other_id",
+                ForeignKey("nc2.id", name="fk_my_table_other_table"),
+                nullable=False,
+            ),
+            Column("foo", Integer),
+            mysql_engine="InnoDB",
+        )
+        Table(
+            "nc2",
+            m1,
+            Column("id", Integer, primary_key=True),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "nothing_changed",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column(
+                "other_id",
+                ForeignKey("nc2.id", name="fk_my_table_other_table"),
+                nullable=False,
+            ),
+            Column("foo", Integer),
+            mysql_engine="InnoDB",
+        )
+        Table(
+            "nc2",
+            m2,
+            Column("id", Integer, primary_key=True),
+            mysql_engine="InnoDB",
+        )
         diffs = self._fixture(m1, m2)
         eq_(diffs, [])
 
@@ -348,35 +419,49 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table("nothing_changed", m1,
-              Column('id', Integer, primary_key=True),
-              Column('other_id_1', Integer),
-              Column('other_id_2', Integer),
-              Column('foo', Integer),
-              ForeignKeyConstraint(
-                  ['other_id_1', 'other_id_2'], ['nc2.id1', 'nc2.id2'],
-                  name='fk_my_table_other_table'
-              ),
-              mysql_engine='InnoDB')
-        Table('nc2', m1,
-              Column('id1', Integer, primary_key=True),
-              Column('id2', Integer, primary_key=True),
-              mysql_engine='InnoDB')
-
-        Table("nothing_changed", m2,
-              Column('id', Integer, primary_key=True),
-              Column('other_id_1', Integer),
-              Column('other_id_2', Integer),
-              Column('foo', Integer),
-              ForeignKeyConstraint(
-                  ['other_id_1', 'other_id_2'], ['nc2.id1', 'nc2.id2'],
-                  name='fk_my_table_other_table'
-              ),
-              mysql_engine='InnoDB')
-        Table('nc2', m2,
-              Column('id1', Integer, primary_key=True),
-              Column('id2', Integer, primary_key=True),
-              mysql_engine='InnoDB')
+        Table(
+            "nothing_changed",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("other_id_1", Integer),
+            Column("other_id_2", Integer),
+            Column("foo", Integer),
+            ForeignKeyConstraint(
+                ["other_id_1", "other_id_2"],
+                ["nc2.id1", "nc2.id2"],
+                name="fk_my_table_other_table",
+            ),
+            mysql_engine="InnoDB",
+        )
+        Table(
+            "nc2",
+            m1,
+            Column("id1", Integer, primary_key=True),
+            Column("id2", Integer, primary_key=True),
+            mysql_engine="InnoDB",
+        )
+
+        Table(
+            "nothing_changed",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("other_id_1", Integer),
+            Column("other_id_2", Integer),
+            Column("foo", Integer),
+            ForeignKeyConstraint(
+                ["other_id_1", "other_id_2"],
+                ["nc2.id1", "nc2.id2"],
+                name="fk_my_table_other_table",
+            ),
+            mysql_engine="InnoDB",
+        )
+        Table(
+            "nc2",
+            m2,
+            Column("id1", Integer, primary_key=True),
+            Column("id2", Integer, primary_key=True),
+            mysql_engine="InnoDB",
+        )
         diffs = self._fixture(m1, m2)
         eq_(diffs, [])
 
@@ -384,65 +469,73 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('new_idx', m1,
-              Column('id1', Integer, primary_key=True),
-              Column('id2', Integer, primary_key=True),
-              Column('x', String(20)),
-              )
+        Table(
+            "new_idx",
+            m1,
+            Column("id1", Integer, primary_key=True),
+            Column("id2", Integer, primary_key=True),
+            Column("x", String(20)),
+        )
 
-        idx = Index('x', 'x')
-        Table('new_idx', m2,
-              Column('id1', Integer, primary_key=True),
-              Column('id2', Integer, primary_key=True),
-              Column('x', String(20)),
-              idx
-              )
+        idx = Index("x", "x")
+        Table(
+            "new_idx",
+            m2,
+            Column("id1", Integer, primary_key=True),
+            Column("id2", Integer, primary_key=True),
+            Column("x", String(20)),
+            idx,
+        )
 
         diffs = self._fixture(m1, m2)
-        eq_(diffs, [('add_index', idx)])
+        eq_(diffs, [("add_index", idx)])
 
     def test_removed_idx_index_named_as_column(self):
         m1 = MetaData()
         m2 = MetaData()
 
-        idx = Index('x', 'x')
-        Table('new_idx', m1,
-              Column('id1', Integer, primary_key=True),
-              Column('id2', Integer, primary_key=True),
-              Column('x', String(20)),
-              idx
-              )
+        idx = Index("x", "x")
+        Table(
+            "new_idx",
+            m1,
+            Column("id1", Integer, primary_key=True),
+            Column("id2", Integer, primary_key=True),
+            Column("x", String(20)),
+            idx,
+        )
 
-        Table('new_idx', m2,
-              Column('id1', Integer, primary_key=True),
-              Column('id2', Integer, primary_key=True),
-              Column('x', String(20))
-              )
+        Table(
+            "new_idx",
+            m2,
+            Column("id1", Integer, primary_key=True),
+            Column("id2", Integer, primary_key=True),
+            Column("x", String(20)),
+        )
 
         diffs = self._fixture(m1, m2)
-        eq_(diffs[0][0], 'remove_index')
+        eq_(diffs[0][0], "remove_index")
 
     def test_drop_table_w_indexes(self):
         m1 = MetaData()
         m2 = MetaData()
 
         t = Table(
-            'some_table', m1,
-            Column('id', Integer, primary_key=True),
-            Column('x', String(20)),
-            Column('y', String(20)),
+            "some_table",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("x", String(20)),
+            Column("y", String(20)),
         )
-        Index('xy_idx', t.c.x, t.c.y)
-        Index('y_idx', t.c.y)
+        Index("xy_idx", t.c.x, t.c.y)
+        Index("y_idx", t.c.y)
 
         diffs = self._fixture(m1, m2)
-        eq_(diffs[0][0], 'remove_index')
-        eq_(diffs[1][0], 'remove_index')
-        eq_(diffs[2][0], 'remove_table')
+        eq_(diffs[0][0], "remove_index")
+        eq_(diffs[1][0], "remove_index")
+        eq_(diffs[2][0], "remove_table")
 
         eq_(
-            set([diffs[0][1].name, diffs[1][1].name]),
-            set(['xy_idx', 'y_idx'])
+            set([diffs[0][1].name, diffs[1][1].name]), set(["xy_idx", "y_idx"])
         )
 
     # this simply doesn't fully work before we had
@@ -453,11 +546,12 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         m2 = MetaData()
 
         Table(
-            'some_table', m1,
-            Column('id', Integer, primary_key=True),
-            Column('x', String(20)),
-            Column('y', String(20)),
-            UniqueConstraint('y', name='uq_y')
+            "some_table",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("x", String(20)),
+            Column("y", String(20)),
+            UniqueConstraint("y", name="uq_y"),
         )
 
         diffs = self._fixture(m1, m2)
@@ -465,65 +559,80 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         if self.reports_unique_constraints_as_indexes:
             # for MySQL this UQ will look like an index, so
             # make sure it at least sets it up correctly
-            eq_(diffs[0][0], 'remove_index')
-            eq_(diffs[1][0], 'remove_table')
+            eq_(diffs[0][0], "remove_index")
+            eq_(diffs[1][0], "remove_table")
             eq_(len(diffs), 2)
 
-            constraints = [c for c in diffs[1][1].constraints
-                           if isinstance(c, UniqueConstraint)]
+            constraints = [
+                c
+                for c in diffs[1][1].constraints
+                if isinstance(c, UniqueConstraint)
+            ]
             eq_(len(constraints), 0)
         else:
-            eq_(diffs[0][0], 'remove_table')
+            eq_(diffs[0][0], "remove_table")
             eq_(len(diffs), 1)
 
-            constraints = [c for c in diffs[0][1].constraints
-                           if isinstance(c, UniqueConstraint)]
+            constraints = [
+                c
+                for c in diffs[0][1].constraints
+                if isinstance(c, UniqueConstraint)
+            ]
             if self.reports_unique_constraints:
                 eq_(len(constraints), 1)
 
     def test_unnamed_cols_changed(self):
         m1 = MetaData()
         m2 = MetaData()
-        Table('col_change', m1,
-              Column('x', Integer),
-              Column('y', Integer),
-              UniqueConstraint('x')
-              )
-        Table('col_change', m2,
-              Column('x', Integer),
-              Column('y', Integer),
-              UniqueConstraint('x', 'y')
-              )
+        Table(
+            "col_change",
+            m1,
+            Column("x", Integer),
+            Column("y", Integer),
+            UniqueConstraint("x"),
+        )
+        Table(
+            "col_change",
+            m2,
+            Column("x", Integer),
+            Column("y", Integer),
+            UniqueConstraint("x", "y"),
+        )
 
         diffs = self._fixture(m1, m2)
 
-        diffs = set((cmd,
-                     ('x' in obj.name) if obj.name is not None else False)
-                    for cmd, obj in diffs)
+        diffs = set(
+            (cmd, ("x" in obj.name) if obj.name is not None else False)
+            for cmd, obj in diffs
+        )
         if self.reports_unnamed_constraints:
             if self.reports_unique_constraints_as_indexes:
                 eq_(
                     diffs,
-                    set([("remove_index", True), ("add_constraint", False)])
+                    set([("remove_index", True), ("add_constraint", False)]),
                 )
             else:
                 eq_(
                     diffs,
-                    set([("remove_constraint", True),
-                         ("add_constraint", False)])
+                    set(
+                        [
+                            ("remove_constraint", True),
+                            ("add_constraint", False),
+                        ]
+                    ),
                 )
 
     def test_remove_named_unique_index(self):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('remove_idx', m1,
-              Column('x', Integer),
-              Index('xidx', 'x', unique=True)
-              )
-        Table('remove_idx', m2,
-              Column('x', Integer)
-              )
+        Table(
+            "remove_idx",
+            m1,
+            Column("x", Integer),
+            Index("xidx", "x", unique=True),
+        )
+        Table("remove_idx", m2, Column("x", Integer))
 
         diffs = self._fixture(m1, m2)
 
@@ -537,13 +646,13 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('remove_idx', m1,
-              Column('x', Integer),
-              UniqueConstraint('x', name='xidx')
-              )
-        Table('remove_idx', m2,
-              Column('x', Integer),
-              )
+        Table(
+            "remove_idx",
+            m1,
+            Column("x", Integer),
+            UniqueConstraint("x", name="xidx"),
+        )
+        Table("remove_idx", m2, Column("x", Integer))
 
         diffs = self._fixture(m1, m2)
 
@@ -559,46 +668,49 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
     def test_dont_add_uq_on_table_create(self):
         m1 = MetaData()
         m2 = MetaData()
-        Table('no_uq', m2, Column('x', String(50), unique=True))
+        Table("no_uq", m2, Column("x", String(50), unique=True))
         diffs = self._fixture(m1, m2)
 
         eq_(diffs[0][0], "add_table")
         eq_(len(diffs), 1)
         assert UniqueConstraint in set(
-            type(c) for c in diffs[0][1].constraints)
+            type(c) for c in diffs[0][1].constraints
+        )
 
     def test_add_uq_ix_on_table_create(self):
         m1 = MetaData()
         m2 = MetaData()
-        Table('add_ix', m2, Column('x', String(50), unique=True, index=True))
+        Table("add_ix", m2, Column("x", String(50), unique=True, index=True))
         diffs = self._fixture(m1, m2)
 
         eq_(diffs[0][0], "add_table")
         eq_(len(diffs), 2)
         assert UniqueConstraint not in set(
-            type(c) for c in diffs[0][1].constraints)
+            type(c) for c in diffs[0][1].constraints
+        )
         eq_(diffs[1][0], "add_index")
         eq_(diffs[1][1].unique, True)
 
     def test_add_ix_on_table_create(self):
         m1 = MetaData()
         m2 = MetaData()
-        Table('add_ix', m2, Column('x', String(50), index=True))
+        Table("add_ix", m2, Column("x", String(50), index=True))
         diffs = self._fixture(m1, m2)
 
         eq_(diffs[0][0], "add_table")
         eq_(len(diffs), 2)
         assert UniqueConstraint not in set(
-            type(c) for c in diffs[0][1].constraints)
+            type(c) for c in diffs[0][1].constraints
+        )
         eq_(diffs[1][0], "add_index")
         eq_(diffs[1][1].unique, False)
 
     def test_add_idx_non_col(self):
         m1 = MetaData()
         m2 = MetaData()
-        Table('add_ix', m1, Column('x', String(50)))
-        t2 = Table('add_ix', m2, Column('x', String(50)))
-        Index('foo_idx', t2.c.x.desc())
+        Table("add_ix", m1, Column("x", String(50)))
+        t2 = Table("add_ix", m2, Column("x", String(50)))
+        Index("foo_idx", t2.c.x.desc())
         diffs = self._fixture(m1, m2)
 
         eq_(diffs[0][0], "add_index")
@@ -606,10 +718,10 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
     def test_unchanged_idx_non_col(self):
         m1 = MetaData()
         m2 = MetaData()
-        t1 = Table('add_ix', m1, Column('x', String(50)))
-        Index('foo_idx', t1.c.x.desc())
-        t2 = Table('add_ix', m2, Column('x', String(50)))
-        Index('foo_idx', t2.c.x.desc())
+        t1 = Table("add_ix", m1, Column("x", String(50)))
+        Index("foo_idx", t1.c.x.desc())
+        t2 = Table("add_ix", m2, Column("x", String(50)))
+        Index("foo_idx", t2.c.x.desc())
         diffs = self._fixture(m1, m2)
 
         eq_(diffs, [])
@@ -622,8 +734,8 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
     def test_unchanged_case_sensitive_implicit_idx(self):
         m1 = MetaData()
         m2 = MetaData()
-        Table('add_ix', m1, Column('regNumber', String(50), index=True))
-        Table('add_ix', m2, Column('regNumber', String(50), index=True))
+        Table("add_ix", m1, Column("regNumber", String(50), index=True))
+        Table("add_ix", m2, Column("regNumber", String(50), index=True))
         diffs = self._fixture(m1, m2)
 
         eq_(diffs, [])
@@ -631,10 +743,10 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
     def test_unchanged_case_sensitive_explicit_idx(self):
         m1 = MetaData()
         m2 = MetaData()
-        t1 = Table('add_ix', m1, Column('reg_number', String(50)))
-        Index('regNumber_idx', t1.c.reg_number)
-        t2 = Table('add_ix', m2, Column('reg_number', String(50)))
-        Index('regNumber_idx', t2.c.reg_number)
+        t1 = Table("add_ix", m1, Column("reg_number", String(50)))
+        Index("regNumber_idx", t1.c.reg_number)
+        t2 = Table("add_ix", m2, Column("reg_number", String(50)))
+        Index("regNumber_idx", t2.c.reg_number)
 
         diffs = self._fixture(m1, m2)
 
@@ -649,21 +761,36 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
     def test_idx_added_schema(self):
         m1 = MetaData()
         m2 = MetaData()
-        Table('add_ix', m1, Column('x', String(50)), schema="test_schema")
-        Table('add_ix', m2, Column('x', String(50)),
-              Index('ix_1', 'x'), schema="test_schema")
+        Table("add_ix", m1, Column("x", String(50)), schema="test_schema")
+        Table(
+            "add_ix",
+            m2,
+            Column("x", String(50)),
+            Index("ix_1", "x"),
+            schema="test_schema",
+        )
 
         diffs = self._fixture(m1, m2, include_schemas=True)
         eq_(diffs[0][0], "add_index")
-        eq_(diffs[0][1].name, 'ix_1')
+        eq_(diffs[0][1].name, "ix_1")
 
     def test_idx_unchanged_schema(self):
         m1 = MetaData()
         m2 = MetaData()
-        Table('add_ix', m1, Column('x', String(50)), Index('ix_1', 'x'),
-              schema="test_schema")
-        Table('add_ix', m2, Column('x', String(50)),
-              Index('ix_1', 'x'), schema="test_schema")
+        Table(
+            "add_ix",
+            m1,
+            Column("x", String(50)),
+            Index("ix_1", "x"),
+            schema="test_schema",
+        )
+        Table(
+            "add_ix",
+            m2,
+            Column("x", String(50)),
+            Index("ix_1", "x"),
+            schema="test_schema",
+        )
 
         diffs = self._fixture(m1, m2, include_schemas=True)
         eq_(diffs, [])
@@ -671,23 +798,36 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
     def test_uq_added_schema(self):
         m1 = MetaData()
         m2 = MetaData()
-        Table('add_uq', m1, Column('x', String(50)), schema="test_schema")
-        Table('add_uq', m2, Column('x', String(50)),
-              UniqueConstraint('x', name='ix_1'), schema="test_schema")
+        Table("add_uq", m1, Column("x", String(50)), schema="test_schema")
+        Table(
+            "add_uq",
+            m2,
+            Column("x", String(50)),
+            UniqueConstraint("x", name="ix_1"),
+            schema="test_schema",
+        )
 
         diffs = self._fixture(m1, m2, include_schemas=True)
         eq_(diffs[0][0], "add_constraint")
-        eq_(diffs[0][1].name, 'ix_1')
+        eq_(diffs[0][1].name, "ix_1")
 
     def test_uq_unchanged_schema(self):
         m1 = MetaData()
         m2 = MetaData()
-        Table('add_uq', m1, Column('x', String(50)),
-              UniqueConstraint('x', name='ix_1'),
-              schema="test_schema")
-        Table('add_uq', m2, Column('x', String(50)),
-              UniqueConstraint('x', name='ix_1'),
-              schema="test_schema")
+        Table(
+            "add_uq",
+            m1,
+            Column("x", String(50)),
+            UniqueConstraint("x", name="ix_1"),
+            schema="test_schema",
+        )
+        Table(
+            "add_uq",
+            m2,
+            Column("x", String(50)),
+            UniqueConstraint("x", name="ix_1"),
+            schema="test_schema",
+        )
 
         diffs = self._fixture(m1, m2, include_schemas=True)
         eq_(diffs, [])
@@ -701,17 +841,19 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
         m2 = MetaData()
 
         Table(
-            'add_excl', m1,
-            Column('id', Integer, primary_key=True),
-            Column('period', TSRANGE),
-            ExcludeConstraint(('period', '&&'), name='quarters_period_excl')
+            "add_excl",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("period", TSRANGE),
+            ExcludeConstraint(("period", "&&"), name="quarters_period_excl"),
         )
 
         Table(
-            'add_excl', m2,
-            Column('id', Integer, primary_key=True),
-            Column('period', TSRANGE),
-            ExcludeConstraint(('period', '&&'), name='quarters_period_excl')
+            "add_excl",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("period", TSRANGE),
+            ExcludeConstraint(("period", "&&"), name="quarters_period_excl"),
         )
 
         diffs = self._fixture(m1, m2)
@@ -721,10 +863,10 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('add_ix', m1, Column('x', String(50)), Index('ix_1', 'x'))
+        Table("add_ix", m1, Column("x", String(50)), Index("ix_1", "x"))
 
-        Table('add_ix', m2, Column('x', String(50)), Index('ix_1', 'x'))
-        Table('add_ix', m2, Column('x', String(50)), schema="test_schema")
+        Table("add_ix", m2, Column("x", String(50)), Index("ix_1", "x"))
+        Table("add_ix", m2, Column("x", String(50)), schema="test_schema")
 
         diffs = self._fixture(m1, m2, include_schemas=True)
         eq_(diffs[0][0], "add_table")
@@ -734,15 +876,17 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
         m1 = MetaData()
         m2 = MetaData()
         Table(
-            'add_uq', m1,
-            Column('id', Integer, primary_key=True),
-            Column('name', String),
-            UniqueConstraint('name', name='uq_name')
+            "add_uq",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("name", String),
+            UniqueConstraint("name", name="uq_name"),
         )
         Table(
-            'add_uq', m2,
-            Column('id', Integer, primary_key=True),
-            Column('name', String),
+            "add_uq",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("name", String),
         )
         diffs = self._fixture(m1, m2, include_schemas=True)
         eq_(diffs[0][0], "remove_constraint")
@@ -754,22 +898,24 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
         m2 = MetaData()
 
         t1 = Table(
-            'foo', m1,
-            Column('id', Integer, primary_key=True),
-            Column('email', String(50))
+            "foo",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("email", String(50)),
         )
         Index("email_idx", func.lower(t1.c.email), unique=True)
 
         t2 = Table(
-            'foo', m2,
-            Column('id', Integer, primary_key=True),
-            Column('email', String(50))
+            "foo",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("email", String(50)),
         )
         Index("email_idx", func.lower(t2.c.email), unique=True)
 
         with assertions.expect_warnings(
-                "Skipped unsupported reflection",
-                "autogenerate skipping functional index"
+            "Skipped unsupported reflection",
+            "autogenerate skipping functional index",
         ):
             diffs = self._fixture(m1, m2)
         eq_(diffs, [])
@@ -779,28 +925,34 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
         m2 = MetaData()
 
         t1 = Table(
-            'foo', m1,
-            Column('id', Integer, primary_key=True),
-            Column('email', String(50)),
-            Column('name', String(50))
+            "foo",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("email", String(50)),
+            Column("name", String(50)),
         )
         Index(
             "email_idx",
-            func.coalesce(t1.c.email, t1.c.name).desc(), unique=True)
+            func.coalesce(t1.c.email, t1.c.name).desc(),
+            unique=True,
+        )
 
         t2 = Table(
-            'foo', m2,
-            Column('id', Integer, primary_key=True),
-            Column('email', String(50)),
-            Column('name', String(50))
+            "foo",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("email", String(50)),
+            Column("name", String(50)),
         )
         Index(
             "email_idx",
-            func.coalesce(t2.c.email, t2.c.name).desc(), unique=True)
+            func.coalesce(t2.c.email, t2.c.name).desc(),
+            unique=True,
+        )
 
         with assertions.expect_warnings(
-                "Skipped unsupported reflection",
-                "autogenerate skipping functional index"
+            "Skipped unsupported reflection",
+            "autogenerate skipping functional index",
         ):
             diffs = self._fixture(m1, m2)
         eq_(diffs, [])
@@ -809,13 +961,14 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
 class MySQLUniqueIndexTest(AutogenerateUniqueIndexTest):
     reports_unnamed_constraints = True
     reports_unique_constraints_as_indexes = True
-    __only_on__ = 'mysql'
+    __only_on__ = "mysql"
     __backend__ = True
 
     def test_removed_idx_index_named_as_column(self):
         try:
-            super(MySQLUniqueIndexTest,
-                  self).test_removed_idx_index_named_as_column()
+            super(
+                MySQLUniqueIndexTest, self
+            ).test_removed_idx_index_named_as_column()
         except IndexError:
             assert True
         else:
@@ -828,61 +981,70 @@ class OracleUniqueIndexTest(AutogenerateUniqueIndexTest):
     __only_on__ = "oracle"
     __backend__ = True
 
+
 class NoUqReflectionIndexTest(NoUqReflection, AutogenerateUniqueIndexTest):
     reports_unique_constraints = False
-    __only_on__ = 'sqlite'
+    __only_on__ = "sqlite"
 
     def test_unique_not_reported(self):
         m1 = MetaData()
-        Table('order', m1,
-              Column('order_id', Integer, primary_key=True),
-              Column('amount', Numeric(10, 2), nullable=True),
-              Column('user_id', Integer),
-              UniqueConstraint('order_id', 'user_id',
-                               name='order_order_id_user_id_unique'
-                               )
-              )
+        Table(
+            "order",
+            m1,
+            Column("order_id", Integer, primary_key=True),
+            Column("amount", Numeric(10, 2), nullable=True),
+            Column("user_id", Integer),
+            UniqueConstraint(
+                "order_id", "user_id", name="order_order_id_user_id_unique"
+            ),
+        )
 
         diffs = self._fixture(m1, m1)
         eq_(diffs, [])
 
     def test_remove_unique_index_not_reported(self):
         m1 = MetaData()
-        Table('order', m1,
-              Column('order_id', Integer, primary_key=True),
-              Column('amount', Numeric(10, 2), nullable=True),
-              Column('user_id', Integer),
-              Index('oid_ix', 'order_id', 'user_id',
-                    unique=True
-                    )
-              )
+        Table(
+            "order",
+            m1,
+            Column("order_id", Integer, primary_key=True),
+            Column("amount", Numeric(10, 2), nullable=True),
+            Column("user_id", Integer),
+            Index("oid_ix", "order_id", "user_id", unique=True),
+        )
         m2 = MetaData()
-        Table('order', m2,
-              Column('order_id', Integer, primary_key=True),
-              Column('amount', Numeric(10, 2), nullable=True),
-              Column('user_id', Integer),
-              )
+        Table(
+            "order",
+            m2,
+            Column("order_id", Integer, primary_key=True),
+            Column("amount", Numeric(10, 2), nullable=True),
+            Column("user_id", Integer),
+        )
 
         diffs = self._fixture(m1, m2)
         eq_(diffs, [])
 
     def test_remove_plain_index_is_reported(self):
         m1 = MetaData()
-        Table('order', m1,
-              Column('order_id', Integer, primary_key=True),
-              Column('amount', Numeric(10, 2), nullable=True),
-              Column('user_id', Integer),
-              Index('oid_ix', 'order_id', 'user_id')
-              )
+        Table(
+            "order",
+            m1,
+            Column("order_id", Integer, primary_key=True),
+            Column("amount", Numeric(10, 2), nullable=True),
+            Column("user_id", Integer),
+            Index("oid_ix", "order_id", "user_id"),
+        )
         m2 = MetaData()
-        Table('order', m2,
-              Column('order_id', Integer, primary_key=True),
-              Column('amount', Numeric(10, 2), nullable=True),
-              Column('user_id', Integer),
-              )
+        Table(
+            "order",
+            m2,
+            Column("order_id", Integer, primary_key=True),
+            Column("amount", Numeric(10, 2), nullable=True),
+            Column("user_id", Integer),
+        )
 
         diffs = self._fixture(m1, m2)
-        eq_(diffs[0][0], 'remove_index')
+        eq_(diffs[0][0], "remove_index")
 
 
 class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest):
@@ -899,7 +1061,7 @@ class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest):
 
     """
 
-    __only_on__ = 'sqlite'
+    __only_on__ = "sqlite"
 
     @classmethod
     def _get_bind(cls):
@@ -916,7 +1078,7 @@ class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest):
             for uq in _get_unique_constraints(
                 self, connection, tablename, **kw
             ):
-                uq['unique'] = True
+                uq["unique"] = True
                 indexes.append(uq)
             return indexes
 
@@ -932,23 +1094,26 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        t1 = Table('t', m1, Column('x', Integer), Column('y', Integer))
-        Index('ix1', t1.c.x)
-        Index('ix2', t1.c.y)
+        t1 = Table("t", m1, Column("x", Integer), Column("y", Integer))
+        Index("ix1", t1.c.x)
+        Index("ix2", t1.c.y)
 
-        Table('t', m2, Column('x', Integer), Column('y', Integer))
+        Table("t", m2, Column("x", Integer), Column("y", Integer))
 
         def include_object(object_, name, type_, reflected, compare_to):
-            if type_ == 'unique_constraint':
+            if type_ == "unique_constraint":
                 return False
             return not (
-                isinstance(object_, Index) and
-                type_ == 'index' and reflected and name == 'ix1')
+                isinstance(object_, Index)
+                and type_ == "index"
+                and reflected
+                and name == "ix1"
+            )
 
         diffs = self._fixture(m1, m2, object_filters=include_object)
 
-        eq_(diffs[0][0], 'remove_index')
-        eq_(diffs[0][1].name, 'ix2')
+        eq_(diffs[0][0], "remove_index")
+        eq_(diffs[0][1].name, "ix2")
         eq_(len(diffs), 1)
 
     @config.requirements.unique_constraint_reflection
@@ -958,45 +1123,54 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
         m2 = MetaData()
 
         Table(
-            't', m1, Column('x', Integer), Column('y', Integer),
-            UniqueConstraint('x', name='uq1'),
-            UniqueConstraint('y', name='uq2'),
+            "t",
+            m1,
+            Column("x", Integer),
+            Column("y", Integer),
+            UniqueConstraint("x", name="uq1"),
+            UniqueConstraint("y", name="uq2"),
         )
 
-        Table('t', m2, Column('x', Integer), Column('y', Integer))
+        Table("t", m2, Column("x", Integer), Column("y", Integer))
 
         def include_object(object_, name, type_, reflected, compare_to):
-            if type_ == 'index':
+            if type_ == "index":
                 return False
             return not (
-                isinstance(object_, UniqueConstraint) and
-                type_ == 'unique_constraint' and reflected and name == 'uq1')
+                isinstance(object_, UniqueConstraint)
+                and type_ == "unique_constraint"
+                and reflected
+                and name == "uq1"
+            )
 
         diffs = self._fixture(m1, m2, object_filters=include_object)
 
-        eq_(diffs[0][0], 'remove_constraint')
-        eq_(diffs[0][1].name, 'uq2')
+        eq_(diffs[0][0], "remove_constraint")
+        eq_(diffs[0][1].name, "uq2")
         eq_(len(diffs), 1)
 
     def test_add_metadata_index(self):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('t', m1, Column('x', Integer))
+        Table("t", m1, Column("x", Integer))
 
-        t2 = Table('t', m2, Column('x', Integer))
-        Index('ix1', t2.c.x)
-        Index('ix2', t2.c.x)
+        t2 = Table("t", m2, Column("x", Integer))
+        Index("ix1", t2.c.x)
+        Index("ix2", t2.c.x)
 
         def include_object(object_, name, type_, reflected, compare_to):
             return not (
-                isinstance(object_, Index) and
-                type_ == 'index' and not reflected and name == 'ix1')
+                isinstance(object_, Index)
+                and type_ == "index"
+                and not reflected
+                and name == "ix1"
+            )
 
         diffs = self._fixture(m1, m2, object_filters=include_object)
 
-        eq_(diffs[0][0], 'add_index')
-        eq_(diffs[0][1].name, 'ix2')
+        eq_(diffs[0][0], "add_index")
+        eq_(diffs[0][1].name, "ix2")
         eq_(len(diffs), 1)
 
     @config.requirements.unique_constraint_reflection
@@ -1004,24 +1178,28 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
         m1 = MetaData()
         m2 = MetaData()
 
-        Table('t', m1, Column('x', Integer))
+        Table("t", m1, Column("x", Integer))
 
         Table(
-            't', m2, Column('x', Integer),
-            UniqueConstraint('x', name='uq1'),
-            UniqueConstraint('x', name='uq2')
+            "t",
+            m2,
+            Column("x", Integer),
+            UniqueConstraint("x", name="uq1"),
+            UniqueConstraint("x", name="uq2"),
         )
 
         def include_object(object_, name, type_, reflected, compare_to):
             return not (
-                isinstance(object_, UniqueConstraint) and
-                type_ == 'unique_constraint' and
-                not reflected and name == 'uq1')
+                isinstance(object_, UniqueConstraint)
+                and type_ == "unique_constraint"
+                and not reflected
+                and name == "uq1"
+            )
 
         diffs = self._fixture(m1, m2, object_filters=include_object)
 
-        eq_(diffs[0][0], 'add_constraint')
-        eq_(diffs[0][1].name, 'uq2')
+        eq_(diffs[0][0], "add_constraint")
+        eq_(diffs[0][1].name, "uq2")
         eq_(len(diffs), 1)
 
     def test_change_index(self):
@@ -1029,29 +1207,40 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
         m2 = MetaData()
 
         t1 = Table(
-            't', m1, Column('x', Integer),
-            Column('y', Integer), Column('z', Integer))
-        Index('ix1', t1.c.x)
-        Index('ix2', t1.c.y)
+            "t",
+            m1,
+            Column("x", Integer),
+            Column("y", Integer),
+            Column("z", Integer),
+        )
+        Index("ix1", t1.c.x)
+        Index("ix2", t1.c.y)
 
         t2 = Table(
-            't', m2, Column('x', Integer),
-            Column('y', Integer), Column('z', Integer))
-        Index('ix1', t2.c.x, t2.c.y)
-        Index('ix2', t2.c.x, t2.c.z)
+            "t",
+            m2,
+            Column("x", Integer),
+            Column("y", Integer),
+            Column("z", Integer),
+        )
+        Index("ix1", t2.c.x, t2.c.y)
+        Index("ix2", t2.c.x, t2.c.z)
 
         def include_object(object_, name, type_, reflected, compare_to):
             return not (
-                isinstance(object_, Index) and
-                type_ == 'index' and not reflected and name == 'ix1'
-                and isinstance(compare_to, Index))
+                isinstance(object_, Index)
+                and type_ == "index"
+                and not reflected
+                and name == "ix1"
+                and isinstance(compare_to, Index)
+            )
 
         diffs = self._fixture(m1, m2, object_filters=include_object)
 
-        eq_(diffs[0][0], 'remove_index')
-        eq_(diffs[0][1].name, 'ix2')
-        eq_(diffs[1][0], 'add_index')
-        eq_(diffs[1][1].name, 'ix2')
+        eq_(diffs[0][0], "remove_index")
+        eq_(diffs[0][1].name, "ix2")
+        eq_(diffs[1][0], "add_index")
+        eq_(diffs[1][1].name, "ix2")
         eq_(len(diffs), 2)
 
     @config.requirements.unique_constraint_reflection
@@ -1060,39 +1249,46 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
         m2 = MetaData()
 
         Table(
-            't', m1, Column('x', Integer),
-            Column('y', Integer), Column('z', Integer),
-            UniqueConstraint('x', name='uq1'),
-            UniqueConstraint('y', name='uq2')
+            "t",
+            m1,
+            Column("x", Integer),
+            Column("y", Integer),
+            Column("z", Integer),
+            UniqueConstraint("x", name="uq1"),
+            UniqueConstraint("y", name="uq2"),
         )
 
         Table(
-            't', m2, Column('x', Integer), Column('y', Integer),
-            Column('z', Integer),
-            UniqueConstraint('x', 'z', name='uq1'),
-            UniqueConstraint('y', 'z', name='uq2')
+            "t",
+            m2,
+            Column("x", Integer),
+            Column("y", Integer),
+            Column("z", Integer),
+            UniqueConstraint("x", "z", name="uq1"),
+            UniqueConstraint("y", "z", name="uq2"),
         )
 
         def include_object(object_, name, type_, reflected, compare_to):
-            if type_ == 'index':
+            if type_ == "index":
                 return False
             return not (
-                isinstance(object_, UniqueConstraint) and
-                type_ == 'unique_constraint' and
-                not reflected and name == 'uq1'
-                and isinstance(compare_to, UniqueConstraint))
+                isinstance(object_, UniqueConstraint)
+                and type_ == "unique_constraint"
+                and not reflected
+                and name == "uq1"
+                and isinstance(compare_to, UniqueConstraint)
+            )
 
         diffs = self._fixture(m1, m2, object_filters=include_object)
 
-        eq_(diffs[0][0], 'remove_constraint')
-        eq_(diffs[0][1].name, 'uq2')
-        eq_(diffs[1][0], 'add_constraint')
-        eq_(diffs[1][1].name, 'uq2')
+        eq_(diffs[0][0], "remove_constraint")
+        eq_(diffs[0][1].name, "uq2")
+        eq_(diffs[1][0], "add_constraint")
+        eq_(diffs[1][1].name, "uq2")
         eq_(len(diffs), 2)
 
 
 class TruncatedIdxTest(AutogenFixtureTest, TestBase):
-
     def setUp(self):
         self.bind = engines.testing_engine()
         self.bind.dialect.max_identifier_length = 30
@@ -1102,12 +1298,13 @@ class TruncatedIdxTest(AutogenFixtureTest, TestBase):
 
         m1 = MetaData()
         Table(
-            'q', m1,
-            Column('id', Integer, primary_key=True),
-            Column('data', Integer),
+            "q",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("data", Integer),
             Index(
-                conv("idx_q_table_this_is_more_than_thirty_characters"),
-                "data")
+                conv("idx_q_table_this_is_more_than_thirty_characters"), "data"
+            ),
         )
 
         diffs = self._fixture(m1, m1)
index b32358fa0ad7e5dba17086866979dd72e5977fea..37e1c618c29361e268d9100ec135ccf4aca87f4a 100644 (file)
@@ -4,11 +4,31 @@ from alembic.testing import TestBase, exclusions, assert_raises
 from alembic.testing import assertions
 
 from alembic.operations import ops
-from sqlalchemy import MetaData, Column, Table, String, \
-    Numeric, CHAR, ForeignKey, DATETIME, Integer, BigInteger, \
-    CheckConstraint, Unicode, Enum, cast,\
-    DateTime, UniqueConstraint, Boolean, ForeignKeyConstraint,\
-    PrimaryKeyConstraint, Index, func, text, DefaultClause
+from sqlalchemy import (
+    MetaData,
+    Column,
+    Table,
+    String,
+    Numeric,
+    CHAR,
+    ForeignKey,
+    DATETIME,
+    Integer,
+    BigInteger,
+    CheckConstraint,
+    Unicode,
+    Enum,
+    cast,
+    DateTime,
+    UniqueConstraint,
+    Boolean,
+    ForeignKeyConstraint,
+    PrimaryKeyConstraint,
+    Index,
+    func,
+    text,
+    DefaultClause,
+)
 
 from sqlalchemy.types import TIMESTAMP
 from sqlalchemy import types
@@ -28,7 +48,7 @@ from alembic.testing.fixtures import op_fixture
 from alembic import op  # noqa
 import sqlalchemy as sa  # noqa
 
-py3k = sys.version_info >= (3, )
+py3k = sys.version_info >= (3,)
 
 
 class AutogenRenderTest(TestBase):
@@ -37,13 +57,12 @@ class AutogenRenderTest(TestBase):
 
     def setUp(self):
         ctx_opts = {
-            'sqlalchemy_module_prefix': 'sa.',
-            'alembic_module_prefix': 'op.',
-            'target_metadata': MetaData()
+            "sqlalchemy_module_prefix": "sa.",
+            "alembic_module_prefix": "op.",
+            "target_metadata": MetaData(),
         }
         context = MigrationContext.configure(
-            dialect=DefaultDialect(),
-            opts=ctx_opts
+            dialect=DefaultDialect(), opts=ctx_opts
         )
 
         self.autogen_context = api.AutogenContext(context)
@@ -53,17 +72,19 @@ class AutogenRenderTest(TestBase):
         autogenerate.render._add_index
         """
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  )
-        idx = Index('test_active_code_idx', t.c.active, t.c.code)
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+        )
+        idx = Index("test_active_code_idx", t.c.active, t.c.code)
         op_obj = ops.CreateIndexOp.from_index(idx)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_index('test_active_code_idx', 'test', "
-            "['active', 'code'], unique=False)"
+            "['active', 'code'], unique=False)",
         )
 
     def test_render_add_index_batch(self):
@@ -71,18 +92,20 @@ class AutogenRenderTest(TestBase):
         autogenerate.render._add_index
         """
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  )
-        idx = Index('test_active_code_idx', t.c.active, t.c.code)
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+        )
+        idx = Index("test_active_code_idx", t.c.active, t.c.code)
         op_obj = ops.CreateIndexOp.from_index(idx)
         with self.autogen_context._within_batch():
             eq_ignore_whitespace(
                 autogenerate.render_op_text(self.autogen_context, op_obj),
                 "batch_op.create_index('test_active_code_idx', "
-                "['active', 'code'], unique=False)"
+                "['active', 'code'], unique=False)",
             )
 
     def test_render_add_index_schema(self):
@@ -90,18 +113,20 @@ class AutogenRenderTest(TestBase):
         autogenerate.render._add_index using schema
         """
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  schema='CamelSchema'
-                  )
-        idx = Index('test_active_code_idx', t.c.active, t.c.code)
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+            schema="CamelSchema",
+        )
+        idx = Index("test_active_code_idx", t.c.active, t.c.code)
         op_obj = ops.CreateIndexOp.from_index(idx)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_index('test_active_code_idx', 'test', "
-            "['active', 'code'], unique=False, schema='CamelSchema')"
+            "['active', 'code'], unique=False, schema='CamelSchema')",
         )
 
     def test_render_add_index_schema_batch(self):
@@ -109,73 +134,78 @@ class AutogenRenderTest(TestBase):
         autogenerate.render._add_index using schema
         """
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  schema='CamelSchema'
-                  )
-        idx = Index('test_active_code_idx', t.c.active, t.c.code)
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+            schema="CamelSchema",
+        )
+        idx = Index("test_active_code_idx", t.c.active, t.c.code)
         op_obj = ops.CreateIndexOp.from_index(idx)
         with self.autogen_context._within_batch():
             eq_ignore_whitespace(
                 autogenerate.render_op_text(self.autogen_context, op_obj),
                 "batch_op.create_index('test_active_code_idx', "
-                "['active', 'code'], unique=False)"
+                "['active', 'code'], unique=False)",
             )
 
     def test_render_add_index_func(self):
         m = MetaData()
         t = Table(
-            'test', m,
-            Column('id', Integer, primary_key=True),
-            Column('code', String(255))
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("code", String(255)),
         )
-        idx = Index('test_lower_code_idx', func.lower(t.c.code))
+        idx = Index("test_lower_code_idx", func.lower(t.c.code))
         op_obj = ops.CreateIndexOp.from_index(idx)
 
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_index('test_lower_code_idx', 'test', "
-            "[sa.text(!U'lower(code)')], unique=False)"
+            "[sa.text(!U'lower(code)')], unique=False)",
         )
 
     def test_render_add_index_cast(self):
         m = MetaData()
         t = Table(
-            'test', m,
-            Column('id', Integer, primary_key=True),
-            Column('code', String(255))
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("code", String(255)),
         )
-        idx = Index('test_lower_code_idx', cast(t.c.code, String))
+        idx = Index("test_lower_code_idx", cast(t.c.code, String))
         op_obj = ops.CreateIndexOp.from_index(idx)
 
         if config.requirements.sqlalchemy_110.enabled:
             eq_ignore_whitespace(
                 autogenerate.render_op_text(self.autogen_context, op_obj),
                 "op.create_index('test_lower_code_idx', 'test', "
-                "[sa.text(!U'CAST(code AS VARCHAR)')], unique=False)"
+                "[sa.text(!U'CAST(code AS VARCHAR)')], unique=False)",
             )
         else:
             eq_ignore_whitespace(
                 autogenerate.render_op_text(self.autogen_context, op_obj),
                 "op.create_index('test_lower_code_idx', 'test', "
-                "[sa.text(!U'CAST(code AS VARCHAR)')], unique=False)"
+                "[sa.text(!U'CAST(code AS VARCHAR)')], unique=False)",
             )
 
     def test_render_add_index_desc(self):
         m = MetaData()
         t = Table(
-            'test', m,
-            Column('id', Integer, primary_key=True),
-            Column('code', String(255))
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("code", String(255)),
         )
-        idx = Index('test_desc_code_idx', t.c.code.desc())
+        idx = Index("test_desc_code_idx", t.c.code.desc())
         op_obj = ops.CreateIndexOp.from_index(idx)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_index('test_desc_code_idx', 'test', "
-            "[sa.text(!U'code DESC')], unique=False)"
+            "[sa.text(!U'code DESC')], unique=False)",
         )
 
     def test_drop_index(self):
@@ -183,16 +213,18 @@ class AutogenRenderTest(TestBase):
         autogenerate.render._drop_index
         """
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  )
-        idx = Index('test_active_code_idx', t.c.active, t.c.code)
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+        )
+        idx = Index("test_active_code_idx", t.c.active, t.c.code)
         op_obj = ops.DropIndexOp.from_index(idx)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
-            "op.drop_index('test_active_code_idx', table_name='test')"
+            "op.drop_index('test_active_code_idx', table_name='test')",
         )
 
     def test_drop_index_batch(self):
@@ -200,17 +232,19 @@ class AutogenRenderTest(TestBase):
         autogenerate.render._drop_index
         """
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  )
-        idx = Index('test_active_code_idx', t.c.active, t.c.code)
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+        )
+        idx = Index("test_active_code_idx", t.c.active, t.c.code)
         op_obj = ops.DropIndexOp.from_index(idx)
         with self.autogen_context._within_batch():
             eq_ignore_whitespace(
                 autogenerate.render_op_text(self.autogen_context, op_obj),
-                "batch_op.drop_index('test_active_code_idx')"
+                "batch_op.drop_index('test_active_code_idx')",
             )
 
     def test_drop_index_schema(self):
@@ -218,18 +252,20 @@ class AutogenRenderTest(TestBase):
         autogenerate.render._drop_index using schema
         """
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  schema='CamelSchema'
-                  )
-        idx = Index('test_active_code_idx', t.c.active, t.c.code)
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+            schema="CamelSchema",
+        )
+        idx = Index("test_active_code_idx", t.c.active, t.c.code)
         op_obj = ops.DropIndexOp.from_index(idx)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
-            "op.drop_index('test_active_code_idx', " +
-            "table_name='test', schema='CamelSchema')"
+            "op.drop_index('test_active_code_idx', "
+            + "table_name='test', schema='CamelSchema')",
         )
 
     def test_drop_index_schema_batch(self):
@@ -237,18 +273,20 @@ class AutogenRenderTest(TestBase):
         autogenerate.render._drop_index using schema
         """
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  schema='CamelSchema'
-                  )
-        idx = Index('test_active_code_idx', t.c.active, t.c.code)
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+            schema="CamelSchema",
+        )
+        idx = Index("test_active_code_idx", t.c.active, t.c.code)
         op_obj = ops.DropIndexOp.from_index(idx)
         with self.autogen_context._within_batch():
             eq_ignore_whitespace(
                 autogenerate.render_op_text(self.autogen_context, op_obj),
-                "batch_op.drop_index('test_active_code_idx')"
+                "batch_op.drop_index('test_active_code_idx')",
             )
 
     def test_add_unique_constraint(self):
@@ -256,16 +294,18 @@ class AutogenRenderTest(TestBase):
         autogenerate.render._add_unique_constraint
         """
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  )
-        uq = UniqueConstraint(t.c.code, name='uq_test_code')
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+        )
+        uq = UniqueConstraint(t.c.code, name="uq_test_code")
         op_obj = ops.AddConstraintOp.from_constraint(uq)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
-            "op.create_unique_constraint('uq_test_code', 'test', ['code'])"
+            "op.create_unique_constraint('uq_test_code', 'test', ['code'])",
         )
 
     def test_add_unique_constraint_batch(self):
@@ -273,17 +313,19 @@ class AutogenRenderTest(TestBase):
         autogenerate.render._add_unique_constraint
         """
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  )
-        uq = UniqueConstraint(t.c.code, name='uq_test_code')
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+        )
+        uq = UniqueConstraint(t.c.code, name="uq_test_code")
         op_obj = ops.AddConstraintOp.from_constraint(uq)
         with self.autogen_context._within_batch():
             eq_ignore_whitespace(
                 autogenerate.render_op_text(self.autogen_context, op_obj),
-                "batch_op.create_unique_constraint('uq_test_code', ['code'])"
+                "batch_op.create_unique_constraint('uq_test_code', ['code'])",
             )
 
     def test_add_unique_constraint_schema(self):
@@ -291,18 +333,20 @@ class AutogenRenderTest(TestBase):
         autogenerate.render._add_unique_constraint using schema
         """
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  schema='CamelSchema'
-                  )
-        uq = UniqueConstraint(t.c.code, name='uq_test_code')
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+            schema="CamelSchema",
+        )
+        uq = UniqueConstraint(t.c.code, name="uq_test_code")
         op_obj = ops.AddConstraintOp.from_constraint(uq)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_unique_constraint('uq_test_code', 'test', "
-            "['code'], schema='CamelSchema')"
+            "['code'], schema='CamelSchema')",
         )
 
     def test_add_unique_constraint_schema_batch(self):
@@ -310,19 +354,21 @@ class AutogenRenderTest(TestBase):
         autogenerate.render._add_unique_constraint using schema
         """
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  schema='CamelSchema'
-                  )
-        uq = UniqueConstraint(t.c.code, name='uq_test_code')
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+            schema="CamelSchema",
+        )
+        uq = UniqueConstraint(t.c.code, name="uq_test_code")
         op_obj = ops.AddConstraintOp.from_constraint(uq)
         with self.autogen_context._within_batch():
             eq_ignore_whitespace(
                 autogenerate.render_op_text(self.autogen_context, op_obj),
                 "batch_op.create_unique_constraint('uq_test_code', "
-                "['code'])"
+                "['code'])",
             )
 
     def test_drop_unique_constraint(self):
@@ -330,16 +376,18 @@ class AutogenRenderTest(TestBase):
         autogenerate.render._drop_constraint
         """
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  )
-        uq = UniqueConstraint(t.c.code, name='uq_test_code')
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+        )
+        uq = UniqueConstraint(t.c.code, name="uq_test_code")
         op_obj = ops.DropConstraintOp.from_constraint(uq)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
-            "op.drop_constraint('uq_test_code', 'test', type_='unique')"
+            "op.drop_constraint('uq_test_code', 'test', type_='unique')",
         )
 
     def test_drop_unique_constraint_schema(self):
@@ -347,67 +395,72 @@ class AutogenRenderTest(TestBase):
         autogenerate.render._drop_constraint using schema
         """
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  schema='CamelSchema'
-                  )
-        uq = UniqueConstraint(t.c.code, name='uq_test_code')
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+            schema="CamelSchema",
+        )
+        uq = UniqueConstraint(t.c.code, name="uq_test_code")
         op_obj = ops.DropConstraintOp.from_constraint(uq)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.drop_constraint('uq_test_code', 'test', "
-            "schema='CamelSchema', type_='unique')"
+            "schema='CamelSchema', type_='unique')",
         )
 
     def test_drop_unique_constraint_schema_reprobj(self):
         """
         autogenerate.render._drop_constraint using schema
         """
+
         class SomeObj(str):
             def __repr__(self):
                 return "foo.camel_schema"
 
         op_obj = ops.DropConstraintOp(
-            "uq_test_code", "test", type_="unique",
-            schema=SomeObj("CamelSchema")
+            "uq_test_code",
+            "test",
+            type_="unique",
+            schema=SomeObj("CamelSchema"),
         )
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.drop_constraint('uq_test_code', 'test', "
-            "schema=foo.camel_schema, type_='unique')"
+            "schema=foo.camel_schema, type_='unique')",
         )
 
     def test_add_fk_constraint(self):
         m = MetaData()
-        Table('a', m, Column('id', Integer, primary_key=True))
-        b = Table('b', m, Column('a_id', Integer, ForeignKey('a.id')))
-        fk = ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id')
+        Table("a", m, Column("id", Integer, primary_key=True))
+        b = Table("b", m, Column("a_id", Integer, ForeignKey("a.id")))
+        fk = ForeignKeyConstraint(["a_id"], ["a.id"], name="fk_a_id")
         b.append_constraint(fk)
         op_obj = ops.AddConstraintOp.from_constraint(fk)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
-            "op.create_foreign_key('fk_a_id', 'b', 'a', ['a_id'], ['id'])"
+            "op.create_foreign_key('fk_a_id', 'b', 'a', ['a_id'], ['id'])",
         )
 
     def test_add_fk_constraint_batch(self):
         m = MetaData()
-        Table('a', m, Column('id', Integer, primary_key=True))
-        b = Table('b', m, Column('a_id', Integer, ForeignKey('a.id')))
-        fk = ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id')
+        Table("a", m, Column("id", Integer, primary_key=True))
+        b = Table("b", m, Column("a_id", Integer, ForeignKey("a.id")))
+        fk = ForeignKeyConstraint(["a_id"], ["a.id"], name="fk_a_id")
         b.append_constraint(fk)
         op_obj = ops.AddConstraintOp.from_constraint(fk)
         with self.autogen_context._within_batch():
             eq_ignore_whitespace(
                 autogenerate.render_op_text(self.autogen_context, op_obj),
-                "batch_op.create_foreign_key('fk_a_id', 'a', ['a_id'], ['id'])"
+                "batch_op.create_foreign_key('fk_a_id', 'a', ['a_id'], ['id'])",
             )
 
     def test_add_fk_constraint_kwarg(self):
         m = MetaData()
-        t1 = Table('t', m, Column('c', Integer))
-        t2 = Table('t2', m, Column('c_rem', Integer))
+        t1 = Table("t", m, Column("c", Integer))
+        t2 = Table("t2", m, Column("c_rem", Integer))
 
         fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], onupdate="CASCADE")
 
@@ -417,11 +470,12 @@ class AutogenRenderTest(TestBase):
         op_obj = ops.AddConstraintOp.from_constraint(fk)
         eq_ignore_whitespace(
             re.sub(
-                r"u'", "'",
+                r"u'",
+                "'",
                 autogenerate.render_op_text(self.autogen_context, op_obj),
             ),
             "op.create_foreign_key(None, 't', 't2', ['c'], ['c_rem'], "
-            "onupdate='CASCADE')"
+            "onupdate='CASCADE')",
         )
 
         fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], ondelete="CASCADE")
@@ -429,54 +483,62 @@ class AutogenRenderTest(TestBase):
         op_obj = ops.AddConstraintOp.from_constraint(fk)
         eq_ignore_whitespace(
             re.sub(
-                r"u'", "'",
-                autogenerate.render_op_text(self.autogen_context, op_obj)
+                r"u'",
+                "'",
+                autogenerate.render_op_text(self.autogen_context, op_obj),
             ),
             "op.create_foreign_key(None, 't', 't2', ['c'], ['c_rem'], "
-            "ondelete='CASCADE')"
+            "ondelete='CASCADE')",
         )
 
         fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], deferrable=True)
         op_obj = ops.AddConstraintOp.from_constraint(fk)
         eq_ignore_whitespace(
             re.sub(
-                r"u'", "'",
-                autogenerate.render_op_text(self.autogen_context, op_obj)
+                r"u'",
+                "'",
+                autogenerate.render_op_text(self.autogen_context, op_obj),
             ),
             "op.create_foreign_key(None, 't', 't2', ['c'], ['c_rem'], "
-            "deferrable=True)"
+            "deferrable=True)",
         )
 
         fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], initially="XYZ")
         op_obj = ops.AddConstraintOp.from_constraint(fk)
         eq_ignore_whitespace(
             re.sub(
-                r"u'", "'",
+                r"u'",
+                "'",
                 autogenerate.render_op_text(self.autogen_context, op_obj),
             ),
             "op.create_foreign_key(None, 't', 't2', ['c'], ['c_rem'], "
-            "initially='XYZ')"
+            "initially='XYZ')",
         )
 
         fk = ForeignKeyConstraint(
-            [t1.c.c], [t2.c.c_rem],
-            initially="XYZ", ondelete="CASCADE", deferrable=True)
+            [t1.c.c],
+            [t2.c.c_rem],
+            initially="XYZ",
+            ondelete="CASCADE",
+            deferrable=True,
+        )
         op_obj = ops.AddConstraintOp.from_constraint(fk)
         eq_ignore_whitespace(
             re.sub(
-                r"u'", "'",
-                autogenerate.render_op_text(self.autogen_context, op_obj)
+                r"u'",
+                "'",
+                autogenerate.render_op_text(self.autogen_context, op_obj),
             ),
             "op.create_foreign_key(None, 't', 't2', ['c'], ['c_rem'], "
-            "ondelete='CASCADE', initially='XYZ', deferrable=True)"
+            "ondelete='CASCADE', initially='XYZ', deferrable=True)",
         )
 
     def test_add_fk_constraint_inline_colkeys(self):
         m = MetaData()
-        Table('a', m, Column('id', Integer, key='aid', primary_key=True))
+        Table("a", m, Column("id", Integer, key="aid", primary_key=True))
         b = Table(
-            'b', m,
-            Column('a_id', Integer, ForeignKey('a.aid'), key='baid'))
+            "b", m, Column("a_id", Integer, ForeignKey("a.aid"), key="baid")
+        )
 
         op_obj = ops.CreateTableOp.from_table(b)
         py_code = autogenerate.render_op_text(self.autogen_context, op_obj)
@@ -485,20 +547,21 @@ class AutogenRenderTest(TestBase):
             py_code,
             "op.create_table('b',"
             "sa.Column('a_id', sa.Integer(), nullable=True),"
-            "sa.ForeignKeyConstraint(['a_id'], ['a.id'], ))"
+            "sa.ForeignKeyConstraint(['a_id'], ['a.id'], ))",
         )
 
         context = op_fixture()
         eval(py_code)
         context.assert_(
             "CREATE TABLE b (a_id INTEGER, "
-            "FOREIGN KEY(a_id) REFERENCES a (id))")
+            "FOREIGN KEY(a_id) REFERENCES a (id))"
+        )
 
     def test_add_fk_constraint_separate_colkeys(self):
         m = MetaData()
-        Table('a', m, Column('id', Integer, key='aid', primary_key=True))
-        b = Table('b', m, Column('a_id', Integer, key='baid'))
-        fk = ForeignKeyConstraint(['baid'], ['a.aid'], name='fk_a_id')
+        Table("a", m, Column("id", Integer, key="aid", primary_key=True))
+        b = Table("b", m, Column("a_id", Integer, key="baid"))
+        fk = ForeignKeyConstraint(["baid"], ["a.aid"], name="fk_a_id")
         b.append_constraint(fk)
 
         op_obj = ops.CreateTableOp.from_table(b)
@@ -508,14 +571,15 @@ class AutogenRenderTest(TestBase):
             py_code,
             "op.create_table('b',"
             "sa.Column('a_id', sa.Integer(), nullable=True),"
-            "sa.ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id'))"
+            "sa.ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id'))",
         )
 
         context = op_fixture()
         eval(py_code)
         context.assert_(
             "CREATE TABLE b (a_id INTEGER, CONSTRAINT "
-            "fk_a_id FOREIGN KEY(a_id) REFERENCES a (id))")
+            "fk_a_id FOREIGN KEY(a_id) REFERENCES a (id))"
+        )
 
         context = op_fixture()
 
@@ -523,7 +587,7 @@ class AutogenRenderTest(TestBase):
 
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
-            "op.create_foreign_key('fk_a_id', 'b', 'a', ['a_id'], ['id'])"
+            "op.create_foreign_key('fk_a_id', 'b', 'a', ['a_id'], ['id'])",
         )
 
         py_code = autogenerate.render_op_text(self.autogen_context, op_obj)
@@ -531,124 +595,151 @@ class AutogenRenderTest(TestBase):
         eval(py_code)
         context.assert_(
             "ALTER TABLE b ADD CONSTRAINT fk_a_id "
-            "FOREIGN KEY(a_id) REFERENCES a (id)")
+            "FOREIGN KEY(a_id) REFERENCES a (id)"
+        )
 
     def test_add_fk_constraint_schema(self):
         m = MetaData()
         Table(
-            'a', m, Column('id', Integer, primary_key=True),
-            schema="CamelSchemaTwo")
+            "a",
+            m,
+            Column("id", Integer, primary_key=True),
+            schema="CamelSchemaTwo",
+        )
         b = Table(
-            'b', m, Column('a_id', Integer, ForeignKey('a.id')),
-            schema="CamelSchemaOne")
+            "b",
+            m,
+            Column("a_id", Integer, ForeignKey("a.id")),
+            schema="CamelSchemaOne",
+        )
         fk = ForeignKeyConstraint(
-            ["a_id"],
-            ["CamelSchemaTwo.a.id"], name='fk_a_id')
+            ["a_id"], ["CamelSchemaTwo.a.id"], name="fk_a_id"
+        )
         b.append_constraint(fk)
         op_obj = ops.AddConstraintOp.from_constraint(fk)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_foreign_key('fk_a_id', 'b', 'a', ['a_id'], ['id'],"
             " source_schema='CamelSchemaOne', "
-            "referent_schema='CamelSchemaTwo')"
+            "referent_schema='CamelSchemaTwo')",
         )
 
     def test_add_fk_constraint_schema_batch(self):
         m = MetaData()
         Table(
-            'a', m, Column('id', Integer, primary_key=True),
-            schema="CamelSchemaTwo")
+            "a",
+            m,
+            Column("id", Integer, primary_key=True),
+            schema="CamelSchemaTwo",
+        )
         b = Table(
-            'b', m, Column('a_id', Integer, ForeignKey('a.id')),
-            schema="CamelSchemaOne")
+            "b",
+            m,
+            Column("a_id", Integer, ForeignKey("a.id")),
+            schema="CamelSchemaOne",
+        )
         fk = ForeignKeyConstraint(
-            ["a_id"],
-            ["CamelSchemaTwo.a.id"], name='fk_a_id')
+            ["a_id"], ["CamelSchemaTwo.a.id"], name="fk_a_id"
+        )
         b.append_constraint(fk)
         op_obj = ops.AddConstraintOp.from_constraint(fk)
         with self.autogen_context._within_batch():
             eq_ignore_whitespace(
                 autogenerate.render_op_text(self.autogen_context, op_obj),
                 "batch_op.create_foreign_key('fk_a_id', 'a', ['a_id'], ['id'],"
-                " referent_schema='CamelSchemaTwo')"
+                " referent_schema='CamelSchemaTwo')",
             )
 
     def test_drop_fk_constraint(self):
         m = MetaData()
-        Table('a', m, Column('id', Integer, primary_key=True))
-        b = Table('b', m, Column('a_id', Integer, ForeignKey('a.id')))
-        fk = ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id')
+        Table("a", m, Column("id", Integer, primary_key=True))
+        b = Table("b", m, Column("a_id", Integer, ForeignKey("a.id")))
+        fk = ForeignKeyConstraint(["a_id"], ["a.id"], name="fk_a_id")
         b.append_constraint(fk)
         op_obj = ops.DropConstraintOp.from_constraint(fk)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
-            "op.drop_constraint('fk_a_id', 'b', type_='foreignkey')"
+            "op.drop_constraint('fk_a_id', 'b', type_='foreignkey')",
         )
 
     def test_drop_fk_constraint_batch(self):
         m = MetaData()
-        Table('a', m, Column('id', Integer, primary_key=True))
-        b = Table('b', m, Column('a_id', Integer, ForeignKey('a.id')))
-        fk = ForeignKeyConstraint(['a_id'], ['a.id'], name='fk_a_id')
+        Table("a", m, Column("id", Integer, primary_key=True))
+        b = Table("b", m, Column("a_id", Integer, ForeignKey("a.id")))
+        fk = ForeignKeyConstraint(["a_id"], ["a.id"], name="fk_a_id")
         b.append_constraint(fk)
         op_obj = ops.DropConstraintOp.from_constraint(fk)
         with self.autogen_context._within_batch():
             eq_ignore_whitespace(
                 autogenerate.render_op_text(self.autogen_context, op_obj),
-                "batch_op.drop_constraint('fk_a_id', type_='foreignkey')"
+                "batch_op.drop_constraint('fk_a_id', type_='foreignkey')",
             )
 
     def test_drop_fk_constraint_schema(self):
         m = MetaData()
         Table(
-            'a', m, Column('id', Integer, primary_key=True),
-            schema="CamelSchemaTwo")
+            "a",
+            m,
+            Column("id", Integer, primary_key=True),
+            schema="CamelSchemaTwo",
+        )
         b = Table(
-            'b', m, Column('a_id', Integer, ForeignKey('a.id')),
-            schema="CamelSchemaOne")
+            "b",
+            m,
+            Column("a_id", Integer, ForeignKey("a.id")),
+            schema="CamelSchemaOne",
+        )
         fk = ForeignKeyConstraint(
-            ["a_id"],
-            ["CamelSchemaTwo.a.id"], name='fk_a_id')
+            ["a_id"], ["CamelSchemaTwo.a.id"], name="fk_a_id"
+        )
         b.append_constraint(fk)
         op_obj = ops.DropConstraintOp.from_constraint(fk)
 
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.drop_constraint('fk_a_id', 'b', schema='CamelSchemaOne', "
-            "type_='foreignkey')"
+            "type_='foreignkey')",
         )
 
     def test_drop_fk_constraint_batch_schema(self):
         m = MetaData()
         Table(
-            'a', m, Column('id', Integer, primary_key=True),
-            schema="CamelSchemaTwo")
+            "a",
+            m,
+            Column("id", Integer, primary_key=True),
+            schema="CamelSchemaTwo",
+        )
         b = Table(
-            'b', m, Column('a_id', Integer, ForeignKey('a.id')),
-            schema="CamelSchemaOne")
+            "b",
+            m,
+            Column("a_id", Integer, ForeignKey("a.id")),
+            schema="CamelSchemaOne",
+        )
         fk = ForeignKeyConstraint(
-            ["a_id"],
-            ["CamelSchemaTwo.a.id"], name='fk_a_id')
+            ["a_id"], ["CamelSchemaTwo.a.id"], name="fk_a_id"
+        )
         b.append_constraint(fk)
         op_obj = ops.DropConstraintOp.from_constraint(fk)
 
         with self.autogen_context._within_batch():
             eq_ignore_whitespace(
                 autogenerate.render_op_text(self.autogen_context, op_obj),
-                "batch_op.drop_constraint('fk_a_id', type_='foreignkey')"
+                "batch_op.drop_constraint('fk_a_id', type_='foreignkey')",
             )
 
     def test_render_table_upgrade(self):
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('name', Unicode(255)),
-                  Column("address_id", Integer, ForeignKey("address.id")),
-                  Column("timestamp", DATETIME, server_default="NOW()"),
-                  Column("amount", Numeric(5, 2)),
-                  UniqueConstraint("name", name="uq_name"),
-                  UniqueConstraint("timestamp"),
-                  )
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("name", Unicode(255)),
+            Column("address_id", Integer, ForeignKey("address.id")),
+            Column("timestamp", DATETIME, server_default="NOW()"),
+            Column("amount", Numeric(5, 2)),
+            UniqueConstraint("name", name="uq_name"),
+            UniqueConstraint("timestamp"),
+        )
 
         op_obj = ops.CreateTableOp.from_table(t)
         eq_ignore_whitespace(
@@ -666,16 +757,18 @@ class AutogenRenderTest(TestBase):
             "sa.PrimaryKeyConstraint('id'),"
             "sa.UniqueConstraint('name', name='uq_name'),"
             "sa.UniqueConstraint('timestamp')"
-            ")"
+            ")",
         )
 
     def test_render_table_w_schema(self):
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('q', Integer, ForeignKey('address.id')),
-                  schema='foo'
-                  )
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("q", Integer, ForeignKey("address.id")),
+            schema="foo",
+        )
         op_obj = ops.CreateTableOp.from_table(t)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
@@ -685,84 +778,89 @@ class AutogenRenderTest(TestBase):
             "sa.ForeignKeyConstraint(['q'], ['address.id'], ),"
             "sa.PrimaryKeyConstraint('id'),"
             "schema='foo'"
-            ")"
+            ")",
         )
 
     def test_render_table_w_system(self):
         m = MetaData()
-        t = Table('sometable', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('xmin', Integer, system=True, nullable=False)
-                  )
+        t = Table(
+            "sometable",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("xmin", Integer, system=True, nullable=False),
+        )
         op_obj = ops.CreateTableOp.from_table(t)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_table('sometable',"
             "sa.Column('id', sa.Integer(), nullable=False),"
             "sa.Column('xmin', sa.Integer(), nullable=False, system=True),"
-            "sa.PrimaryKeyConstraint('id'))"
+            "sa.PrimaryKeyConstraint('id'))",
         )
 
     def test_render_table_w_unicode_name(self):
         m = MetaData()
-        t = Table(compat.ue('\u0411\u0435\u0437'), m,
-                  Column('id', Integer, primary_key=True),
-                  )
+        t = Table(
+            compat.ue("\u0411\u0435\u0437"),
+            m,
+            Column("id", Integer, primary_key=True),
+        )
         op_obj = ops.CreateTableOp.from_table(t)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_table(%r,"
             "sa.Column('id', sa.Integer(), nullable=False),"
-            "sa.PrimaryKeyConstraint('id'))" % compat.ue('\u0411\u0435\u0437')
+            "sa.PrimaryKeyConstraint('id'))" % compat.ue("\u0411\u0435\u0437"),
         )
 
     def test_render_table_w_unicode_schema(self):
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  schema=compat.ue('\u0411\u0435\u0437')
-                  )
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            schema=compat.ue("\u0411\u0435\u0437"),
+        )
         op_obj = ops.CreateTableOp.from_table(t)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_table('test',"
             "sa.Column('id', sa.Integer(), nullable=False),"
             "sa.PrimaryKeyConstraint('id'),"
-            "schema=%r)" % compat.ue('\u0411\u0435\u0437')
+            "schema=%r)" % compat.ue("\u0411\u0435\u0437"),
         )
 
     def test_render_table_w_unsupported_constraint(self):
         from sqlalchemy.sql.schema import ColumnCollectionConstraint
 
         class SomeCustomConstraint(ColumnCollectionConstraint):
-            __visit_name__ = 'some_custom'
+            __visit_name__ = "some_custom"
 
         m = MetaData()
 
-        t = Table(
-            't', m, Column('id', Integer),
-            SomeCustomConstraint('id'),
-        )
+        t = Table("t", m, Column("id", Integer), SomeCustomConstraint("id"))
         op_obj = ops.CreateTableOp.from_table(t)
         with assertions.expect_warnings(
-                "No renderer is established for object SomeCustomConstraint"):
+            "No renderer is established for object SomeCustomConstraint"
+        ):
             eq_ignore_whitespace(
                 autogenerate.render_op_text(self.autogen_context, op_obj),
                 "op.create_table('t',"
                 "sa.Column('id', sa.Integer(), nullable=True),"
                 "[Unknown Python object "
-                "SomeCustomConstraint(Column('id', Integer(), table=<t>))])"
+                "SomeCustomConstraint(Column('id', Integer(), table=<t>))])",
             )
 
     @patch("alembic.autogenerate.render.MAX_PYTHON_ARGS", 3)
     def test_render_table_max_cols(self):
         m = MetaData()
         t = Table(
-            'test', m,
-            Column('a', Integer),
-            Column('b', Integer),
-            Column('c', Integer),
-            Column('d', Integer),
+            "test",
+            m,
+            Column("a", Integer),
+            Column("b", Integer),
+            Column("c", Integer),
+            Column("d", Integer),
         )
         op_obj = ops.CreateTableOp.from_table(t)
         eq_ignore_whitespace(
@@ -771,14 +869,15 @@ class AutogenRenderTest(TestBase):
             "*[sa.Column('a', sa.Integer(), nullable=True),"
             "sa.Column('b', sa.Integer(), nullable=True),"
             "sa.Column('c', sa.Integer(), nullable=True),"
-            "sa.Column('d', sa.Integer(), nullable=True)])"
+            "sa.Column('d', sa.Integer(), nullable=True)])",
         )
 
         t2 = Table(
-            'test2', m,
-            Column('a', Integer),
-            Column('b', Integer),
-            Column('c', Integer),
+            "test2",
+            m,
+            Column("a", Integer),
+            Column("b", Integer),
+            Column("c", Integer),
         )
         op_obj = ops.CreateTableOp.from_table(t2)
 
@@ -787,15 +886,17 @@ class AutogenRenderTest(TestBase):
             "op.create_table('test2',"
             "sa.Column('a', sa.Integer(), nullable=True),"
             "sa.Column('b', sa.Integer(), nullable=True),"
-            "sa.Column('c', sa.Integer(), nullable=True))"
+            "sa.Column('c', sa.Integer(), nullable=True))",
         )
 
     def test_render_table_w_fk_schema(self):
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('q', Integer, ForeignKey('foo.address.id')),
-                  )
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("q", Integer, ForeignKey("foo.address.id")),
+        )
         op_obj = ops.CreateTableOp.from_table(t)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
@@ -804,20 +905,23 @@ class AutogenRenderTest(TestBase):
             "sa.Column('q', sa.Integer(), nullable=True),"
             "sa.ForeignKeyConstraint(['q'], ['foo.address.id'], ),"
             "sa.PrimaryKeyConstraint('id')"
-            ")"
+            ")",
         )
 
     def test_render_table_w_metadata_schema(self):
         m = MetaData(schema="foo")
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('q', Integer, ForeignKey('address.id')),
-                  )
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("q", Integer, ForeignKey("address.id")),
+        )
         op_obj = ops.CreateTableOp.from_table(t)
         eq_ignore_whitespace(
             re.sub(
-                r"u'", "'",
-                autogenerate.render_op_text(self.autogen_context, op_obj)
+                r"u'",
+                "'",
+                autogenerate.render_op_text(self.autogen_context, op_obj),
             ),
             "op.create_table('test',"
             "sa.Column('id', sa.Integer(), nullable=False),"
@@ -825,15 +929,17 @@ class AutogenRenderTest(TestBase):
             "sa.ForeignKeyConstraint(['q'], ['foo.address.id'], ),"
             "sa.PrimaryKeyConstraint('id'),"
             "schema='foo'"
-            ")"
+            ")",
         )
 
     def test_render_table_w_metadata_schema_override(self):
         m = MetaData(schema="foo")
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('q', Integer, ForeignKey('bar.address.id')),
-                  )
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("q", Integer, ForeignKey("bar.address.id")),
+        )
         op_obj = ops.CreateTableOp.from_table(t)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
@@ -843,16 +949,19 @@ class AutogenRenderTest(TestBase):
             "sa.ForeignKeyConstraint(['q'], ['bar.address.id'], ),"
             "sa.PrimaryKeyConstraint('id'),"
             "schema='foo'"
-            ")"
+            ")",
         )
 
     def test_render_addtl_args(self):
         m = MetaData()
-        t = Table('test', m,
-                  Column('id', Integer, primary_key=True),
-                  Column('q', Integer, ForeignKey('bar.address.id')),
-                  sqlite_autoincrement=True, mysql_engine="InnoDB"
-                  )
+        t = Table(
+            "test",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("q", Integer, ForeignKey("bar.address.id")),
+            sqlite_autoincrement=True,
+            mysql_engine="InnoDB",
+        )
         op_obj = ops.CreateTableOp.from_table(t)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
@@ -861,60 +970,58 @@ class AutogenRenderTest(TestBase):
             "sa.Column('q', sa.Integer(), nullable=True),"
             "sa.ForeignKeyConstraint(['q'], ['bar.address.id'], ),"
             "sa.PrimaryKeyConstraint('id'),"
-            "mysql_engine='InnoDB',sqlite_autoincrement=True)"
+            "mysql_engine='InnoDB',sqlite_autoincrement=True)",
         )
 
     def test_render_drop_table(self):
-        op_obj = ops.DropTableOp.from_table(
-            Table("sometable", MetaData())
-        )
+        op_obj = ops.DropTableOp.from_table(Table("sometable", MetaData()))
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
-            "op.drop_table('sometable')"
+            "op.drop_table('sometable')",
         )
 
     def test_render_drop_table_w_schema(self):
         op_obj = ops.DropTableOp.from_table(
-            Table("sometable", MetaData(), schema='foo')
+            Table("sometable", MetaData(), schema="foo")
         )
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
-            "op.drop_table('sometable', schema='foo')"
+            "op.drop_table('sometable', schema='foo')",
         )
 
     def test_render_table_no_implicit_check(self):
         m = MetaData()
-        t = Table('test', m, Column('x', Boolean()))
+        t = Table("test", m, Column("x", Boolean()))
 
         op_obj = ops.CreateTableOp.from_table(t)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_table('test',"
-            "sa.Column('x', sa.Boolean(), nullable=True))"
+            "sa.Column('x', sa.Boolean(), nullable=True))",
         )
 
     def test_render_pk_with_col_name_vs_col_key(self):
         m = MetaData()
-        t1 = Table('t1', m, Column('x', Integer, key='y', primary_key=True))
+        t1 = Table("t1", m, Column("x", Integer, key="y", primary_key=True))
 
         op_obj = ops.CreateTableOp.from_table(t1)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_table('t1',"
             "sa.Column('x', sa.Integer(), nullable=False),"
-            "sa.PrimaryKeyConstraint('x'))"
+            "sa.PrimaryKeyConstraint('x'))",
         )
 
     def test_render_empty_pk_vs_nonempty_pk(self):
         m = MetaData()
-        t1 = Table('t1', m, Column('x', Integer))
-        t2 = Table('t2', m, Column('x', Integer, primary_key=True))
+        t1 = Table("t1", m, Column("x", Integer))
+        t2 = Table("t2", m, Column("x", Integer, primary_key=True))
 
         op_obj = ops.CreateTableOp.from_table(t1)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_table('t1',"
-            "sa.Column('x', sa.Integer(), nullable=True))"
+            "sa.Column('x', sa.Integer(), nullable=True))",
         )
 
         op_obj = ops.CreateTableOp.from_table(t2)
@@ -922,16 +1029,18 @@ class AutogenRenderTest(TestBase):
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_table('t2',"
             "sa.Column('x', sa.Integer(), nullable=False),"
-            "sa.PrimaryKeyConstraint('x'))"
+            "sa.PrimaryKeyConstraint('x'))",
         )
 
     @config.requirements.fail_before_sqla_110
     def test_render_table_w_autoincrement(self):
         m = MetaData()
         t = Table(
-            'test', m,
-            Column('id1', Integer, primary_key=True),
-            Column('id2', Integer, primary_key=True, autoincrement=True))
+            "test",
+            m,
+            Column("id1", Integer, primary_key=True),
+            Column("id2", Integer, primary_key=True, autoincrement=True),
+        )
         op_obj = ops.CreateTableOp.from_table(t)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
@@ -940,111 +1049,109 @@ class AutogenRenderTest(TestBase):
             "sa.Column('id2', sa.Integer(), autoincrement=True, "
             "nullable=False),"
             "sa.PrimaryKeyConstraint('id1', 'id2')"
-            ")"
+            ")",
         )
 
     def test_render_add_column(self):
         op_obj = ops.AddColumnOp(
-            "foo", Column("x", Integer, server_default="5"))
+            "foo", Column("x", Integer, server_default="5")
+        )
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.add_column('foo', sa.Column('x', sa.Integer(), "
-            "server_default='5', nullable=True))"
+            "server_default='5', nullable=True))",
         )
 
     def test_render_add_column_system(self):
         # this would never actually happen since "system" columns
         # can't be added in any case.   Howver it will render as
         # part of op.CreateTableOp.
-        op_obj = ops.AddColumnOp(
-            "foo", Column("xmin", Integer, system=True))
+        op_obj = ops.AddColumnOp("foo", Column("xmin", Integer, system=True))
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.add_column('foo', sa.Column('xmin', sa.Integer(), "
-            "nullable=True, system=True))"
+            "nullable=True, system=True))",
         )
 
     def test_render_add_column_w_schema(self):
         op_obj = ops.AddColumnOp(
-            "bar", Column("x", Integer, server_default="5"),
-            schema="foo")
+            "bar", Column("x", Integer, server_default="5"), schema="foo"
+        )
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.add_column('bar', sa.Column('x', sa.Integer(), "
-            "server_default='5', nullable=True), schema='foo')"
+            "server_default='5', nullable=True), schema='foo')",
         )
 
     def test_render_drop_column(self):
         op_obj = ops.DropColumnOp.from_column_and_tablename(
-            None, "foo", Column("x", Integer, server_default="5"))
+            None, "foo", Column("x", Integer, server_default="5")
+        )
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
-            "op.drop_column('foo', 'x')"
+            "op.drop_column('foo', 'x')",
         )
 
     def test_render_drop_column_w_schema(self):
         op_obj = ops.DropColumnOp.from_column_and_tablename(
-            "foo", "bar", Column("x", Integer, server_default="5"))
+            "foo", "bar", Column("x", Integer, server_default="5")
+        )
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
-            "op.drop_column('bar', 'x', schema='foo')"
+            "op.drop_column('bar', 'x', schema='foo')",
         )
 
     def test_render_quoted_server_default(self):
         eq_(
             autogenerate.render._render_server_default(
                 "nextval('group_to_perm_group_to_perm_id_seq'::regclass)",
-                self.autogen_context),
-            '"nextval(\'group_to_perm_group_to_perm_id_seq\'::regclass)"'
+                self.autogen_context,
+            ),
+            "\"nextval('group_to_perm_group_to_perm_id_seq'::regclass)\"",
         )
 
     def test_render_unicode_server_default(self):
         default = compat.ue(
-            '\u0411\u0435\u0437 '
-            '\u043d\u0430\u0437\u0432\u0430\u043d\u0438\u044f'
+            "\u0411\u0435\u0437 "
+            "\u043d\u0430\u0437\u0432\u0430\u043d\u0438\u044f"
         )
 
-        c = Column(
-            'x', Unicode,
-            server_default=text(default)
-        )
+        c = Column("x", Unicode, server_default=text(default))
 
         eq_ignore_whitespace(
             autogenerate.render._render_server_default(
                 c.server_default, self.autogen_context
             ),
-            "sa.text(%r)" % default
+            "sa.text(%r)" % default,
         )
 
     def test_render_col_with_server_default(self):
-        c = Column('updated_at', TIMESTAMP(),
-                   server_default='TIMEZONE("utc", CURRENT_TIMESTAMP)',
-                   nullable=False)
-        result = autogenerate.render._render_column(
-            c, self.autogen_context
+        c = Column(
+            "updated_at",
+            TIMESTAMP(),
+            server_default='TIMEZONE("utc", CURRENT_TIMESTAMP)',
+            nullable=False,
         )
+        result = autogenerate.render._render_column(c, self.autogen_context)
         eq_ignore_whitespace(
             result,
-            'sa.Column(\'updated_at\', sa.TIMESTAMP(), '
-            'server_default=\'TIMEZONE("utc", CURRENT_TIMESTAMP)\', '
-            'nullable=False)'
+            "sa.Column('updated_at', sa.TIMESTAMP(), "
+            "server_default='TIMEZONE(\"utc\", CURRENT_TIMESTAMP)', "
+            "nullable=False)",
         )
 
     def test_render_col_autoinc_false_mysql(self):
-        c = Column('some_key', Integer, primary_key=True, autoincrement=False)
-        Table('some_table', MetaData(), c)
-        result = autogenerate.render._render_column(
-            c, self.autogen_context
-        )
+        c = Column("some_key", Integer, primary_key=True, autoincrement=False)
+        Table("some_table", MetaData(), c)
+        result = autogenerate.render._render_column(c, self.autogen_context)
         eq_ignore_whitespace(
             result,
-            'sa.Column(\'some_key\', sa.Integer(), '
-            'autoincrement=False, '
-            'nullable=False)'
+            "sa.Column('some_key', sa.Integer(), "
+            "autoincrement=False, "
+            "nullable=False)",
         )
 
     def test_render_custom(self):
-
         class MySpecialType(Integer):
             pass
 
@@ -1065,17 +1172,18 @@ class AutogenRenderTest(TestBase):
             return "render:%s" % type_
 
         self.autogen_context.opts.update(
-            render_item=render,
-            alembic_module_prefix='sa.'
+            render_item=render, alembic_module_prefix="sa."
         )
 
-        t = Table('t', MetaData(),
-                  Column('x', Integer),
-                  Column('y', Integer),
-                  Column('q', MySpecialType()),
-                  PrimaryKeyConstraint('x'),
-                  ForeignKeyConstraint(['x'], ['y'])
-                  )
+        t = Table(
+            "t",
+            MetaData(),
+            Column("x", Integer),
+            Column("y", Integer),
+            Column("q", MySpecialType()),
+            PrimaryKeyConstraint("x"),
+            ForeignKeyConstraint(["x"], ["y"]),
+        )
         op_obj = ops.CreateTableOp.from_table(t)
         result = autogenerate.render_op_text(self.autogen_context, op_obj)
         eq_ignore_whitespace(
@@ -1083,89 +1191,97 @@ class AutogenRenderTest(TestBase):
             "sa.create_table('t',"
             "col(x),"
             "sa.Column('q', MySpecialType(), nullable=True),"
-            "render:primary_key)"
+            "render:primary_key)",
         )
         eq_(
             self.autogen_context.imports,
-            set(['from mypackage import MySpecialType'])
+            set(["from mypackage import MySpecialType"]),
         )
 
     def test_render_modify_type(self):
         op_obj = ops.AlterColumnOp(
-            "sometable", "somecolumn",
-            modify_type=CHAR(10), existing_type=CHAR(20)
+            "sometable",
+            "somecolumn",
+            modify_type=CHAR(10),
+            existing_type=CHAR(20),
         )
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.alter_column('sometable', 'somecolumn', "
-            "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10))"
+            "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10))",
         )
 
     def test_render_modify_type_w_schema(self):
         op_obj = ops.AlterColumnOp(
-            "sometable", "somecolumn",
-            modify_type=CHAR(10), existing_type=CHAR(20),
-            schema='foo'
+            "sometable",
+            "somecolumn",
+            modify_type=CHAR(10),
+            existing_type=CHAR(20),
+            schema="foo",
         )
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.alter_column('sometable', 'somecolumn', "
             "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10), "
-            "schema='foo')"
+            "schema='foo')",
         )
 
     def test_render_modify_nullable(self):
         op_obj = ops.AlterColumnOp(
-            "sometable", "somecolumn",
+            "sometable",
+            "somecolumn",
             existing_type=Integer(),
-            modify_nullable=True
+            modify_nullable=True,
         )
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.alter_column('sometable', 'somecolumn', "
-            "existing_type=sa.Integer(), nullable=True)"
+            "existing_type=sa.Integer(), nullable=True)",
         )
 
     def test_render_modify_nullable_no_existing_type(self):
         op_obj = ops.AlterColumnOp(
-            "sometable", "somecolumn",
-            modify_nullable=True
+            "sometable", "somecolumn", modify_nullable=True
         )
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
-            "op.alter_column('sometable', 'somecolumn', nullable=True)"
+            "op.alter_column('sometable', 'somecolumn', nullable=True)",
         )
 
     def test_render_modify_nullable_w_schema(self):
         op_obj = ops.AlterColumnOp(
-            "sometable", "somecolumn",
+            "sometable",
+            "somecolumn",
             existing_type=Integer(),
-            modify_nullable=True, schema='foo'
+            modify_nullable=True,
+            schema="foo",
         )
 
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.alter_column('sometable', 'somecolumn', "
-            "existing_type=sa.Integer(), nullable=True, schema='foo')"
+            "existing_type=sa.Integer(), nullable=True, schema='foo')",
         )
 
     def test_render_modify_type_w_autoincrement(self):
         op_obj = ops.AlterColumnOp(
-            "sometable", "somecolumn",
-            modify_type=Integer(), existing_type=BigInteger(),
-            autoincrement=True
+            "sometable",
+            "somecolumn",
+            modify_type=Integer(),
+            existing_type=BigInteger(),
+            autoincrement=True,
         )
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.alter_column('sometable', 'somecolumn', "
             "existing_type=sa.BigInteger(), type_=sa.Integer(), "
-            "autoincrement=True)"
+            "autoincrement=True)",
         )
 
     def test_render_fk_constraint_kwarg(self):
         m = MetaData()
-        t1 = Table('t', m, Column('c', Integer))
-        t2 = Table('t2', m, Column('c_rem', Integer))
+        t1 = Table("t", m, Column("c", Integer))
+        t2 = Table("t2", m, Column("c_rem", Integer))
 
         fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], onupdate="CASCADE")
 
@@ -1174,239 +1290,274 @@ class AutogenRenderTest(TestBase):
 
         eq_ignore_whitespace(
             re.sub(
-                r"u'", "'",
+                r"u'",
+                "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context)),
-            "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], onupdate='CASCADE')"
+                    fk, self.autogen_context
+                ),
+            ),
+            "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], onupdate='CASCADE')",
         )
 
         fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], ondelete="CASCADE")
 
         eq_ignore_whitespace(
             re.sub(
-                r"u'", "'",
+                r"u'",
+                "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context)),
-            "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], ondelete='CASCADE')"
+                    fk, self.autogen_context
+                ),
+            ),
+            "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], ondelete='CASCADE')",
         )
 
         fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], deferrable=True)
         eq_ignore_whitespace(
             re.sub(
-                r"u'", "'",
+                r"u'",
+                "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context),
+                    fk, self.autogen_context
+                ),
             ),
-            "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], deferrable=True)"
+            "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], deferrable=True)",
         )
 
         fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], initially="XYZ")
         eq_ignore_whitespace(
             re.sub(
-                r"u'", "'",
+                r"u'",
+                "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context)
+                    fk, self.autogen_context
+                ),
             ),
-            "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], initially='XYZ')"
+            "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], initially='XYZ')",
         )
 
         fk = ForeignKeyConstraint(
-            [t1.c.c], [t2.c.c_rem],
-            initially="XYZ", ondelete="CASCADE", deferrable=True)
+            [t1.c.c],
+            [t2.c.c_rem],
+            initially="XYZ",
+            ondelete="CASCADE",
+            deferrable=True,
+        )
         eq_ignore_whitespace(
             re.sub(
-                r"u'", "'",
+                r"u'",
+                "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context)
+                    fk, self.autogen_context
+                ),
             ),
             "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], "
-            "ondelete='CASCADE', initially='XYZ', deferrable=True)"
+            "ondelete='CASCADE', initially='XYZ', deferrable=True)",
         )
 
     def test_render_fk_constraint_resolve_key(self):
         m = MetaData()
-        t1 = Table('t', m, Column('c', Integer))
-        t2 = Table('t2', m, Column('c_rem', Integer, key='c_remkey'))
+        t1 = Table("t", m, Column("c", Integer))
+        t2 = Table("t2", m, Column("c_rem", Integer, key="c_remkey"))
 
-        fk = ForeignKeyConstraint(['c'], ['t2.c_remkey'])
+        fk = ForeignKeyConstraint(["c"], ["t2.c_remkey"])
         t1.append_constraint(fk)
 
         eq_ignore_whitespace(
             re.sub(
-                r"u'", "'",
+                r"u'",
+                "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context)),
-            "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )"
+                    fk, self.autogen_context
+                ),
+            ),
+            "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )",
         )
 
     def test_render_fk_constraint_bad_table_resolve(self):
         m = MetaData()
-        t1 = Table('t', m, Column('c', Integer))
-        t2 = Table('t2', m, Column('c_rem', Integer))
+        t1 = Table("t", m, Column("c", Integer))
+        t2 = Table("t2", m, Column("c_rem", Integer))
 
-        fk = ForeignKeyConstraint(['c'], ['t2.nonexistent'])
+        fk = ForeignKeyConstraint(["c"], ["t2.nonexistent"])
         t1.append_constraint(fk)
 
         eq_ignore_whitespace(
             re.sub(
-                r"u'", "'",
+                r"u'",
+                "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context)),
-            "sa.ForeignKeyConstraint(['c'], ['t2.nonexistent'], )"
+                    fk, self.autogen_context
+                ),
+            ),
+            "sa.ForeignKeyConstraint(['c'], ['t2.nonexistent'], )",
         )
 
     def test_render_fk_constraint_bad_table_resolve_dont_get_confused(self):
         m = MetaData()
-        t1 = Table('t', m, Column('c', Integer))
+        t1 = Table("t", m, Column("c", Integer))
         t2 = Table(
-            't2', m,
-            Column('c_rem', Integer, key='cr_key'),
-            Column('c_rem_2', Integer, key='c_rem')
-
+            "t2",
+            m,
+            Column("c_rem", Integer, key="cr_key"),
+            Column("c_rem_2", Integer, key="c_rem"),
         )
 
-        fk = ForeignKeyConstraint(['c'], ['t2.c_rem'], link_to_name=True)
+        fk = ForeignKeyConstraint(["c"], ["t2.c_rem"], link_to_name=True)
         t1.append_constraint(fk)
 
         eq_ignore_whitespace(
             re.sub(
-                r"u'", "'",
+                r"u'",
+                "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context)),
-            "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )"
+                    fk, self.autogen_context
+                ),
+            ),
+            "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )",
         )
 
     def test_render_fk_constraint_link_to_name(self):
         m = MetaData()
-        t1 = Table('t', m, Column('c', Integer))
-        t2 = Table('t2', m, Column('c_rem', Integer, key='c_remkey'))
+        t1 = Table("t", m, Column("c", Integer))
+        t2 = Table("t2", m, Column("c_rem", Integer, key="c_remkey"))
 
-        fk = ForeignKeyConstraint(['c'], ['t2.c_rem'], link_to_name=True)
+        fk = ForeignKeyConstraint(["c"], ["t2.c_rem"], link_to_name=True)
         t1.append_constraint(fk)
 
         eq_ignore_whitespace(
             re.sub(
-                r"u'", "'",
+                r"u'",
+                "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context)),
-            "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )"
+                    fk, self.autogen_context
+                ),
+            ),
+            "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )",
         )
 
     def test_render_fk_constraint_use_alter(self):
         m = MetaData()
-        Table('t', m, Column('c', Integer))
+        Table("t", m, Column("c", Integer))
         t2 = Table(
-            't2', m,
+            "t2",
+            m,
             Column(
-                'c_rem', Integer,
-                ForeignKey('t.c', name="fk1", use_alter=True)))
+                "c_rem", Integer, ForeignKey("t.c", name="fk1", use_alter=True)
+            ),
+        )
         const = list(t2.foreign_keys)[0].constraint
 
         eq_ignore_whitespace(
             autogenerate.render._render_constraint(
-                const, self.autogen_context),
+                const, self.autogen_context
+            ),
             "sa.ForeignKeyConstraint(['c_rem'], ['t.c'], "
-            "name='fk1', use_alter=True)"
+            "name='fk1', use_alter=True)",
         )
 
     def test_render_fk_constraint_w_metadata_schema(self):
         m = MetaData(schema="foo")
-        t1 = Table('t', m, Column('c', Integer))
-        t2 = Table('t2', m, Column('c_rem', Integer))
+        t1 = Table("t", m, Column("c", Integer))
+        t2 = Table("t2", m, Column("c_rem", Integer))
 
         fk = ForeignKeyConstraint([t1.c.c], [t2.c.c_rem], onupdate="CASCADE")
 
         eq_ignore_whitespace(
             re.sub(
-                r"u'", "'",
+                r"u'",
+                "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context)
+                    fk, self.autogen_context
+                ),
             ),
             "sa.ForeignKeyConstraint(['c'], ['foo.t2.c_rem'], "
-            "onupdate='CASCADE')"
+            "onupdate='CASCADE')",
         )
 
     def test_render_check_constraint_literal(self):
         eq_ignore_whitespace(
             autogenerate.render._render_check_constraint(
-                CheckConstraint("im a constraint", name='cc1'),
-                self.autogen_context
+                CheckConstraint("im a constraint", name="cc1"),
+                self.autogen_context,
             ),
-            "sa.CheckConstraint(!U'im a constraint', name='cc1')"
+            "sa.CheckConstraint(!U'im a constraint', name='cc1')",
         )
 
     def test_render_check_constraint_sqlexpr(self):
-        c = column('c')
-        five = literal_column('5')
-        ten = literal_column('10')
+        c = column("c")
+        five = literal_column("5")
+        ten = literal_column("10")
         eq_ignore_whitespace(
             autogenerate.render._render_check_constraint(
-                CheckConstraint(and_(c > five, c < ten)),
-                self.autogen_context
+                CheckConstraint(and_(c > five, c < ten)), self.autogen_context
             ),
-            "sa.CheckConstraint(!U'c > 5 AND c < 10')"
+            "sa.CheckConstraint(!U'c > 5 AND c < 10')",
         )
 
     def test_render_check_constraint_literal_binds(self):
-        c = column('c')
+        c = column("c")
         eq_ignore_whitespace(
             autogenerate.render._render_check_constraint(
-                CheckConstraint(and_(c > 5, c < 10)),
-                self.autogen_context
+                CheckConstraint(and_(c > 5, c < 10)), self.autogen_context
             ),
-            "sa.CheckConstraint(!U'c > 5 AND c < 10')"
+            "sa.CheckConstraint(!U'c > 5 AND c < 10')",
         )
 
     def test_render_unique_constraint_opts(self):
         m = MetaData()
-        t = Table('t', m, Column('c', Integer))
+        t = Table("t", m, Column("c", Integer))
         eq_ignore_whitespace(
             autogenerate.render._render_unique_constraint(
-                UniqueConstraint(t.c.c, name='uq_1', deferrable='XYZ'),
-                self.autogen_context
+                UniqueConstraint(t.c.c, name="uq_1", deferrable="XYZ"),
+                self.autogen_context,
             ),
-            "sa.UniqueConstraint('c', deferrable='XYZ', name='uq_1')"
+            "sa.UniqueConstraint('c', deferrable='XYZ', name='uq_1')",
         )
 
     def test_add_unique_constraint_unicode_schema(self):
         m = MetaData()
         t = Table(
-            't', m, Column('c', Integer),
-            schema=compat.ue('\u0411\u0435\u0437')
+            "t",
+            m,
+            Column("c", Integer),
+            schema=compat.ue("\u0411\u0435\u0437"),
         )
         op_obj = ops.AddConstraintOp.from_constraint(UniqueConstraint(t.c.c))
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_unique_constraint(None, 't', ['c'], "
-            "schema=%r)" % compat.ue('\u0411\u0435\u0437')
+            "schema=%r)" % compat.ue("\u0411\u0435\u0437"),
         )
 
     def test_render_modify_nullable_w_default(self):
         op_obj = ops.AlterColumnOp(
-            "sometable", "somecolumn",
+            "sometable",
+            "somecolumn",
             existing_type=Integer(),
             existing_server_default="5",
-            modify_nullable=True
+            modify_nullable=True,
         )
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.alter_column('sometable', 'somecolumn', "
             "existing_type=sa.Integer(), nullable=True, "
-            "existing_server_default='5')"
+            "existing_server_default='5')",
         )
 
     def test_render_enum(self):
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
                 Enum("one", "two", "three", name="myenum"),
-                self.autogen_context),
-            "sa.Enum('one', 'two', 'three', name='myenum')"
+                self.autogen_context,
+            ),
+            "sa.Enum('one', 'two', 'three', name='myenum')",
         )
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
-                Enum("one", "two", "three"),
-                self.autogen_context),
-            "sa.Enum('one', 'two', 'three')"
+                Enum("one", "two", "three"), self.autogen_context
+            ),
+            "sa.Enum('one', 'two', 'three')",
         )
 
     @config.requirements.sqlalchemy_099
@@ -1414,15 +1565,16 @@ class AutogenRenderTest(TestBase):
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
                 Enum("one", "two", "three", native_enum=False),
-                self.autogen_context),
-            "sa.Enum('one', 'two', 'three', native_enum=False)"
+                self.autogen_context,
+            ),
+            "sa.Enum('one', 'two', 'three', native_enum=False)",
         )
 
     def test_repr_plain_sqla_type(self):
         type_ = Integer()
         eq_ignore_whitespace(
             autogenerate.render._repr_type(type_, self.autogen_context),
-            "sa.Integer()"
+            "sa.Integer()",
         )
 
     @config.requirements.sqlalchemy_110
@@ -1430,24 +1582,27 @@ class AutogenRenderTest(TestBase):
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
-                types.ARRAY(Integer), self.autogen_context),
-            "sa.ARRAY(sa.Integer())"
+                types.ARRAY(Integer), self.autogen_context
+            ),
+            "sa.ARRAY(sa.Integer())",
         )
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
-                types.ARRAY(DateTime(timezone=True)), self.autogen_context),
-            "sa.ARRAY(sa.DateTime(timezone=True))"
+                types.ARRAY(DateTime(timezone=True)), self.autogen_context
+            ),
+            "sa.ARRAY(sa.DateTime(timezone=True))",
         )
 
     @config.requirements.sqlalchemy_110
     def test_render_array_no_context(self):
-        uo = ops.UpgradeOps(ops=[
-            ops.CreateTableOp(
-                "sometable",
-                [Column('x', types.ARRAY(Integer))]
-            )
-        ])
+        uo = ops.UpgradeOps(
+            ops=[
+                ops.CreateTableOp(
+                    "sometable", [Column("x", types.ARRAY(Integer))]
+                )
+            ]
+        )
 
         eq_(
             autogenerate.render_python_code(uo),
@@ -1455,11 +1610,11 @@ class AutogenRenderTest(TestBase):
             "    op.create_table('sometable',\n"
             "    sa.Column('x', sa.ARRAY(sa.Integer()), nullable=True)\n"
             "    )\n"
-            "    # ### end Alembic commands ###"
+            "    # ### end Alembic commands ###",
         )
 
     def test_repr_custom_type_w_sqla_prefix(self):
-        self.autogen_context.opts['user_module_prefix'] = None
+        self.autogen_context.opts["user_module_prefix"] = None
 
         class MyType(UserDefinedType):
             pass
@@ -1470,283 +1625,280 @@ class AutogenRenderTest(TestBase):
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(type_, self.autogen_context),
-            "sqlalchemy_util.types.MyType()"
+            "sqlalchemy_util.types.MyType()",
         )
 
     def test_repr_user_type_user_prefix_None(self):
         class MyType(UserDefinedType):
-
             def get_col_spec(self):
                 return "MYTYPE"
 
         type_ = MyType()
-        self.autogen_context.opts['user_module_prefix'] = None
+        self.autogen_context.opts["user_module_prefix"] = None
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(type_, self.autogen_context),
-            "tests.test_autogen_render.MyType()"
+            "tests.test_autogen_render.MyType()",
         )
 
     def test_repr_user_type_user_prefix_present(self):
         from sqlalchemy.types import UserDefinedType
 
         class MyType(UserDefinedType):
-
             def get_col_spec(self):
                 return "MYTYPE"
 
         type_ = MyType()
 
-        self.autogen_context.opts['user_module_prefix'] = 'user.'
+        self.autogen_context.opts["user_module_prefix"] = "user."
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(type_, self.autogen_context),
-            "user.MyType()"
+            "user.MyType()",
         )
 
     def test_repr_dialect_type(self):
         from sqlalchemy.dialects.mysql import VARCHAR
 
-        type_ = VARCHAR(20, charset='utf8', national=True)
+        type_ = VARCHAR(20, charset="utf8", national=True)
 
-        self.autogen_context.opts['user_module_prefix'] = None
+        self.autogen_context.opts["user_module_prefix"] = None
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(type_, self.autogen_context),
-            "mysql.VARCHAR(charset='utf8', national=True, length=20)"
+            "mysql.VARCHAR(charset='utf8', national=True, length=20)",
+        )
+        eq_(
+            self.autogen_context.imports,
+            set(["from sqlalchemy.dialects import mysql"]),
         )
-        eq_(self.autogen_context.imports,
-            set(['from sqlalchemy.dialects import mysql'])
-            )
 
     def test_render_server_default_text(self):
         c = Column(
-            'updated_at', TIMESTAMP(),
-            server_default=text('now()'),
-            nullable=False)
-        result = autogenerate.render._render_column(
-            c, self.autogen_context
+            "updated_at",
+            TIMESTAMP(),
+            server_default=text("now()"),
+            nullable=False,
         )
+        result = autogenerate.render._render_column(c, self.autogen_context)
         eq_ignore_whitespace(
             result,
-            'sa.Column(\'updated_at\', sa.TIMESTAMP(), '
-            'server_default=sa.text(!U\'now()\'), '
-            'nullable=False)'
+            "sa.Column('updated_at', sa.TIMESTAMP(), "
+            "server_default=sa.text(!U'now()'), "
+            "nullable=False)",
         )
 
     def test_render_server_default_non_native_boolean(self):
         c = Column(
-            'updated_at', Boolean(),
-            server_default=false(),
-            nullable=False)
-
-        result = autogenerate.render._render_column(
-            c, self.autogen_context
+            "updated_at", Boolean(), server_default=false(), nullable=False
         )
+
+        result = autogenerate.render._render_column(c, self.autogen_context)
         eq_ignore_whitespace(
             result,
-            'sa.Column(\'updated_at\', sa.Boolean(), '
-            'server_default=sa.text(!U\'0\'), '
-            'nullable=False)'
+            "sa.Column('updated_at', sa.Boolean(), "
+            "server_default=sa.text(!U'0'), "
+            "nullable=False)",
         )
 
     def test_render_server_default_func(self):
         c = Column(
-            'updated_at', TIMESTAMP(),
+            "updated_at",
+            TIMESTAMP(),
             server_default=func.now(),
-            nullable=False)
-        result = autogenerate.render._render_column(
-            c, self.autogen_context
+            nullable=False,
         )
+        result = autogenerate.render._render_column(c, self.autogen_context)
         eq_ignore_whitespace(
             result,
-            'sa.Column(\'updated_at\', sa.TIMESTAMP(), '
-            'server_default=sa.text(!U\'now()\'), '
-            'nullable=False)'
+            "sa.Column('updated_at', sa.TIMESTAMP(), "
+            "server_default=sa.text(!U'now()'), "
+            "nullable=False)",
         )
 
     def test_render_server_default_int(self):
-        c = Column(
-            'value', Integer,
-            server_default="0")
-        result = autogenerate.render._render_column(
-            c, self.autogen_context
-        )
+        c = Column("value", Integer, server_default="0")
+        result = autogenerate.render._render_column(c, self.autogen_context)
         eq_(
             result,
             "sa.Column('value', sa.Integer(), "
-            "server_default='0', nullable=True)"
+            "server_default='0', nullable=True)",
         )
 
     def test_render_modify_reflected_int_server_default(self):
         op_obj = ops.AlterColumnOp(
-            "sometable", "somecolumn",
+            "sometable",
+            "somecolumn",
             existing_type=Integer(),
             existing_server_default=DefaultClause(text("5")),
-            modify_nullable=True
+            modify_nullable=True,
         )
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.alter_column('sometable', 'somecolumn', "
             "existing_type=sa.Integer(), nullable=True, "
-            "existing_server_default=sa.text(!U'5'))"
+            "existing_server_default=sa.text(!U'5'))",
         )
 
     def test_render_executesql_plaintext(self):
         op_obj = ops.ExecuteSQLOp("drop table foo")
         eq_(
             autogenerate.render_op_text(self.autogen_context, op_obj),
-            "op.execute('drop table foo')"
+            "op.execute('drop table foo')",
         )
 
     def test_render_executesql_sqlexpr_notimplemented(self):
-        sql = table('x', column('q')).insert()
+        sql = table("x", column("q")).insert()
         op_obj = ops.ExecuteSQLOp(sql)
         assert_raises(
             NotImplementedError,
-            autogenerate.render_op_text, self.autogen_context, op_obj
+            autogenerate.render_op_text,
+            self.autogen_context,
+            op_obj,
         )
 
 
 class RenderNamingConventionTest(TestBase):
-    __requires__ = ('sqlalchemy_094',)
+    __requires__ = ("sqlalchemy_094",)
 
     def setUp(self):
 
         convention = {
-            "ix": 'ix_%(custom)s_%(column_0_label)s',
+            "ix": "ix_%(custom)s_%(column_0_label)s",
             "uq": "uq_%(custom)s_%(table_name)s_%(column_0_name)s",
             "ck": "ck_%(custom)s_%(table_name)s",
             "fk": "fk_%(custom)s_%(table_name)s_"
             "%(column_0_name)s_%(referred_table_name)s",
             "pk": "pk_%(custom)s_%(table_name)s",
-            "custom": lambda const, table: "ct"
+            "custom": lambda const, table: "ct",
         }
 
-        self.metadata = MetaData(
-            naming_convention=convention
-        )
+        self.metadata = MetaData(naming_convention=convention)
 
         ctx_opts = {
-            'sqlalchemy_module_prefix': 'sa.',
-            'alembic_module_prefix': 'op.',
-            'target_metadata': MetaData()
+            "sqlalchemy_module_prefix": "sa.",
+            "alembic_module_prefix": "op.",
+            "target_metadata": MetaData(),
         }
         context = MigrationContext.configure(
-            dialect_name="postgresql",
-            opts=ctx_opts
+            dialect_name="postgresql", opts=ctx_opts
         )
         self.autogen_context = api.AutogenContext(context)
 
     def test_schema_type_boolean(self):
-        t = Table('t', self.metadata, Column('c', Boolean(name='xyz')))
+        t = Table("t", self.metadata, Column("c", Boolean(name="xyz")))
         op_obj = ops.AddColumnOp.from_column(t.c.c)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.add_column('t', "
-            "sa.Column('c', sa.Boolean(name='xyz'), nullable=True))"
+            "sa.Column('c', sa.Boolean(name='xyz'), nullable=True))",
         )
 
     def test_explicit_unique_constraint(self):
-        t = Table('t', self.metadata, Column('c', Integer))
+        t = Table("t", self.metadata, Column("c", Integer))
         eq_ignore_whitespace(
             autogenerate.render._render_unique_constraint(
-                UniqueConstraint(t.c.c, deferrable='XYZ'),
-                self.autogen_context
+                UniqueConstraint(t.c.c, deferrable="XYZ"), self.autogen_context
             ),
             "sa.UniqueConstraint('c', deferrable='XYZ', "
-            "name=op.f('uq_ct_t_c'))"
+            "name=op.f('uq_ct_t_c'))",
         )
 
     def test_explicit_named_unique_constraint(self):
-        t = Table('t', self.metadata, Column('c', Integer))
+        t = Table("t", self.metadata, Column("c", Integer))
         eq_ignore_whitespace(
             autogenerate.render._render_unique_constraint(
-                UniqueConstraint(t.c.c, name='q'),
-                self.autogen_context
+                UniqueConstraint(t.c.c, name="q"), self.autogen_context
             ),
-            "sa.UniqueConstraint('c', name='q')"
+            "sa.UniqueConstraint('c', name='q')",
         )
 
     def test_render_add_index(self):
-        t = Table('test', self.metadata,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  )
+        t = Table(
+            "test",
+            self.metadata,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+        )
         idx = Index(None, t.c.active, t.c.code)
         op_obj = ops.CreateIndexOp.from_index(idx)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_index(op.f('ix_ct_test_active'), 'test', "
-            "['active', 'code'], unique=False)"
+            "['active', 'code'], unique=False)",
         )
 
     def test_render_drop_index(self):
-        t = Table('test', self.metadata,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  )
+        t = Table(
+            "test",
+            self.metadata,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+        )
         idx = Index(None, t.c.active, t.c.code)
         op_obj = ops.DropIndexOp.from_index(idx)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
-            "op.drop_index(op.f('ix_ct_test_active'), table_name='test')"
+            "op.drop_index(op.f('ix_ct_test_active'), table_name='test')",
         )
 
     def test_render_add_index_schema(self):
-        t = Table('test', self.metadata,
-                  Column('id', Integer, primary_key=True),
-                  Column('active', Boolean()),
-                  Column('code', String(255)),
-                  schema='CamelSchema'
-                  )
+        t = Table(
+            "test",
+            self.metadata,
+            Column("id", Integer, primary_key=True),
+            Column("active", Boolean()),
+            Column("code", String(255)),
+            schema="CamelSchema",
+        )
         idx = Index(None, t.c.active, t.c.code)
         op_obj = ops.CreateIndexOp.from_index(idx)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_index(op.f('ix_ct_CamelSchema_test_active'), 'test', "
-            "['active', 'code'], unique=False, schema='CamelSchema')"
+            "['active', 'code'], unique=False, schema='CamelSchema')",
         )
 
     def test_implicit_unique_constraint(self):
-        t = Table('t', self.metadata, Column('c', Integer, unique=True))
+        t = Table("t", self.metadata, Column("c", Integer, unique=True))
         uq = [c for c in t.constraints if isinstance(c, UniqueConstraint)][0]
         eq_ignore_whitespace(
-            autogenerate.render._render_unique_constraint(uq,
-                                                          self.autogen_context
-                                                          ),
-            "sa.UniqueConstraint('c', name=op.f('uq_ct_t_c'))"
+            autogenerate.render._render_unique_constraint(
+                uq, self.autogen_context
+            ),
+            "sa.UniqueConstraint('c', name=op.f('uq_ct_t_c'))",
         )
 
     def test_inline_pk_constraint(self):
-        t = Table('t', self.metadata, Column('c', Integer, primary_key=True))
+        t = Table("t", self.metadata, Column("c", Integer, primary_key=True))
         op_obj = ops.CreateTableOp.from_table(t)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_table('t',sa.Column('c', sa.Integer(), nullable=False),"
-            "sa.PrimaryKeyConstraint('c', name=op.f('pk_ct_t')))"
+            "sa.PrimaryKeyConstraint('c', name=op.f('pk_ct_t')))",
         )
 
     def test_inline_ck_constraint(self):
         t = Table(
-            't', self.metadata, Column('c', Integer), CheckConstraint("c > 5"))
+            "t", self.metadata, Column("c", Integer), CheckConstraint("c > 5")
+        )
         op_obj = ops.CreateTableOp.from_table(t)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_table('t',sa.Column('c', sa.Integer(), nullable=True),"
-            "sa.CheckConstraint(!U'c > 5', name=op.f('ck_ct_t')))"
+            "sa.CheckConstraint(!U'c > 5', name=op.f('ck_ct_t')))",
         )
 
     def test_inline_fk(self):
-        t = Table('t', self.metadata, Column('c', Integer, ForeignKey('q.id')))
+        t = Table("t", self.metadata, Column("c", Integer, ForeignKey("q.id")))
         op_obj = ops.CreateTableOp.from_table(t)
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_table('t',sa.Column('c', sa.Integer(), nullable=True),"
             "sa.ForeignKeyConstraint(['c'], ['q.id'], "
-            "name=op.f('fk_ct_t_c_q')))"
+            "name=op.f('fk_ct_t_c_q')))",
         )
 
     def test_render_check_constraint_renamed(self):
@@ -1760,31 +1912,31 @@ class RenderNamingConventionTest(TestBase):
         used.
 
         """
-        m1 = MetaData(naming_convention={
-            "ck": "ck_%(table_name)s_%(constraint_name)s"})
+        m1 = MetaData(
+            naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+        )
         ck = CheckConstraint("im a constraint", name="cc1")
-        Table('t', m1, Column('x'), ck)
+        Table("t", m1, Column("x"), ck)
 
         eq_ignore_whitespace(
             autogenerate.render._render_check_constraint(
-                ck,
-                self.autogen_context
+                ck, self.autogen_context
             ),
-            "sa.CheckConstraint(!U'im a constraint', name=op.f('ck_t_cc1'))"
+            "sa.CheckConstraint(!U'im a constraint', name=op.f('ck_t_cc1'))",
         )
 
     def test_create_table_plus_add_index_in_modify(self):
-        uo = ops.UpgradeOps(ops=[
-            ops.CreateTableOp(
-                "sometable",
-                [Column('x', Integer), Column('y', Integer)]
-            ),
-            ops.ModifyTableOps(
-                "sometable", ops=[
-                    ops.CreateIndexOp('ix1', 'sometable', ['x', 'y'])
-                ]
-            )
-        ])
+        uo = ops.UpgradeOps(
+            ops=[
+                ops.CreateTableOp(
+                    "sometable", [Column("x", Integer), Column("y", Integer)]
+                ),
+                ops.ModifyTableOps(
+                    "sometable",
+                    ops=[ops.CreateIndexOp("ix1", "sometable", ["x", "y"])],
+                ),
+            ]
+        )
 
         eq_(
             autogenerate.render_python_code(uo, render_as_batch=True),
@@ -1797,5 +1949,5 @@ class RenderNamingConventionTest(TestBase):
             "as batch_op:\n"
             "        batch_op.create_index("
             "'ix1', ['x', 'y'], unique=False)\n\n"
-            "    # ### end Alembic commands ###"
+            "    # ### end Alembic commands ###",
         )
index 99605d0650e0a69f120acb4f915ab5299a449490..2a7c52eff48873d9d5fb478efa0302ef625627b7 100644 (file)
@@ -11,9 +11,22 @@ from alembic.operations.batch import ApplyBatchImpl
 from alembic.runtime.migration import MigrationContext
 
 
-from sqlalchemy import Integer, Table, Column, String, MetaData, ForeignKey, \
-    UniqueConstraint, ForeignKeyConstraint, Index, Boolean, CheckConstraint, \
-    Enum, DateTime, PrimaryKeyConstraint
+from sqlalchemy import (
+    Integer,
+    Table,
+    Column,
+    String,
+    MetaData,
+    ForeignKey,
+    UniqueConstraint,
+    ForeignKeyConstraint,
+    Index,
+    Boolean,
+    CheckConstraint,
+    Enum,
+    DateTime,
+    PrimaryKeyConstraint,
+)
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.sql import column, text, select
 from sqlalchemy.schema import CreateTable, CreateIndex
@@ -21,84 +34,91 @@ from sqlalchemy import exc
 
 
 class BatchApplyTest(TestBase):
-
     def setUp(self):
         self.op = Operations(mock.Mock(opts={}))
 
     def _simple_fixture(self, table_args=(), table_kwargs={}):
         m = MetaData()
         t = Table(
-            'tname', m,
-            Column('id', Integer, primary_key=True),
-            Column('x', String(10)),
-            Column('y', Integer)
+            "tname",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("x", String(10)),
+            Column("y", Integer),
         )
         return ApplyBatchImpl(t, table_args, table_kwargs, False)
 
     def _uq_fixture(self, table_args=(), table_kwargs={}):
         m = MetaData()
         t = Table(
-            'tname', m,
-            Column('id', Integer, primary_key=True),
-            Column('x', String()),
-            Column('y', Integer),
-            UniqueConstraint('y', name='uq1')
+            "tname",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("x", String()),
+            Column("y", Integer),
+            UniqueConstraint("y", name="uq1"),
         )
         return ApplyBatchImpl(t, table_args, table_kwargs, False)
 
     def _ix_fixture(self, table_args=(), table_kwargs={}):
         m = MetaData()
         t = Table(
-            'tname', m,
-            Column('id', Integer, primary_key=True),
-            Column('x', String()),
-            Column('y', Integer),
-            Index('ix1', 'y')
+            "tname",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("x", String()),
+            Column("y", Integer),
+            Index("ix1", "y"),
         )
         return ApplyBatchImpl(t, table_args, table_kwargs, False)
 
     def _pk_fixture(self):
         m = MetaData()
         t = Table(
-            'tname', m,
-            Column('id', Integer),
-            Column('x', String()),
-            Column('y', Integer),
-            PrimaryKeyConstraint('id', name="mypk")
+            "tname",
+            m,
+            Column("id", Integer),
+            Column("x", String()),
+            Column("y", Integer),
+            PrimaryKeyConstraint("id", name="mypk"),
         )
         return ApplyBatchImpl(t, (), {}, False)
 
     def _literal_ck_fixture(
-            self, copy_from=None, table_args=(), table_kwargs={}):
+        self, copy_from=None, table_args=(), table_kwargs={}
+    ):
         m = MetaData()
         if copy_from is not None:
             t = copy_from
         else:
             t = Table(
-                'tname', m,
-                Column('id', Integer, primary_key=True),
-                Column('email', String()),
-                CheckConstraint("email LIKE '%@%'")
+                "tname",
+                m,
+                Column("id", Integer, primary_key=True),
+                Column("email", String()),
+                CheckConstraint("email LIKE '%@%'"),
             )
         return ApplyBatchImpl(t, table_args, table_kwargs, False)
 
     def _sql_ck_fixture(self, table_args=(), table_kwargs={}):
         m = MetaData()
         t = Table(
-            'tname', m,
-            Column('id', Integer, primary_key=True),
-            Column('email', String())
+            "tname",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("email", String()),
         )
-        t.append_constraint(CheckConstraint(t.c.email.like('%@%')))
+        t.append_constraint(CheckConstraint(t.c.email.like("%@%")))
         return ApplyBatchImpl(t, table_args, table_kwargs, False)
 
     def _fk_fixture(self, table_args=(), table_kwargs={}):
         m = MetaData()
         t = Table(
-            'tname', m,
-            Column('id', Integer, primary_key=True),
-            Column('email', String()),
-            Column('user_id', Integer, ForeignKey('user.id'))
+            "tname",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("email", String()),
+            Column("user_id", Integer, ForeignKey("user.id")),
         )
         return ApplyBatchImpl(t, table_args, table_kwargs, False)
 
@@ -110,93 +130,108 @@ class BatchApplyTest(TestBase):
             schemaarg = ""
 
         t = Table(
-            'tname', m,
-            Column('id', Integer, primary_key=True),
-            Column('email', String()),
-            Column('user_id_1', Integer, ForeignKey('%suser.id' % schemaarg)),
-            Column('user_id_2', Integer, ForeignKey('%suser.id' % schemaarg)),
-            Column('user_id_3', Integer),
-            Column('user_id_version', Integer),
+            "tname",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("email", String()),
+            Column("user_id_1", Integer, ForeignKey("%suser.id" % schemaarg)),
+            Column("user_id_2", Integer, ForeignKey("%suser.id" % schemaarg)),
+            Column("user_id_3", Integer),
+            Column("user_id_version", Integer),
             ForeignKeyConstraint(
-                ['user_id_3', 'user_id_version'],
-                ['%suser.id' % schemaarg, '%suser.id_version' % schemaarg]),
-            schema=schema
+                ["user_id_3", "user_id_version"],
+                ["%suser.id" % schemaarg, "%suser.id_version" % schemaarg],
+            ),
+            schema=schema,
         )
         return ApplyBatchImpl(t, table_args, table_kwargs, False)
 
     def _named_fk_fixture(self, table_args=(), table_kwargs={}):
         m = MetaData()
         t = Table(
-            'tname', m,
-            Column('id', Integer, primary_key=True),
-            Column('email', String()),
-            Column('user_id', Integer, ForeignKey('user.id', name='ufk'))
+            "tname",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("email", String()),
+            Column("user_id", Integer, ForeignKey("user.id", name="ufk")),
         )
         return ApplyBatchImpl(t, table_args, table_kwargs, False)
 
     def _selfref_fk_fixture(self, table_args=(), table_kwargs={}):
         m = MetaData()
         t = Table(
-            'tname', m,
-            Column('id', Integer, primary_key=True),
-            Column('parent_id', Integer, ForeignKey('tname.id')),
-            Column('data', String)
+            "tname",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("parent_id", Integer, ForeignKey("tname.id")),
+            Column("data", String),
         )
         return ApplyBatchImpl(t, table_args, table_kwargs, False)
 
     def _boolean_fixture(self, table_args=(), table_kwargs={}):
         m = MetaData()
         t = Table(
-            'tname', m,
-            Column('id', Integer, primary_key=True),
-            Column('flag', Boolean)
+            "tname",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("flag", Boolean),
         )
         return ApplyBatchImpl(t, table_args, table_kwargs, False)
 
     def _boolean_no_ck_fixture(self, table_args=(), table_kwargs={}):
         m = MetaData()
         t = Table(
-            'tname', m,
-            Column('id', Integer, primary_key=True),
-            Column('flag', Boolean(create_constraint=False))
+            "tname",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("flag", Boolean(create_constraint=False)),
         )
         return ApplyBatchImpl(t, table_args, table_kwargs, False)
 
     def _enum_fixture(self, table_args=(), table_kwargs={}):
         m = MetaData()
         t = Table(
-            'tname', m,
-            Column('id', Integer, primary_key=True),
-            Column('thing', Enum('a', 'b', 'c'))
+            "tname",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("thing", Enum("a", "b", "c")),
         )
         return ApplyBatchImpl(t, table_args, table_kwargs, False)
 
     def _server_default_fixture(self, table_args=(), table_kwargs={}):
         m = MetaData()
         t = Table(
-            'tname', m,
-            Column('id', Integer, primary_key=True),
-            Column('thing', String(), server_default='')
+            "tname",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("thing", String(), server_default=""),
         )
         return ApplyBatchImpl(t, table_args, table_kwargs, False)
 
-    def _assert_impl(self, impl, colnames=None,
-                     ddl_contains=None, ddl_not_contains=None,
-                     dialect='default', schema=None):
+    def _assert_impl(
+        self,
+        impl,
+        colnames=None,
+        ddl_contains=None,
+        ddl_not_contains=None,
+        dialect="default",
+        schema=None,
+    ):
         context = op_fixture(dialect=dialect)
 
         impl._create(context.impl)
 
         if colnames is None:
-            colnames = ['id', 'x', 'y']
+            colnames = ["id", "x", "y"]
         eq_(impl.new_table.c.keys(), colnames)
 
         pk_cols = [col for col in impl.new_table.c if col.primary_key]
         eq_(list(impl.new_table.primary_key), pk_cols)
 
         create_stmt = str(
-            CreateTable(impl.new_table).compile(dialect=context.dialect))
-        create_stmt = re.sub(r'[\n\t]', '', create_stmt)
+            CreateTable(impl.new_table).compile(dialect=context.dialect)
+        )
+        create_stmt = re.sub(r"[\n\t]", "", create_stmt)
 
         idx_stmt = ""
         for idx in impl.indexes.values():
@@ -205,17 +240,16 @@ class BatchApplyTest(TestBase):
             impl.new_table.name = impl.table.name
             idx_stmt += str(CreateIndex(idx).compile(dialect=context.dialect))
             impl.new_table.name = ApplyBatchImpl._calc_temp_name(
-                impl.table.name)
-        idx_stmt = re.sub(r'[\n\t]', '', idx_stmt)
+                impl.table.name
+            )
+        idx_stmt = re.sub(r"[\n\t]", "", idx_stmt)
 
         if ddl_contains:
             assert ddl_contains in create_stmt + idx_stmt
         if ddl_not_contains:
             assert ddl_not_contains not in create_stmt + idx_stmt
 
-        expected = [
-            create_stmt,
-        ]
+        expected = [create_stmt]
 
         if schema:
             args = {"schema": "%s." % schema}
@@ -224,32 +258,40 @@ class BatchApplyTest(TestBase):
 
         args["temp_name"] = impl.new_table.name
 
-        args['colnames'] = ", ".join([
-            impl.new_table.c[name].name
-            for name in colnames
-            if name in impl.table.c])
+        args["colnames"] = ", ".join(
+            [
+                impl.new_table.c[name].name
+                for name in colnames
+                if name in impl.table.c
+            ]
+        )
 
-        args['tname_colnames'] = ", ".join(
-            "CAST(%(schema)stname.%(name)s AS %(type)s) AS anon_1" % {
-                'schema': args['schema'],
-                'name': name,
-                'type': impl.new_table.c[name].type
+        args["tname_colnames"] = ", ".join(
+            "CAST(%(schema)stname.%(name)s AS %(type)s) AS anon_1"
+            % {
+                "schema": args["schema"],
+                "name": name,
+                "type": impl.new_table.c[name].type,
             }
             if (
                 impl.new_table.c[name].type._type_affinity
-                is not impl.table.c[name].type._type_affinity)
-            else "%(schema)stname.%(name)s" % {
-                'schema': args['schema'], 'name': name}
-            for name in colnames if name in impl.table.c
-        )
-
-        expected.extend([
-            'INSERT INTO %(schema)s%(temp_name)s (%(colnames)s) '
-            'SELECT %(tname_colnames)s FROM %(schema)stname' % args,
-            'DROP TABLE %(schema)stname' % args,
-            'ALTER TABLE %(schema)s%(temp_name)s '
-            'RENAME TO %(schema)stname' % args
-        ])
+                is not impl.table.c[name].type._type_affinity
+            )
+            else "%(schema)stname.%(name)s"
+            % {"schema": args["schema"], "name": name}
+            for name in colnames
+            if name in impl.table.c
+        )
+
+        expected.extend(
+            [
+                "INSERT INTO %(schema)s%(temp_name)s (%(colnames)s) "
+                "SELECT %(tname_colnames)s FROM %(schema)stname" % args,
+                "DROP TABLE %(schema)stname" % args,
+                "ALTER TABLE %(schema)s%(temp_name)s "
+                "RENAME TO %(schema)stname" % args,
+            ]
+        )
         if idx_stmt:
             expected.append(idx_stmt)
         context.assert_(*expected)
@@ -257,36 +299,42 @@ class BatchApplyTest(TestBase):
 
     def test_change_type(self):
         impl = self._simple_fixture()
-        impl.alter_column('tname', 'x', type_=String)
+        impl.alter_column("tname", "x", type_=String)
         new_table = self._assert_impl(impl)
         assert new_table.c.x.type._type_affinity is String
 
     def test_rename_col(self):
         impl = self._simple_fixture()
-        impl.alter_column('tname', 'x', name='q')
+        impl.alter_column("tname", "x", name="q")
         new_table = self._assert_impl(impl)
-        eq_(new_table.c.x.name, 'q')
+        eq_(new_table.c.x.name, "q")
 
     def test_rename_col_boolean(self):
         impl = self._boolean_fixture()
-        impl.alter_column('tname', 'flag', name='bflag')
+        impl.alter_column("tname", "flag", name="bflag")
         new_table = self._assert_impl(
-            impl, ddl_contains="CHECK (bflag IN (0, 1)",
-            colnames=["id", "flag"])
-        eq_(new_table.c.flag.name, 'bflag')
+            impl,
+            ddl_contains="CHECK (bflag IN (0, 1)",
+            colnames=["id", "flag"],
+        )
+        eq_(new_table.c.flag.name, "bflag")
         eq_(
-            len([
-                const for const
-                in new_table.constraints
-                if isinstance(const, CheckConstraint)]),
-            1)
+            len(
+                [
+                    const
+                    for const in new_table.constraints
+                    if isinstance(const, CheckConstraint)
+                ]
+            ),
+            1,
+        )
 
     def test_change_type_schematype_to_non(self):
         impl = self._boolean_fixture()
-        impl.alter_column('tname', 'flag', type_=Integer)
+        impl.alter_column("tname", "flag", type_=Integer)
         new_table = self._assert_impl(
-            impl, colnames=['id', 'flag'],
-            ddl_not_contains="CHECK")
+            impl, colnames=["id", "flag"], ddl_not_contains="CHECK"
+        )
         assert new_table.c.flag.type._type_affinity is Integer
 
         # NOTE: we can't do test_change_type_non_to_schematype
@@ -295,254 +343,310 @@ class BatchApplyTest(TestBase):
 
     def test_rename_col_boolean_no_ck(self):
         impl = self._boolean_no_ck_fixture()
-        impl.alter_column('tname', 'flag', name='bflag')
+        impl.alter_column("tname", "flag", name="bflag")
         new_table = self._assert_impl(
-            impl, ddl_not_contains="CHECK",
-            colnames=["id", "flag"])
-        eq_(new_table.c.flag.name, 'bflag')
+            impl, ddl_not_contains="CHECK", colnames=["id", "flag"]
+        )
+        eq_(new_table.c.flag.name, "bflag")
         eq_(
-            len([
-                const for const
-                in new_table.constraints
-                if isinstance(const, CheckConstraint)]),
-            0)
+            len(
+                [
+                    const
+                    for const in new_table.constraints
+                    if isinstance(const, CheckConstraint)
+                ]
+            ),
+            0,
+        )
 
     def test_rename_col_enum(self):
         impl = self._enum_fixture()
-        impl.alter_column('tname', 'thing', name='thang')
+        impl.alter_column("tname", "thing", name="thang")
         new_table = self._assert_impl(
-            impl, ddl_contains="CHECK (thang IN ('a', 'b', 'c')",
-            colnames=["id", "thing"])
-        eq_(new_table.c.thing.name, 'thang')
+            impl,
+            ddl_contains="CHECK (thang IN ('a', 'b', 'c')",
+            colnames=["id", "thing"],
+        )
+        eq_(new_table.c.thing.name, "thang")
         eq_(
-            len([
-                const for const
-                in new_table.constraints
-                if isinstance(const, CheckConstraint)]),
-            1)
+            len(
+                [
+                    const
+                    for const in new_table.constraints
+                    if isinstance(const, CheckConstraint)
+                ]
+            ),
+            1,
+        )
 
     def test_rename_col_literal_ck(self):
         impl = self._literal_ck_fixture()
-        impl.alter_column('tname', 'email', name='emol')
+        impl.alter_column("tname", "email", name="emol")
         new_table = self._assert_impl(
             # note this is wrong, we don't dig into the SQL
-            impl, ddl_contains="CHECK (email LIKE '%@%')",
-            colnames=["id", "email"])
+            impl,
+            ddl_contains="CHECK (email LIKE '%@%')",
+            colnames=["id", "email"],
+        )
         eq_(
-            len([c for c in new_table.constraints
-                if isinstance(c, CheckConstraint)]), 1)
+            len(
+                [
+                    c
+                    for c in new_table.constraints
+                    if isinstance(c, CheckConstraint)
+                ]
+            ),
+            1,
+        )
 
-        eq_(new_table.c.email.name, 'emol')
+        eq_(new_table.c.email.name, "emol")
 
     def test_rename_col_literal_ck_workaround(self):
         impl = self._literal_ck_fixture(
             copy_from=Table(
-                'tname', MetaData(),
-                Column('id', Integer, primary_key=True),
-                Column('email', String),
+                "tname",
+                MetaData(),
+                Column("id", Integer, primary_key=True),
+                Column("email", String),
             ),
-            table_args=[CheckConstraint("emol LIKE '%@%'")])
+            table_args=[CheckConstraint("emol LIKE '%@%'")],
+        )
 
-        impl.alter_column('tname', 'email', name='emol')
+        impl.alter_column("tname", "email", name="emol")
         new_table = self._assert_impl(
-            impl, ddl_contains="CHECK (emol LIKE '%@%')",
-            colnames=["id", "email"])
+            impl,
+            ddl_contains="CHECK (emol LIKE '%@%')",
+            colnames=["id", "email"],
+        )
         eq_(
-            len([c for c in new_table.constraints
-                if isinstance(c, CheckConstraint)]), 1)
-        eq_(new_table.c.email.name, 'emol')
+            len(
+                [
+                    c
+                    for c in new_table.constraints
+                    if isinstance(c, CheckConstraint)
+                ]
+            ),
+            1,
+        )
+        eq_(new_table.c.email.name, "emol")
 
     def test_rename_col_sql_ck(self):
         impl = self._sql_ck_fixture()
 
-        impl.alter_column('tname', 'email', name='emol')
+        impl.alter_column("tname", "email", name="emol")
         new_table = self._assert_impl(
-            impl, ddl_contains="CHECK (emol LIKE '%@%')",
-            colnames=["id", "email"])
+            impl,
+            ddl_contains="CHECK (emol LIKE '%@%')",
+            colnames=["id", "email"],
+        )
         eq_(
-            len([c for c in new_table.constraints
-                if isinstance(c, CheckConstraint)]), 1)
+            len(
+                [
+                    c
+                    for c in new_table.constraints
+                    if isinstance(c, CheckConstraint)
+                ]
+            ),
+            1,
+        )
 
-        eq_(new_table.c.email.name, 'emol')
+        eq_(new_table.c.email.name, "emol")
 
     def test_add_col(self):
         impl = self._simple_fixture()
-        col = Column('g', Integer)
+        col = Column("g", Integer)
         # operations.add_column produces a table
-        t = self.op.schema_obj.table('tname', col)  # noqa
-        impl.add_column('tname', col)
-        new_table = self._assert_impl(impl, colnames=['id', 'x', 'y', 'g'])
-        eq_(new_table.c.g.name, 'g')
+        t = self.op.schema_obj.table("tname", col)  # noqa
+        impl.add_column("tname", col)
+        new_table = self._assert_impl(impl, colnames=["id", "x", "y", "g"])
+        eq_(new_table.c.g.name, "g")
 
     def test_add_server_default(self):
         impl = self._simple_fixture()
-        impl.alter_column('tname', 'y', server_default="10")
-        new_table = self._assert_impl(
-            impl, ddl_contains="DEFAULT '10'")
-        eq_(
-            new_table.c.y.server_default.arg, "10"
-        )
+        impl.alter_column("tname", "y", server_default="10")
+        new_table = self._assert_impl(impl, ddl_contains="DEFAULT '10'")
+        eq_(new_table.c.y.server_default.arg, "10")
 
     def test_drop_server_default(self):
         impl = self._server_default_fixture()
-        impl.alter_column('tname', 'thing', server_default=None)
+        impl.alter_column("tname", "thing", server_default=None)
         new_table = self._assert_impl(
-            impl, colnames=['id', 'thing'], ddl_not_contains="DEFAULT")
+            impl, colnames=["id", "thing"], ddl_not_contains="DEFAULT"
+        )
         eq_(new_table.c.thing.server_default, None)
 
     def test_rename_col_pk(self):
         impl = self._simple_fixture()
-        impl.alter_column('tname', 'id', name='foobar')
+        impl.alter_column("tname", "id", name="foobar")
         new_table = self._assert_impl(
-            impl, ddl_contains="PRIMARY KEY (foobar)")
-        eq_(new_table.c.id.name, 'foobar')
+            impl, ddl_contains="PRIMARY KEY (foobar)"
+        )
+        eq_(new_table.c.id.name, "foobar")
         eq_(list(new_table.primary_key), [new_table.c.id])
 
     def test_rename_col_fk(self):
         impl = self._fk_fixture()
-        impl.alter_column('tname', 'user_id', name='foobar')
+        impl.alter_column("tname", "user_id", name="foobar")
         new_table = self._assert_impl(
-            impl, colnames=['id', 'email', 'user_id'],
-            ddl_contains='FOREIGN KEY(foobar) REFERENCES "user" (id)')
-        eq_(new_table.c.user_id.name, 'foobar')
+            impl,
+            colnames=["id", "email", "user_id"],
+            ddl_contains='FOREIGN KEY(foobar) REFERENCES "user" (id)',
+        )
+        eq_(new_table.c.user_id.name, "foobar")
         eq_(
-            list(new_table.c.user_id.foreign_keys)[0]._get_colspec(),
-            "user.id"
+            list(new_table.c.user_id.foreign_keys)[0]._get_colspec(), "user.id"
         )
 
     def test_regen_multi_fk(self):
         impl = self._multi_fk_fixture()
         self._assert_impl(
-            impl, colnames=[
-                'id', 'email', 'user_id_1', 'user_id_2',
-                'user_id_3', 'user_id_version'],
-            ddl_contains='FOREIGN KEY(user_id_3, user_id_version) '
-            'REFERENCES "user" (id, id_version)')
+            impl,
+            colnames=[
+                "id",
+                "email",
+                "user_id_1",
+                "user_id_2",
+                "user_id_3",
+                "user_id_version",
+            ],
+            ddl_contains="FOREIGN KEY(user_id_3, user_id_version) "
+            'REFERENCES "user" (id, id_version)',
+        )
 
     def test_regen_multi_fk_schema(self):
-        impl = self._multi_fk_fixture(schema='foo_schema')
+        impl = self._multi_fk_fixture(schema="foo_schema")
         self._assert_impl(
-            impl, colnames=[
-                'id', 'email', 'user_id_1', 'user_id_2',
-                'user_id_3', 'user_id_version'],
-            ddl_contains='FOREIGN KEY(user_id_3, user_id_version) '
+            impl,
+            colnames=[
+                "id",
+                "email",
+                "user_id_1",
+                "user_id_2",
+                "user_id_3",
+                "user_id_version",
+            ],
+            ddl_contains="FOREIGN KEY(user_id_3, user_id_version) "
             'REFERENCES foo_schema."user" (id, id_version)',
-            schema='foo_schema')
+            schema="foo_schema",
+        )
 
     def test_drop_col(self):
         impl = self._simple_fixture()
-        impl.drop_column('tname', column('x'))
-        new_table = self._assert_impl(impl, colnames=['id', 'y'])
-        assert 'y' in new_table.c
-        assert 'x' not in new_table.c
+        impl.drop_column("tname", column("x"))
+        new_table = self._assert_impl(impl, colnames=["id", "y"])
+        assert "y" in new_table.c
+        assert "x" not in new_table.c
 
     def test_drop_col_remove_pk(self):
         impl = self._simple_fixture()
-        impl.drop_column('tname', column('id'))
+        impl.drop_column("tname", column("id"))
         new_table = self._assert_impl(
-            impl, colnames=['x', 'y'], ddl_not_contains="PRIMARY KEY")
-        assert 'y' in new_table.c
-        assert 'id' not in new_table.c
+            impl, colnames=["x", "y"], ddl_not_contains="PRIMARY KEY"
+        )
+        assert "y" in new_table.c
+        assert "id" not in new_table.c
         assert not new_table.primary_key
 
     def test_drop_col_remove_fk(self):
         impl = self._fk_fixture()
-        impl.drop_column('tname', column('user_id'))
+        impl.drop_column("tname", column("user_id"))
         new_table = self._assert_impl(
-            impl, colnames=['id', 'email'], ddl_not_contains="FOREIGN KEY")
-        assert 'user_id' not in new_table.c
+            impl, colnames=["id", "email"], ddl_not_contains="FOREIGN KEY"
+        )
+        assert "user_id" not in new_table.c
         assert not new_table.foreign_keys
 
     def test_drop_col_retain_fk(self):
         impl = self._fk_fixture()
-        impl.drop_column('tname', column('email'))
+        impl.drop_column("tname", column("email"))
         new_table = self._assert_impl(
-            impl, colnames=['id', 'user_id'],
-            ddl_contains='FOREIGN KEY(user_id) REFERENCES "user" (id)')
-        assert 'email' not in new_table.c
+            impl,
+            colnames=["id", "user_id"],
+            ddl_contains='FOREIGN KEY(user_id) REFERENCES "user" (id)',
+        )
+        assert "email" not in new_table.c
         assert new_table.c.user_id.foreign_keys
 
     def test_drop_col_retain_fk_selfref(self):
         impl = self._selfref_fk_fixture()
-        impl.drop_column('tname', column('data'))
-        new_table = self._assert_impl(impl, colnames=['id', 'parent_id'])
-        assert 'data' not in new_table.c
+        impl.drop_column("tname", column("data"))
+        new_table = self._assert_impl(impl, colnames=["id", "parent_id"])
+        assert "data" not in new_table.c
         assert new_table.c.parent_id.foreign_keys
 
     def test_add_fk(self):
         impl = self._simple_fixture()
-        impl.add_column('tname', Column('user_id', Integer))
+        impl.add_column("tname", Column("user_id", Integer))
         fk = self.op.schema_obj.foreign_key_constraint(
-            'fk1', 'tname', 'user',
-            ['user_id'], ['id'])
+            "fk1", "tname", "user", ["user_id"], ["id"]
+        )
         impl.add_constraint(fk)
         new_table = self._assert_impl(
-            impl, colnames=['id', 'x', 'y', 'user_id'],
-            ddl_contains='CONSTRAINT fk1 FOREIGN KEY(user_id) '
-            'REFERENCES "user" (id)')
+            impl,
+            colnames=["id", "x", "y", "user_id"],
+            ddl_contains="CONSTRAINT fk1 FOREIGN KEY(user_id) "
+            'REFERENCES "user" (id)',
+        )
         eq_(
-            list(new_table.c.user_id.foreign_keys)[0]._get_colspec(),
-            'user.id'
+            list(new_table.c.user_id.foreign_keys)[0]._get_colspec(), "user.id"
         )
 
     def test_drop_fk(self):
         impl = self._named_fk_fixture()
-        fk = ForeignKeyConstraint([], [], name='ufk')
+        fk = ForeignKeyConstraint([], [], name="ufk")
         impl.drop_constraint(fk)
         new_table = self._assert_impl(
-            impl, colnames=['id', 'email', 'user_id'],
-            ddl_not_contains="CONSTRANT fk1")
-        eq_(
-            list(new_table.foreign_keys),
-            []
+            impl,
+            colnames=["id", "email", "user_id"],
+            ddl_not_contains="CONSTRANT fk1",
         )
+        eq_(list(new_table.foreign_keys), [])
 
     def test_add_uq(self):
         impl = self._simple_fixture()
-        uq = self.op.schema_obj.unique_constraint(
-            'uq1', 'tname', ['y']
-        )
+        uq = self.op.schema_obj.unique_constraint("uq1", "tname", ["y"])
 
         impl.add_constraint(uq)
         self._assert_impl(
-            impl, colnames=['id', 'x', 'y'],
-            ddl_contains="CONSTRAINT uq1 UNIQUE")
+            impl,
+            colnames=["id", "x", "y"],
+            ddl_contains="CONSTRAINT uq1 UNIQUE",
+        )
 
     def test_drop_uq(self):
         impl = self._uq_fixture()
 
-        uq = self.op.schema_obj.unique_constraint(
-            'uq1', 'tname', ['y']
-        )
+        uq = self.op.schema_obj.unique_constraint("uq1", "tname", ["y"])
         impl.drop_constraint(uq)
         self._assert_impl(
-            impl, colnames=['id', 'x', 'y'],
-            ddl_not_contains="CONSTRAINT uq1 UNIQUE")
+            impl,
+            colnames=["id", "x", "y"],
+            ddl_not_contains="CONSTRAINT uq1 UNIQUE",
+        )
 
     def test_create_index(self):
         impl = self._simple_fixture()
-        ix = self.op.schema_obj.index('ix1', 'tname', ['y'])
+        ix = self.op.schema_obj.index("ix1", "tname", ["y"])
 
         impl.create_index(ix)
         self._assert_impl(
-            impl, colnames=['id', 'x', 'y'],
-            ddl_contains="CREATE INDEX ix1")
+            impl, colnames=["id", "x", "y"], ddl_contains="CREATE INDEX ix1"
+        )
 
     def test_drop_index(self):
         impl = self._ix_fixture()
 
-        ix = self.op.schema_obj.index('ix1', 'tname', ['y'])
+        ix = self.op.schema_obj.index("ix1", "tname", ["y"])
         impl.drop_index(ix)
         self._assert_impl(
-            impl, colnames=['id', 'x', 'y'],
-            ddl_not_contains="CONSTRAINT uq1 UNIQUE")
+            impl,
+            colnames=["id", "x", "y"],
+            ddl_not_contains="CONSTRAINT uq1 UNIQUE",
+        )
 
     def test_add_table_opts(self):
-        impl = self._simple_fixture(table_kwargs={'mysql_engine': 'InnoDB'})
-        self._assert_impl(
-            impl, ddl_contains="ENGINE=InnoDB",
-            dialect='mysql'
-        )
+        impl = self._simple_fixture(table_kwargs={"mysql_engine": "InnoDB"})
+        self._assert_impl(impl, ddl_contains="ENGINE=InnoDB", dialect="mysql")
 
     def test_drop_pk(self):
         impl = self._pk_fixture()
@@ -554,14 +658,15 @@ class BatchApplyTest(TestBase):
 
 
 class BatchAPITest(TestBase):
-
     @contextmanager
     def _fixture(self, schema=None):
         migration_context = mock.Mock(
-            opts={}, impl=mock.MagicMock(__dialect__='sqlite'))
+            opts={}, impl=mock.MagicMock(__dialect__="sqlite")
+        )
         op = Operations(migration_context)
         batch = op.batch_alter_table(
-            'tname', recreate='never', schema=schema).__enter__()
+            "tname", recreate="never", schema=schema
+        ).__enter__()
 
         mock_schema = mock.MagicMock()
         with mock.patch("alembic.operations.schemaobj.sa_schema", mock_schema):
@@ -571,105 +676,131 @@ class BatchAPITest(TestBase):
 
     def test_drop_col(self):
         with self._fixture() as batch:
-            batch.drop_column('q')
+            batch.drop_column("q")
 
         eq_(
             batch.impl.operations.impl.mock_calls,
-            [mock.call.drop_column(
-                'tname', self.mock_schema.Column(), schema=None)]
+            [
+                mock.call.drop_column(
+                    "tname", self.mock_schema.Column(), schema=None
+                )
+            ],
         )
 
     def test_add_col(self):
-        column = Column('w', String(50))
+        column = Column("w", String(50))
 
         with self._fixture() as batch:
             batch.add_column(column)
 
         eq_(
             batch.impl.operations.impl.mock_calls,
-            [mock.call.add_column(
-                'tname', column, schema=None)]
+            [mock.call.add_column("tname", column, schema=None)],
         )
 
     def test_create_fk(self):
         with self._fixture() as batch:
-            batch.create_foreign_key('myfk', 'user', ['x'], ['y'])
+            batch.create_foreign_key("myfk", "user", ["x"], ["y"])
 
         eq_(
             self.mock_schema.ForeignKeyConstraint.mock_calls,
             [
                 mock.call(
-                    ['x'], ['user.y'],
-                    onupdate=None, ondelete=None, name='myfk',
-                    initially=None, deferrable=None, match=None)
-            ]
+                    ["x"],
+                    ["user.y"],
+                    onupdate=None,
+                    ondelete=None,
+                    name="myfk",
+                    initially=None,
+                    deferrable=None,
+                    match=None,
+                )
+            ],
         )
         eq_(
             self.mock_schema.Table.mock_calls,
             [
                 mock.call(
-                    'user', self.mock_schema.MetaData(),
+                    "user",
+                    self.mock_schema.MetaData(),
                     self.mock_schema.Column(),
-                    schema=None
+                    schema=None,
                 ),
                 mock.call(
-                    'tname', self.mock_schema.MetaData(),
+                    "tname",
+                    self.mock_schema.MetaData(),
                     self.mock_schema.Column(),
-                    schema=None
+                    schema=None,
                 ),
                 mock.call().append_constraint(
-                    self.mock_schema.ForeignKeyConstraint())
-            ]
+                    self.mock_schema.ForeignKeyConstraint()
+                ),
+            ],
         )
         eq_(
             batch.impl.operations.impl.mock_calls,
-            [mock.call.add_constraint(
-                self.mock_schema.ForeignKeyConstraint())]
+            [
+                mock.call.add_constraint(
+                    self.mock_schema.ForeignKeyConstraint()
+                )
+            ],
         )
 
     def test_create_fk_schema(self):
-        with self._fixture(schema='foo') as batch:
-            batch.create_foreign_key('myfk', 'user', ['x'], ['y'])
+        with self._fixture(schema="foo") as batch:
+            batch.create_foreign_key("myfk", "user", ["x"], ["y"])
 
         eq_(
             self.mock_schema.ForeignKeyConstraint.mock_calls,
             [
                 mock.call(
-                    ['x'], ['user.y'],
-                    onupdate=None, ondelete=None, name='myfk',
-                    initially=None, deferrable=None, match=None)
-            ]
+                    ["x"],
+                    ["user.y"],
+                    onupdate=None,
+                    ondelete=None,
+                    name="myfk",
+                    initially=None,
+                    deferrable=None,
+                    match=None,
+                )
+            ],
         )
         eq_(
             self.mock_schema.Table.mock_calls,
             [
                 mock.call(
-                    'user', self.mock_schema.MetaData(),
+                    "user",
+                    self.mock_schema.MetaData(),
                     self.mock_schema.Column(),
-                    schema=None
+                    schema=None,
                 ),
                 mock.call(
-                    'tname', self.mock_schema.MetaData(),
+                    "tname",
+                    self.mock_schema.MetaData(),
                     self.mock_schema.Column(),
-                    schema='foo'
+                    schema="foo",
                 ),
                 mock.call().append_constraint(
-                    self.mock_schema.ForeignKeyConstraint())
-            ]
+                    self.mock_schema.ForeignKeyConstraint()
+                ),
+            ],
         )
         eq_(
             batch.impl.operations.impl.mock_calls,
-            [mock.call.add_constraint(
-                self.mock_schema.ForeignKeyConstraint())]
+            [
+                mock.call.add_constraint(
+                    self.mock_schema.ForeignKeyConstraint()
+                )
+            ],
         )
 
     def test_create_uq(self):
         with self._fixture() as batch:
-            batch.create_unique_constraint('uq1', ['a', 'b'])
+            batch.create_unique_constraint("uq1", ["a", "b"])
 
         eq_(
             self.mock_schema.Table().c.__getitem__.mock_calls,
-            [mock.call('a'), mock.call('b')]
+            [mock.call("a"), mock.call("b")],
         )
 
         eq_(
@@ -678,23 +809,22 @@ class BatchAPITest(TestBase):
                 mock.call(
                     self.mock_schema.Table().c.__getitem__(),
                     self.mock_schema.Table().c.__getitem__(),
-                    name='uq1'
+                    name="uq1",
                 )
-            ]
+            ],
         )
         eq_(
             batch.impl.operations.impl.mock_calls,
-            [mock.call.add_constraint(
-                self.mock_schema.UniqueConstraint())]
+            [mock.call.add_constraint(self.mock_schema.UniqueConstraint())],
         )
 
     def test_create_pk(self):
         with self._fixture() as batch:
-            batch.create_primary_key('pk1', ['a', 'b'])
+            batch.create_primary_key("pk1", ["a", "b"])
 
         eq_(
             self.mock_schema.Table().c.__getitem__.mock_calls,
-            [mock.call('a'), mock.call('b')]
+            [mock.call("a"), mock.call("b")],
         )
 
         eq_(
@@ -703,60 +833,53 @@ class BatchAPITest(TestBase):
                 mock.call(
                     self.mock_schema.Table().c.__getitem__(),
                     self.mock_schema.Table().c.__getitem__(),
-                    name='pk1'
+                    name="pk1",
                 )
-            ]
+            ],
         )
         eq_(
             batch.impl.operations.impl.mock_calls,
-            [mock.call.add_constraint(
-                self.mock_schema.PrimaryKeyConstraint())]
+            [
+                mock.call.add_constraint(
+                    self.mock_schema.PrimaryKeyConstraint()
+                )
+            ],
         )
 
     def test_create_check(self):
         expr = text("a > b")
         with self._fixture() as batch:
-            batch.create_check_constraint('ck1', expr)
+            batch.create_check_constraint("ck1", expr)
 
         eq_(
             self.mock_schema.CheckConstraint.mock_calls,
-            [
-                mock.call(
-                    expr, name="ck1"
-                )
-            ]
+            [mock.call(expr, name="ck1")],
         )
         eq_(
             batch.impl.operations.impl.mock_calls,
-            [mock.call.add_constraint(
-                self.mock_schema.CheckConstraint())]
+            [mock.call.add_constraint(self.mock_schema.CheckConstraint())],
         )
 
     def test_drop_constraint(self):
         with self._fixture() as batch:
-            batch.drop_constraint('uq1')
+            batch.drop_constraint("uq1")
 
-        eq_(
-            self.mock_schema.Constraint.mock_calls,
-            [
-                mock.call(name='uq1')
-            ]
-        )
+        eq_(self.mock_schema.Constraint.mock_calls, [mock.call(name="uq1")])
         eq_(
             batch.impl.operations.impl.mock_calls,
-            [mock.call.drop_constraint(self.mock_schema.Constraint())]
+            [mock.call.drop_constraint(self.mock_schema.Constraint())],
         )
 
 
 class CopyFromTest(TestBase):
-
     def _fixture(self):
         self.metadata = MetaData()
         self.table = Table(
-            'foo', self.metadata,
-            Column('id', Integer, primary_key=True),
-            Column('data', String(50)),
-            Column('x', Integer),
+            "foo",
+            self.metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", String(50)),
+            Column("x", Integer),
         )
 
         context = op_fixture(dialect="sqlite", as_sql=True)
@@ -766,148 +889,151 @@ class CopyFromTest(TestBase):
     def test_change_type(self):
         context = self._fixture()
         with self.op.batch_alter_table(
-                "foo", copy_from=self.table) as batch_op:
-            batch_op.alter_column('data', type_=Integer)
+            "foo", copy_from=self.table
+        ) as batch_op:
+            batch_op.alter_column("data", type_=Integer)
         context.assert_(
-            'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, '
-            'data INTEGER, x INTEGER, PRIMARY KEY (id))',
-            'INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, '
-            'CAST(foo.data AS INTEGER) AS anon_1, foo.x FROM foo',
-            'DROP TABLE foo',
-            'ALTER TABLE _alembic_tmp_foo RENAME TO foo'
+            "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, "
+            "data INTEGER, x INTEGER, PRIMARY KEY (id))",
+            "INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, "
+            "CAST(foo.data AS INTEGER) AS anon_1, foo.x FROM foo",
+            "DROP TABLE foo",
+            "ALTER TABLE _alembic_tmp_foo RENAME TO foo",
         )
 
     def test_change_type_from_schematype(self):
         context = self._fixture()
         self.table.append_column(
-            Column('y', Boolean(
-                create_constraint=True, name="ck1")))
+            Column("y", Boolean(create_constraint=True, name="ck1"))
+        )
 
         with self.op.batch_alter_table(
-                "foo", copy_from=self.table) as batch_op:
+            "foo", copy_from=self.table
+        ) as batch_op:
             batch_op.alter_column(
-                'y', type_=Integer,
-                existing_type=Boolean(
-                    create_constraint=True, name="ck1"))
+                "y",
+                type_=Integer,
+                existing_type=Boolean(create_constraint=True, name="ck1"),
+            )
         context.assert_(
-            'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, '
-            'data VARCHAR(50), x INTEGER, y INTEGER, PRIMARY KEY (id))',
-            'INSERT INTO _alembic_tmp_foo (id, data, x, y) SELECT foo.id, '
-            'foo.data, foo.x, CAST(foo.y AS INTEGER) AS anon_1 FROM foo',
-            'DROP TABLE foo',
-            'ALTER TABLE _alembic_tmp_foo RENAME TO foo'
+            "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, "
+            "data VARCHAR(50), x INTEGER, y INTEGER, PRIMARY KEY (id))",
+            "INSERT INTO _alembic_tmp_foo (id, data, x, y) SELECT foo.id, "
+            "foo.data, foo.x, CAST(foo.y AS INTEGER) AS anon_1 FROM foo",
+            "DROP TABLE foo",
+            "ALTER TABLE _alembic_tmp_foo RENAME TO foo",
         )
 
     def test_change_type_to_schematype(self):
         context = self._fixture()
-        self.table.append_column(
-            Column('y', Integer))
+        self.table.append_column(Column("y", Integer))
 
         with self.op.batch_alter_table(
-                "foo", copy_from=self.table) as batch_op:
+            "foo", copy_from=self.table
+        ) as batch_op:
             batch_op.alter_column(
-                'y', existing_type=Integer,
-                type_=Boolean(
-                    create_constraint=True, name="ck1"))
+                "y",
+                existing_type=Integer,
+                type_=Boolean(create_constraint=True, name="ck1"),
+            )
         context.assert_(
-            'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, '
-            'data VARCHAR(50), x INTEGER, y BOOLEAN, PRIMARY KEY (id), '
-            'CONSTRAINT ck1 CHECK (y IN (0, 1)))',
-            'INSERT INTO _alembic_tmp_foo (id, data, x, y) SELECT foo.id, '
-            'foo.data, foo.x, CAST(foo.y AS BOOLEAN) AS anon_1 FROM foo',
-            'DROP TABLE foo',
-            'ALTER TABLE _alembic_tmp_foo RENAME TO foo'
+            "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, "
+            "data VARCHAR(50), x INTEGER, y BOOLEAN, PRIMARY KEY (id), "
+            "CONSTRAINT ck1 CHECK (y IN (0, 1)))",
+            "INSERT INTO _alembic_tmp_foo (id, data, x, y) SELECT foo.id, "
+            "foo.data, foo.x, CAST(foo.y AS BOOLEAN) AS anon_1 FROM foo",
+            "DROP TABLE foo",
+            "ALTER TABLE _alembic_tmp_foo RENAME TO foo",
         )
 
     def test_create_drop_index_w_always(self):
         context = self._fixture()
         with self.op.batch_alter_table(
-                "foo", copy_from=self.table, recreate='always') as batch_op:
-            batch_op.create_index(
-                'ix_data', ['data'], unique=True)
+            "foo", copy_from=self.table, recreate="always"
+        ) as batch_op:
+            batch_op.create_index("ix_data", ["data"], unique=True)
 
         context.assert_(
-            'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, '
-            'data VARCHAR(50), '
-            'x INTEGER, PRIMARY KEY (id))',
-            'INSERT INTO _alembic_tmp_foo (id, data, x) '
-            'SELECT foo.id, foo.data, foo.x FROM foo',
-            'DROP TABLE foo',
-            'ALTER TABLE _alembic_tmp_foo RENAME TO foo',
-            'CREATE UNIQUE INDEX ix_data ON foo (data)',
+            "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, "
+            "data VARCHAR(50), "
+            "x INTEGER, PRIMARY KEY (id))",
+            "INSERT INTO _alembic_tmp_foo (id, data, x) "
+            "SELECT foo.id, foo.data, foo.x FROM foo",
+            "DROP TABLE foo",
+            "ALTER TABLE _alembic_tmp_foo RENAME TO foo",
+            "CREATE UNIQUE INDEX ix_data ON foo (data)",
         )
 
         context.clear_assertions()
 
-        Index('ix_data', self.table.c.data, unique=True)
+        Index("ix_data", self.table.c.data, unique=True)
         with self.op.batch_alter_table(
-                "foo", copy_from=self.table, recreate='always') as batch_op:
-            batch_op.drop_index('ix_data')
+            "foo", copy_from=self.table, recreate="always"
+        ) as batch_op:
+            batch_op.drop_index("ix_data")
 
         context.assert_(
-            'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, '
-            'data VARCHAR(50), x INTEGER, PRIMARY KEY (id))',
-            'INSERT INTO _alembic_tmp_foo (id, data, x) '
-            'SELECT foo.id, foo.data, foo.x FROM foo',
-            'DROP TABLE foo',
-            'ALTER TABLE _alembic_tmp_foo RENAME TO foo'
+            "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, "
+            "data VARCHAR(50), x INTEGER, PRIMARY KEY (id))",
+            "INSERT INTO _alembic_tmp_foo (id, data, x) "
+            "SELECT foo.id, foo.data, foo.x FROM foo",
+            "DROP TABLE foo",
+            "ALTER TABLE _alembic_tmp_foo RENAME TO foo",
         )
 
     def test_create_drop_index_wo_always(self):
         context = self._fixture()
         with self.op.batch_alter_table(
-                "foo", copy_from=self.table) as batch_op:
-            batch_op.create_index(
-                'ix_data', ['data'], unique=True)
+            "foo", copy_from=self.table
+        ) as batch_op:
+            batch_op.create_index("ix_data", ["data"], unique=True)
 
-        context.assert_(
-            'CREATE UNIQUE INDEX ix_data ON foo (data)'
-        )
+        context.assert_("CREATE UNIQUE INDEX ix_data ON foo (data)")
 
         context.clear_assertions()
 
-        Index('ix_data', self.table.c.data, unique=True)
+        Index("ix_data", self.table.c.data, unique=True)
         with self.op.batch_alter_table(
-                "foo", copy_from=self.table) as batch_op:
-            batch_op.drop_index('ix_data')
+            "foo", copy_from=self.table
+        ) as batch_op:
+            batch_op.drop_index("ix_data")
 
-        context.assert_(
-            'DROP INDEX ix_data'
-        )
+        context.assert_("DROP INDEX ix_data")
 
     def test_create_drop_index_w_other_ops(self):
         context = self._fixture()
         with self.op.batch_alter_table(
-                "foo", copy_from=self.table) as batch_op:
-            batch_op.alter_column('data', type_=Integer)
-            batch_op.create_index(
-                'ix_data', ['data'], unique=True)
+            "foo", copy_from=self.table
+        ) as batch_op:
+            batch_op.alter_column("data", type_=Integer)
+            batch_op.create_index("ix_data", ["data"], unique=True)
 
         context.assert_(
-            'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, '
-            'data INTEGER, x INTEGER, PRIMARY KEY (id))',
-            'INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, '
-            'CAST(foo.data AS INTEGER) AS anon_1, foo.x FROM foo',
-            'DROP TABLE foo',
-            'ALTER TABLE _alembic_tmp_foo RENAME TO foo',
-            'CREATE UNIQUE INDEX ix_data ON foo (data)',
+            "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, "
+            "data INTEGER, x INTEGER, PRIMARY KEY (id))",
+            "INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, "
+            "CAST(foo.data AS INTEGER) AS anon_1, foo.x FROM foo",
+            "DROP TABLE foo",
+            "ALTER TABLE _alembic_tmp_foo RENAME TO foo",
+            "CREATE UNIQUE INDEX ix_data ON foo (data)",
         )
 
         context.clear_assertions()
 
-        Index('ix_data', self.table.c.data, unique=True)
+        Index("ix_data", self.table.c.data, unique=True)
         with self.op.batch_alter_table(
-                "foo", copy_from=self.table) as batch_op:
-            batch_op.drop_index('ix_data')
-            batch_op.alter_column('data', type_=String)
+            "foo", copy_from=self.table
+        ) as batch_op:
+            batch_op.drop_index("ix_data")
+            batch_op.alter_column("data", type_=String)
 
         context.assert_(
-            'CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, '
-            'data VARCHAR, x INTEGER, PRIMARY KEY (id))',
-            'INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, '
-            'foo.data, foo.x FROM foo',
-            'DROP TABLE foo',
-            'ALTER TABLE _alembic_tmp_foo RENAME TO foo'
+            "CREATE TABLE _alembic_tmp_foo (id INTEGER NOT NULL, "
+            "data VARCHAR, x INTEGER, PRIMARY KEY (id))",
+            "INSERT INTO _alembic_tmp_foo (id, data, x) SELECT foo.id, "
+            "foo.data, foo.x FROM foo",
+            "DROP TABLE foo",
+            "ALTER TABLE _alembic_tmp_foo RENAME TO foo",
         )
 
 
@@ -918,11 +1044,12 @@ class BatchRoundTripTest(TestBase):
         self.conn = config.db.connect()
         self.metadata = MetaData()
         t1 = Table(
-            'foo', self.metadata,
-            Column('id', Integer, primary_key=True),
-            Column('data', String(50)),
-            Column('x', Integer),
-            mysql_engine='InnoDB'
+            "foo",
+            self.metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", String(50)),
+            Column("x", Integer),
+            mysql_engine="InnoDB",
         )
         t1.create(self.conn)
 
@@ -933,8 +1060,8 @@ class BatchRoundTripTest(TestBase):
                 {"id": 2, "data": "22", "x": 6},
                 {"id": 3, "data": "8.5", "x": 7},
                 {"id": 4, "data": "9.46", "x": 8},
-                {"id": 5, "data": "d5", "x": 9}
-            ]
+                {"id": 5, "data": "d5", "x": 9},
+            ],
         )
         context = MigrationContext.configure(self.conn)
         self.op = Operations(context)
@@ -949,80 +1076,75 @@ class BatchRoundTripTest(TestBase):
 
     def _no_pk_fixture(self):
         nopk = Table(
-            'nopk', self.metadata,
-            Column('a', Integer),
-            Column('b', Integer),
-            Column('c', Integer),
-            mysql_engine='InnoDB'
+            "nopk",
+            self.metadata,
+            Column("a", Integer),
+            Column("b", Integer),
+            Column("c", Integer),
+            mysql_engine="InnoDB",
         )
         nopk.create(self.conn)
         self.conn.execute(
-            nopk.insert(),
-            [
-                {"a": 1, "b": 2, "c": 3},
-                {"a": 2, "b": 4, "c": 5},
-            ]
-
+            nopk.insert(), [{"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 4, "c": 5}]
         )
         return nopk
 
     def _table_w_index_fixture(self):
         t = Table(
-            't_w_ix', self.metadata,
-            Column('id', Integer, primary_key=True),
-            Column('thing', Integer),
-            Column('data', String(20)),
+            "t_w_ix",
+            self.metadata,
+            Column("id", Integer, primary_key=True),
+            Column("thing", Integer),
+            Column("data", String(20)),
         )
-        Index('ix_thing', t.c.thing)
+        Index("ix_thing", t.c.thing)
         t.create(self.conn)
         return t
 
     def _boolean_fixture(self):
         t = Table(
-            'hasbool', self.metadata,
-            Column('x', Boolean(create_constraint=True, name='ck1')),
-            Column('y', Integer)
+            "hasbool",
+            self.metadata,
+            Column("x", Boolean(create_constraint=True, name="ck1")),
+            Column("y", Integer),
         )
         t.create(self.conn)
 
     def _timestamp_fixture(self):
-        t = Table(
-            'hasts', self.metadata,
-            Column('x', DateTime()),
-        )
+        t = Table("hasts", self.metadata, Column("x", DateTime()))
         t.create(self.conn)
         return t
 
     def _int_to_boolean_fixture(self):
-        t = Table(
-            'hasbool', self.metadata,
-            Column('x', Integer)
-        )
+        t = Table("hasbool", self.metadata, Column("x", Integer))
         t.create(self.conn)
 
     def test_change_type_boolean_to_int(self):
         self._boolean_fixture()
-        with self.op.batch_alter_table(
-                "hasbool"
-        ) as batch_op:
+        with self.op.batch_alter_table("hasbool") as batch_op:
             batch_op.alter_column(
-                'x', type_=Integer, existing_type=Boolean(
-                    create_constraint=True, name='ck1'))
+                "x",
+                type_=Integer,
+                existing_type=Boolean(create_constraint=True, name="ck1"),
+            )
         insp = Inspector.from_engine(config.db)
 
         eq_(
-            [c['type']._type_affinity for c in insp.get_columns('hasbool')
-             if c['name'] == 'x'],
-            [Integer]
+            [
+                c["type"]._type_affinity
+                for c in insp.get_columns("hasbool")
+                if c["name"] == "x"
+            ],
+            [Integer],
         )
 
     def test_no_net_change_timestamp(self):
         t = self._timestamp_fixture()
 
         import datetime
+
         self.conn.execute(
-            t.insert(),
-            {"x": datetime.datetime(2012, 5, 18, 15, 32, 5)}
+            t.insert(), {"x": datetime.datetime(2012, 5, 18, 15, 32, 5)}
         )
 
         with self.op.batch_alter_table("hasts") as batch_op:
@@ -1030,69 +1152,71 @@ class BatchRoundTripTest(TestBase):
 
         eq_(
             self.conn.execute(select([t.c.x])).fetchall(),
-            [(datetime.datetime(2012, 5, 18, 15, 32, 5),)]
+            [(datetime.datetime(2012, 5, 18, 15, 32, 5),)],
         )
 
     def test_drop_col_schematype(self):
         self._boolean_fixture()
-        with self.op.batch_alter_table(
-                "hasbool"
-        ) as batch_op:
-            batch_op.drop_column('x')
+        with self.op.batch_alter_table("hasbool") as batch_op:
+            batch_op.drop_column("x")
         insp = Inspector.from_engine(config.db)
 
-        assert 'x' not in (c['name'] for c in insp.get_columns('hasbool'))
+        assert "x" not in (c["name"] for c in insp.get_columns("hasbool"))
 
     def test_change_type_int_to_boolean(self):
         self._int_to_boolean_fixture()
-        with self.op.batch_alter_table(
-                "hasbool"
-        ) as batch_op:
+        with self.op.batch_alter_table("hasbool") as batch_op:
             batch_op.alter_column(
-                'x', type_=Boolean(create_constraint=True, name='ck1'))
+                "x", type_=Boolean(create_constraint=True, name="ck1")
+            )
         insp = Inspector.from_engine(config.db)
 
         if exclusions.against(config, "sqlite"):
             eq_(
-                [c['type']._type_affinity for
-                 c in insp.get_columns('hasbool') if c['name'] == 'x'],
-                [Boolean]
+                [
+                    c["type"]._type_affinity
+                    for c in insp.get_columns("hasbool")
+                    if c["name"] == "x"
+                ],
+                [Boolean],
             )
         elif exclusions.against(config, "mysql"):
             eq_(
-                [c['type']._type_affinity for
-                 c in insp.get_columns('hasbool') if c['name'] == 'x'],
-                [Integer]
+                [
+                    c["type"]._type_affinity
+                    for c in insp.get_columns("hasbool")
+                    if c["name"] == "x"
+                ],
+                [Integer],
             )
 
     def tearDown(self):
         self.metadata.drop_all(self.conn)
         self.conn.close()
 
-    def _assert_data(self, data, tablename='foo'):
+    def _assert_data(self, data, tablename="foo"):
         eq_(
-            [dict(row) for row
-             in self.conn.execute("select * from %s" % tablename)],
-            data
+            [
+                dict(row)
+                for row in self.conn.execute("select * from %s" % tablename)
+            ],
+            data,
         )
 
     def test_ix_existing(self):
         self._table_w_index_fixture()
 
         with self.op.batch_alter_table("t_w_ix") as batch_op:
-            batch_op.alter_column('data', type_=String(30))
+            batch_op.alter_column("data", type_=String(30))
             batch_op.create_index("ix_data", ["data"])
 
         insp = Inspector.from_engine(config.db)
         eq_(
             set(
-                (ix['name'], tuple(ix['column_names'])) for ix in
-                insp.get_indexes('t_w_ix')
+                (ix["name"], tuple(ix["column_names"]))
+                for ix in insp.get_indexes("t_w_ix")
             ),
-            set([
-                ('ix_data', ('data',)),
-                ('ix_thing', ('thing', ))
-            ])
+            set([("ix_data", ("data",)), ("ix_thing", ("thing",))]),
         )
 
     def test_fk_points_to_me_auto(self):
@@ -1108,31 +1232,39 @@ class BatchRoundTripTest(TestBase):
     @exclusions.only_on("sqlite")
     @exclusions.fails(
         "intentionally asserting that this "
-        "doesn't work w/ pragma foreign keys")
+        "doesn't work w/ pragma foreign keys"
+    )
     def test_fk_points_to_me_sqlite_refinteg(self):
         with self._sqlite_referential_integrity():
             self._test_fk_points_to_me("auto")
 
     def _test_fk_points_to_me(self, recreate):
         bar = Table(
-            'bar', self.metadata,
-            Column('id', Integer, primary_key=True),
-            Column('foo_id', Integer, ForeignKey('foo.id')),
-            mysql_engine='InnoDB'
+            "bar",
+            self.metadata,
+            Column("id", Integer, primary_key=True),
+            Column("foo_id", Integer, ForeignKey("foo.id")),
+            mysql_engine="InnoDB",
         )
         bar.create(self.conn)
-        self.conn.execute(bar.insert(), {'id': 1, 'foo_id': 3})
+        self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3})
 
         with self.op.batch_alter_table("foo", recreate=recreate) as batch_op:
             batch_op.alter_column(
-                'data', new_column_name='newdata', existing_type=String(50))
+                "data", new_column_name="newdata", existing_type=String(50)
+            )
 
         insp = Inspector.from_engine(self.conn)
         eq_(
-            [(key['referred_table'],
-             key['referred_columns'], key['constrained_columns'])
-             for key in insp.get_foreign_keys('bar')],
-            [('foo', ['id'], ['foo_id'])]
+            [
+                (
+                    key["referred_table"],
+                    key["referred_columns"],
+                    key["constrained_columns"],
+                )
+                for key in insp.get_foreign_keys("bar")
+            ],
+            [("foo", ["id"], ["foo_id"])],
         )
 
     def test_selfref_fk_auto(self):
@@ -1145,100 +1277,112 @@ class BatchRoundTripTest(TestBase):
     @exclusions.only_on("sqlite")
     @exclusions.fails(
         "intentionally asserting that this "
-        "doesn't work w/ pragma foreign keys")
+        "doesn't work w/ pragma foreign keys"
+    )
     def test_selfref_fk_sqlite_refinteg(self):
         with self._sqlite_referential_integrity():
             self._test_selfref_fk("auto")
 
     def _test_selfref_fk(self, recreate):
         bar = Table(
-            'bar', self.metadata,
-            Column('id', Integer, primary_key=True),
-            Column('bar_id', Integer, ForeignKey('bar.id')),
-            Column('data', String(50)),
-            mysql_engine='InnoDB'
+            "bar",
+            self.metadata,
+            Column("id", Integer, primary_key=True),
+            Column("bar_id", Integer, ForeignKey("bar.id")),
+            Column("data", String(50)),
+            mysql_engine="InnoDB",
         )
         bar.create(self.conn)
-        self.conn.execute(bar.insert(), {'id': 1, 'data': 'x', 'bar_id': None})
-        self.conn.execute(bar.insert(), {'id': 2, 'data': 'y', 'bar_id': 1})
+        self.conn.execute(bar.insert(), {"id": 1, "data": "x", "bar_id": None})
+        self.conn.execute(bar.insert(), {"id": 2, "data": "y", "bar_id": 1})
 
         with self.op.batch_alter_table("bar", recreate=recreate) as batch_op:
             batch_op.alter_column(
-                'data', new_column_name='newdata', existing_type=String(50))
+                "data", new_column_name="newdata", existing_type=String(50)
+            )
 
         insp = Inspector.from_engine(self.conn)
 
         insp = Inspector.from_engine(self.conn)
         eq_(
-            [(key['referred_table'],
-             key['referred_columns'], key['constrained_columns'])
-             for key in insp.get_foreign_keys('bar')],
-            [('bar', ['id'], ['bar_id'])]
+            [
+                (
+                    key["referred_table"],
+                    key["referred_columns"],
+                    key["constrained_columns"],
+                )
+                for key in insp.get_foreign_keys("bar")
+            ],
+            [("bar", ["id"], ["bar_id"])],
         )
 
     def test_change_type(self):
         with self.op.batch_alter_table("foo") as batch_op:
-            batch_op.alter_column('data', type_=Integer)
+            batch_op.alter_column("data", type_=Integer)
 
-        self._assert_data([
-            {"id": 1, "data": 0, "x": 5},
-            {"id": 2, "data": 22, "x": 6},
-            {"id": 3, "data": 8, "x": 7},
-            {"id": 4, "data": 9, "x": 8},
-            {"id": 5, "data": 0, "x": 9}
-        ])
+        self._assert_data(
+            [
+                {"id": 1, "data": 0, "x": 5},
+                {"id": 2, "data": 22, "x": 6},
+                {"id": 3, "data": 8, "x": 7},
+                {"id": 4, "data": 9, "x": 8},
+                {"id": 5, "data": 0, "x": 9},
+            ]
+        )
 
     def test_drop_column(self):
         with self.op.batch_alter_table("foo") as batch_op:
-            batch_op.drop_column('data')
+            batch_op.drop_column("data")
 
-        self._assert_data([
-            {"id": 1, "x": 5},
-            {"id": 2, "x": 6},
-            {"id": 3, "x": 7},
-            {"id": 4, "x": 8},
-            {"id": 5, "x": 9}
-        ])
+        self._assert_data(
+            [
+                {"id": 1, "x": 5},
+                {"id": 2, "x": 6},
+                {"id": 3, "x": 7},
+                {"id": 4, "x": 8},
+                {"id": 5, "x": 9},
+            ]
+        )
 
     def test_drop_pk_col_readd_col(self):
         # drop a column, add it back without primary_key=True, should no
         # longer be in the constraint
         with self.op.batch_alter_table("foo") as batch_op:
-            batch_op.drop_column('id')
-            batch_op.add_column(Column('id', Integer))
+            batch_op.drop_column("id")
+            batch_op.add_column(Column("id", Integer))
 
-        pk_const = Inspector.from_engine(self.conn).get_pk_constraint('foo')
-        eq_(pk_const['constrained_columns'], [])
+        pk_const = Inspector.from_engine(self.conn).get_pk_constraint("foo")
+        eq_(pk_const["constrained_columns"], [])
 
     def test_drop_pk_col_readd_pk_col(self):
         # drop a column, add it back with primary_key=True, should remain
         with self.op.batch_alter_table("foo") as batch_op:
-            batch_op.drop_column('id')
-            batch_op.add_column(Column('id', Integer, primary_key=True))
+            batch_op.drop_column("id")
+            batch_op.add_column(Column("id", Integer, primary_key=True))
 
-        pk_const = Inspector.from_engine(self.conn).get_pk_constraint('foo')
-        eq_(pk_const['constrained_columns'], ['id'])
+        pk_const = Inspector.from_engine(self.conn).get_pk_constraint("foo")
+        eq_(pk_const["constrained_columns"], ["id"])
 
     def test_drop_pk_col_readd_col_also_pk_const(self):
         # drop a column, add it back without primary_key=True, but then
         # also make anew PK constraint that includes it, should remain
         with self.op.batch_alter_table("foo") as batch_op:
-            batch_op.drop_column('id')
-            batch_op.add_column(Column('id', Integer))
-            batch_op.create_primary_key('newpk', ['id'])
+            batch_op.drop_column("id")
+            batch_op.add_column(Column("id", Integer))
+            batch_op.create_primary_key("newpk", ["id"])
 
-        pk_const = Inspector.from_engine(self.conn).get_pk_constraint('foo')
-        eq_(pk_const['constrained_columns'], ['id'])
+        pk_const = Inspector.from_engine(self.conn).get_pk_constraint("foo")
+        eq_(pk_const["constrained_columns"], ["id"])
 
     def test_add_pk_constraint(self):
         self._no_pk_fixture()
         with self.op.batch_alter_table("nopk", recreate="always") as batch_op:
-            batch_op.create_primary_key('newpk', ['a', 'b'])
+            batch_op.create_primary_key("newpk", ["a", "b"])
 
-        pk_const = Inspector.from_engine(self.conn).get_pk_constraint('nopk')
+        pk_const = Inspector.from_engine(self.conn).get_pk_constraint("nopk")
         with config.requirements.reflects_pk_names.fail_if():
-            eq_(pk_const['name'], 'newpk')
-        eq_(pk_const['constrained_columns'], ['a', 'b'])
+            eq_(pk_const["name"], "newpk")
+        eq_(pk_const["constrained_columns"], ["a", "b"])
 
     @config.requirements.check_constraints_w_enforcement
     def test_add_ck_constraint(self):
@@ -1247,203 +1391,219 @@ class BatchRoundTripTest(TestBase):
 
         # we dont support reflection of CHECK constraints
         # so test this by just running invalid data in
-        foo = self.metadata.tables['foo']
+        foo = self.metadata.tables["foo"]
 
         assert_raises_message(
             exc.IntegrityError,
             "newck",
             self.conn.execute,
-            foo.insert(), {"id": 6, "data": 5, "x": -2}
+            foo.insert(),
+            {"id": 6, "data": 5, "x": -2},
         )
 
     @config.requirements.sqlalchemy_094
     @config.requirements.unnamed_constraints
     def test_drop_foreign_key(self):
         bar = Table(
-            'bar', self.metadata,
-            Column('id', Integer, primary_key=True),
-            Column('foo_id', Integer, ForeignKey('foo.id')),
-            mysql_engine='InnoDB'
+            "bar",
+            self.metadata,
+            Column("id", Integer, primary_key=True),
+            Column("foo_id", Integer, ForeignKey("foo.id")),
+            mysql_engine="InnoDB",
         )
         bar.create(self.conn)
-        self.conn.execute(bar.insert(), {'id': 1, 'foo_id': 3})
+        self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3})
 
         naming_convention = {
-            "fk":
-            "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
+            "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s"
         }
         with self.op.batch_alter_table(
-                "bar", naming_convention=naming_convention) as batch_op:
-            batch_op.drop_constraint(
-                "fk_bar_foo_id_foo", type_="foreignkey")
-        eq_(
-            Inspector.from_engine(self.conn).get_foreign_keys('bar'),
-            []
-        )
+            "bar", naming_convention=naming_convention
+        ) as batch_op:
+            batch_op.drop_constraint("fk_bar_foo_id_foo", type_="foreignkey")
+        eq_(Inspector.from_engine(self.conn).get_foreign_keys("bar"), [])
 
     def test_drop_column_fk_recreate(self):
-        with self.op.batch_alter_table("foo", recreate='always') as batch_op:
-            batch_op.drop_column('data')
+        with self.op.batch_alter_table("foo", recreate="always") as batch_op:
+            batch_op.drop_column("data")
 
-        self._assert_data([
-            {"id": 1, "x": 5},
-            {"id": 2, "x": 6},
-            {"id": 3, "x": 7},
-            {"id": 4, "x": 8},
-            {"id": 5, "x": 9}
-        ])
+        self._assert_data(
+            [
+                {"id": 1, "x": 5},
+                {"id": 2, "x": 6},
+                {"id": 3, "x": 7},
+                {"id": 4, "x": 8},
+                {"id": 5, "x": 9},
+            ]
+        )
 
     def test_rename_column(self):
         with self.op.batch_alter_table("foo") as batch_op:
-            batch_op.alter_column('x', new_column_name='y')
+            batch_op.alter_column("x", new_column_name="y")
 
-        self._assert_data([
-            {"id": 1, "data": "d1", "y": 5},
-            {"id": 2, "data": "22", "y": 6},
-            {"id": 3, "data": "8.5", "y": 7},
-            {"id": 4, "data": "9.46", "y": 8},
-            {"id": 5, "data": "d5", "y": 9}
-        ])
+        self._assert_data(
+            [
+                {"id": 1, "data": "d1", "y": 5},
+                {"id": 2, "data": "22", "y": 6},
+                {"id": 3, "data": "8.5", "y": 7},
+                {"id": 4, "data": "9.46", "y": 8},
+                {"id": 5, "data": "d5", "y": 9},
+            ]
+        )
 
     def test_rename_column_boolean(self):
         bar = Table(
-            'bar', self.metadata,
-            Column('id', Integer, primary_key=True),
-            Column('flag', Boolean()),
-            mysql_engine='InnoDB'
+            "bar",
+            self.metadata,
+            Column("id", Integer, primary_key=True),
+            Column("flag", Boolean()),
+            mysql_engine="InnoDB",
         )
         bar.create(self.conn)
-        self.conn.execute(bar.insert(), {'id': 1, 'flag': True})
-        self.conn.execute(bar.insert(), {'id': 2, 'flag': False})
+        self.conn.execute(bar.insert(), {"id": 1, "flag": True})
+        self.conn.execute(bar.insert(), {"id": 2, "flag": False})
 
-        with self.op.batch_alter_table(
-            "bar"
-        ) as batch_op:
+        with self.op.batch_alter_table("bar") as batch_op:
             batch_op.alter_column(
-                'flag', new_column_name='bflag', existing_type=Boolean)
+                "flag", new_column_name="bflag", existing_type=Boolean
+            )
 
-        self._assert_data([
-            {"id": 1, 'bflag': True},
-            {"id": 2, 'bflag': False},
-        ], 'bar')
+        self._assert_data(
+            [{"id": 1, "bflag": True}, {"id": 2, "bflag": False}], "bar"
+        )
 
     @config.requirements.non_native_boolean
     def test_rename_column_non_native_boolean_no_ck(self):
         bar = Table(
-            'bar', self.metadata,
-            Column('id', Integer, primary_key=True),
-            Column('flag', Boolean(create_constraint=False)),
-            mysql_engine='InnoDB'
+            "bar",
+            self.metadata,
+            Column("id", Integer, primary_key=True),
+            Column("flag", Boolean(create_constraint=False)),
+            mysql_engine="InnoDB",
         )
         bar.create(self.conn)
-        self.conn.execute(bar.insert(), {'id': 1, 'flag': True})
-        self.conn.execute(bar.insert(), {'id': 2, 'flag': False})
+        self.conn.execute(bar.insert(), {"id": 1, "flag": True})
+        self.conn.execute(bar.insert(), {"id": 2, "flag": False})
         self.conn.execute(
             # override Boolean type which as of 1.1 coerces numerics
             # to 1/0
             text("insert into bar (id, flag) values (:id, :flag)"),
-            {'id': 3, 'flag': 5})
+            {"id": 3, "flag": 5},
+        )
 
         with self.op.batch_alter_table(
             "bar",
-            reflect_args=[Column('flag', Boolean(create_constraint=False))]
+            reflect_args=[Column("flag", Boolean(create_constraint=False))],
         ) as batch_op:
             batch_op.alter_column(
-                'flag', new_column_name='bflag', existing_type=Boolean)
+                "flag", new_column_name="bflag", existing_type=Boolean
+            )
 
-        self._assert_data([
-            {"id": 1, 'bflag': True},
-            {"id": 2, 'bflag': False},
-            {'id': 3, 'bflag': 5}
-        ], 'bar')
+        self._assert_data(
+            [
+                {"id": 1, "bflag": True},
+                {"id": 2, "bflag": False},
+                {"id": 3, "bflag": 5},
+            ],
+            "bar",
+        )
 
     def test_drop_column_pk(self):
         with self.op.batch_alter_table("foo") as batch_op:
-            batch_op.drop_column('id')
+            batch_op.drop_column("id")
 
-        self._assert_data([
-            {"data": "d1", "x": 5},
-            {"data": "22", "x": 6},
-            {"data": "8.5", "x": 7},
-            {"data": "9.46", "x": 8},
-            {"data": "d5", "x": 9}
-        ])
+        self._assert_data(
+            [
+                {"data": "d1", "x": 5},
+                {"data": "22", "x": 6},
+                {"data": "8.5", "x": 7},
+                {"data": "9.46", "x": 8},
+                {"data": "d5", "x": 9},
+            ]
+        )
 
     def test_rename_column_pk(self):
         with self.op.batch_alter_table("foo") as batch_op:
-            batch_op.alter_column('id', new_column_name='ident')
+            batch_op.alter_column("id", new_column_name="ident")
 
-        self._assert_data([
-            {"ident": 1, "data": "d1", "x": 5},
-            {"ident": 2, "data": "22", "x": 6},
-            {"ident": 3, "data": "8.5", "x": 7},
-            {"ident": 4, "data": "9.46", "x": 8},
-            {"ident": 5, "data": "d5", "x": 9}
-        ])
+        self._assert_data(
+            [
+                {"ident": 1, "data": "d1", "x": 5},
+                {"ident": 2, "data": "22", "x": 6},
+                {"ident": 3, "data": "8.5", "x": 7},
+                {"ident": 4, "data": "9.46", "x": 8},
+                {"ident": 5, "data": "d5", "x": 9},
+            ]
+        )
 
     def test_add_column_auto(self):
         # note this uses ALTER
         with self.op.batch_alter_table("foo") as batch_op:
             batch_op.add_column(
-                Column('data2', String(50), server_default='hi'))
+                Column("data2", String(50), server_default="hi")
+            )
 
-        self._assert_data([
-            {"id": 1, "data": "d1", "x": 5, 'data2': 'hi'},
-            {"id": 2, "data": "22", "x": 6, 'data2': 'hi'},
-            {"id": 3, "data": "8.5", "x": 7, 'data2': 'hi'},
-            {"id": 4, "data": "9.46", "x": 8, 'data2': 'hi'},
-            {"id": 5, "data": "d5", "x": 9, 'data2': 'hi'}
-        ])
+        self._assert_data(
+            [
+                {"id": 1, "data": "d1", "x": 5, "data2": "hi"},
+                {"id": 2, "data": "22", "x": 6, "data2": "hi"},
+                {"id": 3, "data": "8.5", "x": 7, "data2": "hi"},
+                {"id": 4, "data": "9.46", "x": 8, "data2": "hi"},
+                {"id": 5, "data": "d5", "x": 9, "data2": "hi"},
+            ]
+        )
 
     def test_add_column_recreate(self):
-        with self.op.batch_alter_table("foo", recreate='always') as batch_op:
+        with self.op.batch_alter_table("foo", recreate="always") as batch_op:
             batch_op.add_column(
-                Column('data2', String(50), server_default='hi'))
+                Column("data2", String(50), server_default="hi")
+            )
 
-        self._assert_data([
-            {"id": 1, "data": "d1", "x": 5, 'data2': 'hi'},
-            {"id": 2, "data": "22", "x": 6, 'data2': 'hi'},
-            {"id": 3, "data": "8.5", "x": 7, 'data2': 'hi'},
-            {"id": 4, "data": "9.46", "x": 8, 'data2': 'hi'},
-            {"id": 5, "data": "d5", "x": 9, 'data2': 'hi'}
-        ])
+        self._assert_data(
+            [
+                {"id": 1, "data": "d1", "x": 5, "data2": "hi"},
+                {"id": 2, "data": "22", "x": 6, "data2": "hi"},
+                {"id": 3, "data": "8.5", "x": 7, "data2": "hi"},
+                {"id": 4, "data": "9.46", "x": 8, "data2": "hi"},
+                {"id": 5, "data": "d5", "x": 9, "data2": "hi"},
+            ]
+        )
 
     def test_create_drop_index(self):
         insp = Inspector.from_engine(config.db)
-        eq_(
-            insp.get_indexes('foo'), []
-        )
+        eq_(insp.get_indexes("foo"), [])
 
-        with self.op.batch_alter_table("foo", recreate='always') as batch_op:
-            batch_op.create_index(
-                'ix_data', ['data'], unique=True)
+        with self.op.batch_alter_table("foo", recreate="always") as batch_op:
+            batch_op.create_index("ix_data", ["data"], unique=True)
 
-        self._assert_data([
-            {"id": 1, "data": "d1", "x": 5},
-            {"id": 2, "data": "22", "x": 6},
-            {"id": 3, "data": "8.5", "x": 7},
-            {"id": 4, "data": "9.46", "x": 8},
-            {"id": 5, "data": "d5", "x": 9}
-        ])
+        self._assert_data(
+            [
+                {"id": 1, "data": "d1", "x": 5},
+                {"id": 2, "data": "22", "x": 6},
+                {"id": 3, "data": "8.5", "x": 7},
+                {"id": 4, "data": "9.46", "x": 8},
+                {"id": 5, "data": "d5", "x": 9},
+            ]
+        )
 
         insp = Inspector.from_engine(config.db)
         eq_(
             [
-                dict(unique=ix['unique'],
-                     name=ix['name'],
-                     column_names=ix['column_names'])
-                for ix in insp.get_indexes('foo')
+                dict(
+                    unique=ix["unique"],
+                    name=ix["name"],
+                    column_names=ix["column_names"],
+                )
+                for ix in insp.get_indexes("foo")
             ],
-            [{'unique': True, 'name': 'ix_data', 'column_names': ['data']}]
+            [{"unique": True, "name": "ix_data", "column_names": ["data"]}],
         )
 
-        with self.op.batch_alter_table("foo", recreate='always') as batch_op:
-            batch_op.drop_index('ix_data')
+        with self.op.batch_alter_table("foo", recreate="always") as batch_op:
+            batch_op.drop_index("ix_data")
 
         insp = Inspector.from_engine(config.db)
-        eq_(
-            insp.get_indexes('foo'), []
-        )
+        eq_(insp.get_indexes("foo"), [])
 
 
 class BatchRoundTripMySQLTest(BatchRoundTripTest):
@@ -1496,7 +1656,8 @@ class BatchRoundTripPostgresqlTest(BatchRoundTripTest):
     @exclusions.fails()
     def test_drop_pk_col_readd_pk_col(self):
         super(
-            BatchRoundTripPostgresqlTest, self).test_drop_pk_col_readd_pk_col()
+            BatchRoundTripPostgresqlTest, self
+        ).test_drop_pk_col_readd_pk_col()
 
     @exclusions.fails()
     def test_drop_pk_col_readd_col_also_pk_const(self):
@@ -1513,10 +1674,12 @@ class BatchRoundTripPostgresqlTest(BatchRoundTripTest):
 
     @exclusions.fails()
     def test_change_type_int_to_boolean(self):
-        super(BatchRoundTripPostgresqlTest, self).\
-            test_change_type_int_to_boolean()
+        super(
+            BatchRoundTripPostgresqlTest, self
+        ).test_change_type_int_to_boolean()
 
     @exclusions.fails()
     def test_change_type_boolean_to_int(self):
-        super(BatchRoundTripPostgresqlTest, self).\
-            test_change_type_boolean_to_int()
+        super(
+            BatchRoundTripPostgresqlTest, self
+        ).test_change_type_boolean_to_int()
index 26556302d3c9ad7dd337be342070c651a5236258..220719a601358d761a5e29770dacc7384a58f656 100644 (file)
@@ -13,127 +13,131 @@ from alembic.testing import eq_, assert_raises_message, config
 class BulkInsertTest(TestBase):
     def _table_fixture(self, dialect, as_sql):
         context = op_fixture(dialect, as_sql)
-        t1 = table("ins_table",
-                   column('id', Integer),
-                   column('v1', String()),
-                   column('v2', String()),
-                   )
+        t1 = table(
+            "ins_table",
+            column("id", Integer),
+            column("v1", String()),
+            column("v2", String()),
+        )
         return context, t1
 
     def _big_t_table_fixture(self, dialect, as_sql):
         context = op_fixture(dialect, as_sql)
-        t1 = Table("ins_table", MetaData(),
-                   Column('id', Integer, primary_key=True),
-                   Column('v1', String()),
-                   Column('v2', String()),
-                   )
+        t1 = Table(
+            "ins_table",
+            MetaData(),
+            Column("id", Integer, primary_key=True),
+            Column("v1", String()),
+            Column("v2", String()),
+        )
         return context, t1
 
     def _test_bulk_insert(self, dialect, as_sql):
         context, t1 = self._table_fixture(dialect, as_sql)
 
-        op.bulk_insert(t1, [
-            {'id': 1, 'v1': 'row v1', 'v2': 'row v5'},
-            {'id': 2, 'v1': 'row v2', 'v2': 'row v6'},
-            {'id': 3, 'v1': 'row v3', 'v2': 'row v7'},
-            {'id': 4, 'v1': 'row v4', 'v2': 'row v8'},
-        ])
+        op.bulk_insert(
+            t1,
+            [
+                {"id": 1, "v1": "row v1", "v2": "row v5"},
+                {"id": 2, "v1": "row v2", "v2": "row v6"},
+                {"id": 3, "v1": "row v3", "v2": "row v7"},
+                {"id": 4, "v1": "row v4", "v2": "row v8"},
+            ],
+        )
         return context
 
     def _test_bulk_insert_single(self, dialect, as_sql):
         context, t1 = self._table_fixture(dialect, as_sql)
 
-        op.bulk_insert(t1, [
-            {'id': 1, 'v1': 'row v1', 'v2': 'row v5'},
-        ])
+        op.bulk_insert(t1, [{"id": 1, "v1": "row v1", "v2": "row v5"}])
         return context
 
     def _test_bulk_insert_single_bigt(self, dialect, as_sql):
         context, t1 = self._big_t_table_fixture(dialect, as_sql)
 
-        op.bulk_insert(t1, [
-            {'id': 1, 'v1': 'row v1', 'v2': 'row v5'},
-        ])
+        op.bulk_insert(t1, [{"id": 1, "v1": "row v1", "v2": "row v5"}])
         return context
 
     def test_bulk_insert(self):
-        context = self._test_bulk_insert('default', False)
+        context = self._test_bulk_insert("default", False)
         context.assert_(
-            'INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)'
+            "INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)"
         )
 
     def test_bulk_insert_wrong_cols(self):
-        context = op_fixture('postgresql')
-        t1 = table("ins_table",
-                   column('id', Integer),
-                   column('v1', String()),
-                   column('v2', String()),
-                   )
-        op.bulk_insert(t1, [
-            {'v1': 'row v1', },
-        ])
+        context = op_fixture("postgresql")
+        t1 = table(
+            "ins_table",
+            column("id", Integer),
+            column("v1", String()),
+            column("v2", String()),
+        )
+        op.bulk_insert(t1, [{"v1": "row v1"}])
         context.assert_(
-            'INSERT INTO ins_table (id, v1, v2) VALUES (%(id)s, %(v1)s, %(v2)s)'
+            "INSERT INTO ins_table (id, v1, v2) VALUES (%(id)s, %(v1)s, %(v2)s)"
         )
 
     def test_bulk_insert_no_rows(self):
-        context, t1 = self._table_fixture('default', False)
+        context, t1 = self._table_fixture("default", False)
 
         op.bulk_insert(t1, [])
         context.assert_()
 
     def test_bulk_insert_pg(self):
-        context = self._test_bulk_insert('postgresql', False)
+        context = self._test_bulk_insert("postgresql", False)
         context.assert_(
-            'INSERT INTO ins_table (id, v1, v2) '
-            'VALUES (%(id)s, %(v1)s, %(v2)s)'
+            "INSERT INTO ins_table (id, v1, v2) "
+            "VALUES (%(id)s, %(v1)s, %(v2)s)"
         )
 
     def test_bulk_insert_pg_single(self):
-        context = self._test_bulk_insert_single('postgresql', False)
+        context = self._test_bulk_insert_single("postgresql", False)
         context.assert_(
-            'INSERT INTO ins_table (id, v1, v2) '
-            'VALUES (%(id)s, %(v1)s, %(v2)s)'
+            "INSERT INTO ins_table (id, v1, v2) "
+            "VALUES (%(id)s, %(v1)s, %(v2)s)"
         )
 
     def test_bulk_insert_pg_single_as_sql(self):
-        context = self._test_bulk_insert_single('postgresql', True)
+        context = self._test_bulk_insert_single("postgresql", True)
         context.assert_(
             "INSERT INTO ins_table (id, v1, v2) VALUES (1, 'row v1', 'row v5')"
         )
 
     def test_bulk_insert_pg_single_big_t_as_sql(self):
-        context = self._test_bulk_insert_single_bigt('postgresql', True)
+        context = self._test_bulk_insert_single_bigt("postgresql", True)
         context.assert_(
             "INSERT INTO ins_table (id, v1, v2) "
             "VALUES (1, 'row v1', 'row v5')"
         )
 
     def test_bulk_insert_mssql(self):
-        context = self._test_bulk_insert('mssql', False)
+        context = self._test_bulk_insert("mssql", False)
         context.assert_(
-            'INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)'
+            "INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)"
         )
 
     def test_bulk_insert_inline_literal_as_sql(self):
-        context = op_fixture('postgresql', True)
+        context = op_fixture("postgresql", True)
 
         class MyType(TypeEngine):
             pass
 
-        t1 = table('t', column('id', Integer), column('data', MyType()))
+        t1 = table("t", column("id", Integer), column("data", MyType()))
 
-        op.bulk_insert(t1, [
-            {'id': 1, 'data': op.inline_literal('d1')},
-            {'id': 2, 'data': op.inline_literal('d2')},
-        ])
+        op.bulk_insert(
+            t1,
+            [
+                {"id": 1, "data": op.inline_literal("d1")},
+                {"id": 2, "data": op.inline_literal("d2")},
+            ],
+        )
         context.assert_(
             "INSERT INTO t (id, data) VALUES (1, 'd1')",
-            "INSERT INTO t (id, data) VALUES (2, 'd2')"
+            "INSERT INTO t (id, data) VALUES (2, 'd2')",
         )
 
     def test_bulk_insert_as_sql(self):
-        context = self._test_bulk_insert('default', True)
+        context = self._test_bulk_insert("default", True)
         context.assert_(
             "INSERT INTO ins_table (id, v1, v2) "
             "VALUES (1, 'row v1', 'row v5')",
@@ -142,11 +146,11 @@ class BulkInsertTest(TestBase):
             "INSERT INTO ins_table (id, v1, v2) "
             "VALUES (3, 'row v3', 'row v7')",
             "INSERT INTO ins_table (id, v1, v2) "
-            "VALUES (4, 'row v4', 'row v8')"
+            "VALUES (4, 'row v4', 'row v8')",
         )
 
     def test_bulk_insert_as_sql_pg(self):
-        context = self._test_bulk_insert('postgresql', True)
+        context = self._test_bulk_insert("postgresql", True)
         context.assert_(
             "INSERT INTO ins_table (id, v1, v2) "
             "VALUES (1, 'row v1', 'row v5')",
@@ -155,65 +159,68 @@ class BulkInsertTest(TestBase):
             "INSERT INTO ins_table (id, v1, v2) "
             "VALUES (3, 'row v3', 'row v7')",
             "INSERT INTO ins_table (id, v1, v2) "
-            "VALUES (4, 'row v4', 'row v8')"
+            "VALUES (4, 'row v4', 'row v8')",
         )
 
     def test_bulk_insert_as_sql_mssql(self):
-        context = self._test_bulk_insert('mssql', True)
+        context = self._test_bulk_insert("mssql", True)
         # SQL server requires IDENTITY_INSERT
         # TODO: figure out if this is safe to enable for a table that
         # doesn't have an IDENTITY column
         context.assert_(
-            'SET IDENTITY_INSERT ins_table ON',
-            'GO',
+            "SET IDENTITY_INSERT ins_table ON",
+            "GO",
             "INSERT INTO ins_table (id, v1, v2) "
             "VALUES (1, 'row v1', 'row v5')",
-            'GO',
+            "GO",
             "INSERT INTO ins_table (id, v1, v2) "
             "VALUES (2, 'row v2', 'row v6')",
-            'GO',
+            "GO",
             "INSERT INTO ins_table (id, v1, v2) "
             "VALUES (3, 'row v3', 'row v7')",
-            'GO',
+            "GO",
             "INSERT INTO ins_table (id, v1, v2) "
             "VALUES (4, 'row v4', 'row v8')",
-            'GO',
-            'SET IDENTITY_INSERT ins_table OFF',
-            'GO',
+            "GO",
+            "SET IDENTITY_INSERT ins_table OFF",
+            "GO",
         )
 
     def test_bulk_insert_from_new_table(self):
         context = op_fixture("postgresql", True)
         t1 = op.create_table(
             "ins_table",
-            Column('id', Integer),
-            Column('v1', String()),
-            Column('v2', String()),
+            Column("id", Integer),
+            Column("v1", String()),
+            Column("v2", String()),
+        )
+        op.bulk_insert(
+            t1,
+            [
+                {"id": 1, "v1": "row v1", "v2": "row v5"},
+                {"id": 2, "v1": "row v2", "v2": "row v6"},
+            ],
         )
-        op.bulk_insert(t1, [
-            {'id': 1, 'v1': 'row v1', 'v2': 'row v5'},
-            {'id': 2, 'v1': 'row v2', 'v2': 'row v6'},
-        ])
         context.assert_(
-            'CREATE TABLE ins_table (id INTEGER, v1 VARCHAR, v2 VARCHAR)',
+            "CREATE TABLE ins_table (id INTEGER, v1 VARCHAR, v2 VARCHAR)",
             "INSERT INTO ins_table (id, v1, v2) VALUES "
             "(1, 'row v1', 'row v5')",
             "INSERT INTO ins_table (id, v1, v2) VALUES "
-            "(2, 'row v2', 'row v6')"
+            "(2, 'row v2', 'row v6')",
         )
 
     def test_invalid_format(self):
         context, t1 = self._table_fixture("sqlite", False)
         assert_raises_message(
-            TypeError,
-            "List expected",
-            op.bulk_insert, t1, {"id": 5}
+            TypeError, "List expected", op.bulk_insert, t1, {"id": 5}
         )
 
         assert_raises_message(
             TypeError,
             "List of dictionaries expected",
-            op.bulk_insert, t1, [(5, )]
+            op.bulk_insert,
+            t1,
+            [(5,)],
         )
 
 
@@ -223,86 +230,85 @@ class RoundTripTest(TestBase):
     def setUp(self):
         from sqlalchemy import create_engine
         from alembic.migration import MigrationContext
+
         self.conn = config.db.connect()
-        self.conn.execute("""
+        self.conn.execute(
+            """
             create table foo(
                 id integer primary key,
                 data varchar(50),
                 x integer
             )
-        """)
+        """
+        )
         context = MigrationContext.configure(self.conn)
         self.op = op.Operations(context)
-        self.t1 = table('foo',
-                        column('id'),
-                        column('data'),
-                        column('x')
-                        )
+        self.t1 = table("foo", column("id"), column("data"), column("x"))
 
     def tearDown(self):
         self.conn.execute("drop table foo")
         self.conn.close()
 
     def test_single_insert_round_trip(self):
-        self.op.bulk_insert(self.t1,
-                            [{'data': "d1", "x": "x1"}]
-                            )
+        self.op.bulk_insert(self.t1, [{"data": "d1", "x": "x1"}])
 
         eq_(
             self.conn.execute("select id, data, x from foo").fetchall(),
-            [
-                (1, "d1", "x1"),
-            ]
+            [(1, "d1", "x1")],
         )
 
     def test_bulk_insert_round_trip(self):
-        self.op.bulk_insert(self.t1, [
-            {'data': "d1", "x": "x1"},
-            {'data': "d2", "x": "x2"},
-            {'data': "d3", "x": "x3"},
-        ])
+        self.op.bulk_insert(
+            self.t1,
+            [
+                {"data": "d1", "x": "x1"},
+                {"data": "d2", "x": "x2"},
+                {"data": "d3", "x": "x3"},
+            ],
+        )
 
         eq_(
             self.conn.execute("select id, data, x from foo").fetchall(),
-            [
-                (1, "d1", "x1"),
-                (2, "d2", "x2"),
-                (3, "d3", "x3")
-            ]
+            [(1, "d1", "x1"), (2, "d2", "x2"), (3, "d3", "x3")],
         )
 
     def test_bulk_insert_inline_literal(self):
         class MyType(TypeEngine):
             pass
 
-        t1 = table('foo', column('id', Integer), column('data', MyType()))
+        t1 = table("foo", column("id", Integer), column("data", MyType()))
 
-        self.op.bulk_insert(t1, [
-            {'id': 1, 'data': self.op.inline_literal('d1')},
-            {'id': 2, 'data': self.op.inline_literal('d2')},
-        ], multiinsert=False)
+        self.op.bulk_insert(
+            t1,
+            [
+                {"id": 1, "data": self.op.inline_literal("d1")},
+                {"id": 2, "data": self.op.inline_literal("d2")},
+            ],
+            multiinsert=False,
+        )
 
         eq_(
             self.conn.execute("select id, data from foo").fetchall(),
-            [
-                (1, "d1"),
-                (2, "d2"),
-            ]
+            [(1, "d1"), (2, "d2")],
         )
 
     def test_bulk_insert_from_new_table(self):
         t1 = self.op.create_table(
             "ins_table",
-            Column('id', Integer),
-            Column('v1', String()),
-            Column('v2', String()),
+            Column("id", Integer),
+            Column("v1", String()),
+            Column("v2", String()),
+        )
+        self.op.bulk_insert(
+            t1,
+            [
+                {"id": 1, "v1": "row v1", "v2": "row v5"},
+                {"id": 2, "v1": "row v2", "v2": "row v6"},
+            ],
         )
-        self.op.bulk_insert(t1, [
-            {'id': 1, 'v1': 'row v1', 'v2': 'row v5'},
-            {'id': 2, 'v1': 'row v2', 'v2': 'row v6'},
-        ])
         eq_(
             self.conn.execute(
-                "select id, v1, v2 from ins_table order by id").fetchall(),
-            [(1, u'row v1', u'row v5'), (2, u'row v2', u'row v6')]
-        )
\ No newline at end of file
+                "select id, v1, v2 from ins_table order by id"
+            ).fetchall(),
+            [(1, u"row v1", u"row v5"), (2, u"row v2", u"row v6")],
+        )
index 3f3daf50f7b079ea4abace83363f89dec2678a97..a9f0e5d721cbbe309d25c1e4fbf642615c19aa6c 100644 (file)
@@ -3,9 +3,16 @@ from io import TextIOWrapper, BytesIO
 from alembic.script import ScriptDirectory
 from alembic import config
 from alembic.testing.fixtures import TestBase, capture_context_buffer
-from alembic.testing.env import staging_env, _sqlite_testing_config, \
-    three_rev_fixture, clear_staging_env, _no_sql_testing_config, \
-    _sqlite_file_db, write_script, env_file_fixture
+from alembic.testing.env import (
+    staging_env,
+    _sqlite_testing_config,
+    three_rev_fixture,
+    clear_staging_env,
+    _no_sql_testing_config,
+    _sqlite_file_db,
+    write_script,
+    env_file_fixture,
+)
 from alembic.testing import eq_, assert_raises_message, mock, assert_raises
 from alembic import util
 from contextlib import contextmanager
@@ -18,13 +25,12 @@ class _BufMixin(object):
         # try to simulate how sys.stdout looks - we send it u''
         # but then it's trying to encode to something.
         buf = BytesIO()
-        wrapper = TextIOWrapper(buf, encoding='ascii', line_buffering=True)
+        wrapper = TextIOWrapper(buf, encoding="ascii", line_buffering=True)
         wrapper.getvalue = buf.getvalue
         return wrapper
 
 
 class HistoryTest(_BufMixin, TestBase):
-
     @classmethod
     def setup_class(cls):
         cls.env = staging_env()
@@ -41,7 +47,8 @@ class HistoryTest(_BufMixin, TestBase):
 
     @classmethod
     def _setup_env_file(self):
-        env_file_fixture(r"""
+        env_file_fixture(
+            r"""
 
 from sqlalchemy import MetaData, engine_from_config
 target_metadata = MetaData()
@@ -63,7 +70,8 @@ try:
 finally:
     connection.close()
 
-""")
+"""
+        )
 
     def _eq_cmd_output(self, buf, expected, env_token=False, currents=()):
         script = ScriptDirectory.from_config(self.cfg)
@@ -82,9 +90,11 @@ finally:
             assert_lines.insert(0, "environment included OK")
 
         eq_(
-            buf.getvalue().decode("ascii", 'replace').strip(),
-            "\n".join(assert_lines).
-            encode("ascii", "replace").decode("ascii").strip()
+            buf.getvalue().decode("ascii", "replace").strip(),
+            "\n".join(assert_lines)
+            .encode("ascii", "replace")
+            .decode("ascii")
+            .strip(),
         )
 
     def test_history_full(self):
@@ -163,11 +173,11 @@ finally:
         self.cfg.stdout = buf = self._buf_fixture()
         command.history(self.cfg, indicate_current=True, verbose=True)
         self._eq_cmd_output(
-            buf, [self.c, self.b, self.a], currents=(self.b,), env_token=True)
+            buf, [self.c, self.b, self.a], currents=(self.b,), env_token=True
+        )
 
 
 class CurrentTest(_BufMixin, TestBase):
-
     @classmethod
     def setup_class(cls):
         cls.env = env = staging_env()
@@ -189,11 +199,15 @@ class CurrentTest(_BufMixin, TestBase):
 
         yield
 
-        lines = set([
-            re.match(r'(^.\w)', elem).group(1)
-            for elem in re.split(
-                "\n",
-                buf.getvalue().decode('ascii', 'replace').strip()) if elem])
+        lines = set(
+            [
+                re.match(r"(^.\w)", elem).group(1)
+                for elem in re.split(
+                    "\n", buf.getvalue().decode("ascii", "replace").strip()
+                )
+                if elem
+            ]
+        )
 
         eq_(lines, set(revs))
 
@@ -205,25 +219,25 @@ class CurrentTest(_BufMixin, TestBase):
     def test_plain_current(self):
         command.stamp(self.cfg, ())
         command.stamp(self.cfg, self.a3.revision)
-        with self._assert_lines(['a3']):
+        with self._assert_lines(["a3"]):
             command.current(self.cfg)
 
     def test_two_heads(self):
         command.stamp(self.cfg, ())
         command.stamp(self.cfg, (self.a1.revision, self.b1.revision))
-        with self._assert_lines(['a1', 'b1']):
+        with self._assert_lines(["a1", "b1"]):
             command.current(self.cfg)
 
     def test_heads_one_is_dependent(self):
         command.stamp(self.cfg, ())
-        command.stamp(self.cfg, (self.b2.revision, ))
-        with self._assert_lines(['a2', 'b2']):
+        command.stamp(self.cfg, (self.b2.revision,))
+        with self._assert_lines(["a2", "b2"]):
             command.current(self.cfg)
 
     def test_heads_upg(self):
-        command.stamp(self.cfg, (self.b2.revision, ))
+        command.stamp(self.cfg, (self.b2.revision,))
         command.upgrade(self.cfg, (self.b3.revision))
-        with self._assert_lines(['a2', 'b3']):
+        with self._assert_lines(["a2", "b3"]):
             command.current(self.cfg)
 
 
@@ -236,7 +250,8 @@ class RevisionTest(TestBase):
         clear_staging_env()
 
     def _env_fixture(self, version_table_pk=True):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 
 from sqlalchemy import MetaData, engine_from_config
 target_metadata = MetaData()
@@ -258,7 +273,9 @@ try:
 finally:
     connection.close()
 
-""" % (version_table_pk, ))
+"""
+            % (version_table_pk,)
+        )
 
     def test_create_rev_plain_db_not_up_to_date(self):
         self._env_fixture()
@@ -275,7 +292,9 @@ finally:
         assert_raises_message(
             util.CommandError,
             "Target database is not up to date.",
-            command.revision, self.cfg, autogenerate=True
+            command.revision,
+            self.cfg,
+            autogenerate=True,
         )
 
     def test_create_rev_autogen_db_not_up_to_date_multi_heads(self):
@@ -290,7 +309,9 @@ finally:
         assert_raises_message(
             util.CommandError,
             "Target database is not up to date.",
-            command.revision, self.cfg, autogenerate=True
+            command.revision,
+            self.cfg,
+            autogenerate=True,
         )
 
     def test_create_rev_plain_db_not_up_to_date_multi_heads(self):
@@ -306,7 +327,8 @@ finally:
             util.CommandError,
             "Multiple heads are present; please specify the head revision "
             "on which the new revision should be based, or perform a merge.",
-            command.revision, self.cfg
+            command.revision,
+            self.cfg,
         )
 
     def test_create_rev_autogen_need_to_select_head(self):
@@ -321,7 +343,9 @@ finally:
             util.CommandError,
             "Multiple heads are present; please specify the head revision "
             "on which the new revision should be based, or perform a merge.",
-            command.revision, self.cfg, autogenerate=True
+            command.revision,
+            self.cfg,
+            autogenerate=True,
         )
 
     def test_pk_constraint_normally_prevents_dupe_rows(self):
@@ -333,7 +357,7 @@ finally:
         assert_raises(
             sqla_exc.IntegrityError,
             db.execute,
-            "insert into alembic_version values ('%s')" % r2.revision
+            "insert into alembic_version values ('%s')" % r2.revision,
         )
 
     def test_err_correctly_raised_on_dupe_rows_no_pk(self):
@@ -347,7 +371,9 @@ finally:
             util.CommandError,
             "Online migration expected to match one row when "
             "updating .* in 'alembic_version'; 2 found",
-            command.downgrade, self.cfg, "-1"
+            command.downgrade,
+            self.cfg,
+            "-1",
         )
 
     def test_create_rev_plain_need_to_select_head(self):
@@ -362,7 +388,8 @@ finally:
             util.CommandError,
             "Multiple heads are present; please specify the head revision "
             "on which the new revision should be based, or perform a merge.",
-            command.revision, self.cfg
+            command.revision,
+            self.cfg,
         )
 
     def test_create_rev_plain_post_merge(self):
@@ -389,27 +416,20 @@ finally:
         command.revision(self.cfg)
         rev2 = command.revision(self.cfg)
         rev3 = command.revision(self.cfg, depends_on=rev2.revision)
-        eq_(
-            rev3._resolved_dependencies, (rev2.revision, )
-        )
+        eq_(rev3._resolved_dependencies, (rev2.revision,))
 
         rev4 = command.revision(
-            self.cfg, depends_on=[rev2.revision, rev3.revision])
-        eq_(
-            rev4._resolved_dependencies, (rev2.revision, rev3.revision)
+            self.cfg, depends_on=[rev2.revision, rev3.revision]
         )
+        eq_(rev4._resolved_dependencies, (rev2.revision, rev3.revision))
 
     def test_create_rev_depends_on_branch_label(self):
         self._env_fixture()
         command.revision(self.cfg)
-        rev2 = command.revision(self.cfg, branch_label='foobar')
-        rev3 = command.revision(self.cfg, depends_on='foobar')
-        eq_(
-            rev3.dependencies, 'foobar'
-        )
-        eq_(
-            rev3._resolved_dependencies, (rev2.revision, )
-        )
+        rev2 = command.revision(self.cfg, branch_label="foobar")
+        rev3 = command.revision(self.cfg, depends_on="foobar")
+        eq_(rev3.dependencies, "foobar")
+        eq_(rev3._resolved_dependencies, (rev2.revision,))
 
     def test_create_rev_depends_on_partial_revid(self):
         self._env_fixture()
@@ -417,12 +437,8 @@ finally:
         rev2 = command.revision(self.cfg)
         assert len(rev2.revision) > 7
         rev3 = command.revision(self.cfg, depends_on=rev2.revision[0:4])
-        eq_(
-            rev3.dependencies, rev2.revision
-        )
-        eq_(
-            rev3._resolved_dependencies, (rev2.revision, )
-        )
+        eq_(rev3.dependencies, rev2.revision)
+        eq_(rev3._resolved_dependencies, (rev2.revision,))
 
     def test_create_rev_invalid_depends_on(self):
         self._env_fixture()
@@ -430,7 +446,9 @@ finally:
         assert_raises_message(
             util.CommandError,
             "Can't locate revision identified by 'invalid'",
-            command.revision, self.cfg, depends_on='invalid'
+            command.revision,
+            self.cfg,
+            depends_on="invalid",
         )
 
     def test_create_rev_autogenerate_db_not_up_to_date_post_merge(self):
@@ -444,7 +462,9 @@ finally:
         assert_raises_message(
             util.CommandError,
             "Target database is not up to date.",
-            command.revision, self.cfg, autogenerate=True
+            command.revision,
+            self.cfg,
+            autogenerate=True,
         )
 
     def test_nonsensical_sql_mode_autogen(self):
@@ -452,7 +472,10 @@ finally:
         assert_raises_message(
             util.CommandError,
             "Using --sql with --autogenerate does not make any sense",
-            command.revision, self.cfg, autogenerate=True, sql=True
+            command.revision,
+            self.cfg,
+            autogenerate=True,
+            sql=True,
         )
 
     def test_nonsensical_sql_no_env(self):
@@ -461,7 +484,9 @@ finally:
             util.CommandError,
             "Using --sql with the revision command when revision_environment "
             "is not configured does not make any sense",
-            command.revision, self.cfg, sql=True
+            command.revision,
+            self.cfg,
+            sql=True,
         )
 
     def test_sensical_sql_w_env(self):
@@ -471,12 +496,11 @@ finally:
 
 
 class UpgradeDowngradeStampTest(TestBase):
-
     def setUp(self):
         self.env = staging_env()
         self.cfg = cfg = _no_sql_testing_config()
-        cfg.set_main_option('dialect_name', 'sqlite')
-        cfg.remove_main_option('url')
+        cfg.set_main_option("dialect_name", "sqlite")
+        cfg.remove_main_option("url")
 
         self.a, self.b, self.c = three_rev_fixture(cfg)
 
@@ -559,7 +583,7 @@ class UpgradeDowngradeStampTest(TestBase):
 
 
 class LiveStampTest(TestBase):
-    __only_on__ = 'sqlite'
+    __only_on__ = "sqlite"
 
     def setUp(self):
         self.bind = _sqlite_file_db()
@@ -569,15 +593,25 @@ class LiveStampTest(TestBase):
         self.b = b = util.rev_id()
         script = ScriptDirectory.from_config(self.cfg)
         script.generate_revision(a, None, refresh=True)
-        write_script(script, a, """
+        write_script(
+            script,
+            a,
+            """
 revision = '%s'
 down_revision = None
-""" % a)
+"""
+            % a,
+        )
         script.generate_revision(b, None, refresh=True)
-        write_script(script, b, """
+        write_script(
+            script,
+            b,
+            """
 revision = '%s'
 down_revision = '%s'
-""" % (b, a))
+"""
+            % (b, a),
+        )
 
     def tearDown(self):
         clear_staging_env()
@@ -585,29 +619,25 @@ down_revision = '%s'
     def test_stamp_creates_table(self):
         command.stamp(self.cfg, "head")
         eq_(
-            self.bind.scalar("select version_num from alembic_version"),
-            self.b
+            self.bind.scalar("select version_num from alembic_version"), self.b
         )
 
     def test_stamp_existing_upgrade(self):
         command.stamp(self.cfg, self.a)
         command.stamp(self.cfg, self.b)
         eq_(
-            self.bind.scalar("select version_num from alembic_version"),
-            self.b
+            self.bind.scalar("select version_num from alembic_version"), self.b
         )
 
     def test_stamp_existing_downgrade(self):
         command.stamp(self.cfg, self.b)
         command.stamp(self.cfg, self.a)
         eq_(
-            self.bind.scalar("select version_num from alembic_version"),
-            self.a
+            self.bind.scalar("select version_num from alembic_version"), self.a
         )
 
 
 class EditTest(TestBase):
-
     @classmethod
     def setup_class(cls):
         cls.env = staging_env()
@@ -622,56 +652,61 @@ class EditTest(TestBase):
         command.stamp(self.cfg, "base")
 
     def test_edit_head(self):
-        expected_call_arg = '%s/scripts/versions/%s_revision_c.py' % (
-            EditTest.cfg.config_args['here'],
-            EditTest.c
+        expected_call_arg = "%s/scripts/versions/%s_revision_c.py" % (
+            EditTest.cfg.config_args["here"],
+            EditTest.c,
         )
 
-        with mock.patch('alembic.util.edit') as edit:
+        with mock.patch("alembic.util.edit") as edit:
             command.edit(self.cfg, "head")
             edit.assert_called_with(expected_call_arg)
 
     def test_edit_b(self):
-        expected_call_arg = '%s/scripts/versions/%s_revision_b.py' % (
-            EditTest.cfg.config_args['here'],
-            EditTest.b
+        expected_call_arg = "%s/scripts/versions/%s_revision_b.py" % (
+            EditTest.cfg.config_args["here"],
+            EditTest.b,
         )
 
-        with mock.patch('alembic.util.edit') as edit:
+        with mock.patch("alembic.util.edit") as edit:
             command.edit(self.cfg, self.b[0:3])
             edit.assert_called_with(expected_call_arg)
 
     def test_edit_with_missing_editor(self):
-        with mock.patch('editor.edit') as edit_mock:
+        with mock.patch("editor.edit") as edit_mock:
             edit_mock.side_effect = OSError("file not found")
             assert_raises_message(
                 util.CommandError,
-                'file not found',
+                "file not found",
                 util.edit,
-                "/not/a/file.txt")
+                "/not/a/file.txt",
+            )
 
     def test_edit_no_revs(self):
         assert_raises_message(
             util.CommandError,
             "No revision files indicated by symbol 'base'",
             command.edit,
-            self.cfg, "base")
+            self.cfg,
+            "base",
+        )
 
     def test_edit_no_current(self):
         assert_raises_message(
             util.CommandError,
             "No current revisions",
             command.edit,
-            self.cfg, "current")
+            self.cfg,
+            "current",
+        )
 
     def test_edit_current(self):
-        expected_call_arg = '%s/scripts/versions/%s_revision_b.py' % (
-            EditTest.cfg.config_args['here'],
-            EditTest.b
+        expected_call_arg = "%s/scripts/versions/%s_revision_b.py" % (
+            EditTest.cfg.config_args["here"],
+            EditTest.b,
         )
 
         command.stamp(self.cfg, self.b)
-        with mock.patch('alembic.util.edit') as edit:
+        with mock.patch("alembic.util.edit") as edit:
             command.edit(self.cfg, "current")
             edit.assert_called_with(expected_call_arg)
 
@@ -691,16 +726,21 @@ class CommandLineTest(TestBase):
         # the command function has "process_revision_directives"
         # however the ArgumentParser does not.  ensure things work
         def revision(
-            config, message=None, autogenerate=False, sql=False,
-            head="head", splice=False, branch_label=None,
-            version_path=None, rev_id=None, depends_on=None,
-            process_revision_directives=None
+            config,
+            message=None,
+            autogenerate=False,
+            sql=False,
+            head="head",
+            splice=False,
+            branch_label=None,
+            version_path=None,
+            rev_id=None,
+            depends_on=None,
+            process_revision_directives=None,
         ):
-            canary(
-                config, message=message
-            )
+            canary(config, message=message)
 
-        revision.__module__ = 'alembic.command'
+        revision.__module__ = "alembic.command"
 
         # CommandLine() pulls the function into the ArgumentParser
         # and needs the full signature, so we can't patch the "revision"
@@ -712,7 +752,4 @@ class CommandLineTest(TestBase):
             commandline.run_cmd(self.cfg, options)
         finally:
             config.command.revision = orig_revision
-        eq_(
-            canary.mock_calls,
-            [mock.call(self.cfg, message="foo")]
-        )
+        eq_(canary.mock_calls, [mock.call(self.cfg, message="foo")])
index 50e1b05e35ee3a273f461d28c701eaa749e83431..b1d1ca10b902d7157f763fa382b6027067ecd01d 100644 (file)
@@ -11,30 +11,35 @@ from alembic.testing.mock import Mock, call
 
 from alembic.testing import eq_, assert_raises_message
 from alembic.testing.fixtures import capture_db
-from alembic.testing.env import _no_sql_testing_config, clear_staging_env,\
-    staging_env, _write_config_file
+from alembic.testing.env import (
+    _no_sql_testing_config,
+    clear_staging_env,
+    staging_env,
+    _write_config_file,
+)
 
 
 class FileConfigTest(TestBase):
-
     def test_config_args(self):
-        cfg = _write_config_file("""
+        cfg = _write_config_file(
+            """
 [alembic]
 migrations = %(base_path)s/db/migrations
-""")
+"""
+        )
         test_cfg = config.Config(
             cfg.config_file_name, config_args=dict(base_path="/home/alembic")
         )
         eq_(
             test_cfg.get_section_option("alembic", "migrations"),
-            "/home/alembic/db/migrations")
+            "/home/alembic/db/migrations",
+        )
 
     def tearDown(self):
         clear_staging_env()
 
 
 class ConfigTest(TestBase):
-
     def test_config_no_file_main_option(self):
         cfg = config.Config()
         cfg.set_main_option("url", "postgresql://foo/bar")
@@ -66,12 +71,12 @@ class ConfigTest(TestBase):
         cfg = config.Config()
         cfg.set_section_option("some_section", "foob", "foob_value")
 
-        cfg.set_section_option(
-            "some_section", "bar", "bar with %(foob)s")
+        cfg.set_section_option("some_section", "bar", "bar with %(foob)s")
 
         eq_(
             cfg.get_section_option("some_section", "bar"),
-            "bar with foob_value")
+            "bar with foob_value",
+        )
 
     def test_standalone_op(self):
         eng, buf = capture_db()
@@ -80,71 +85,58 @@ class ConfigTest(TestBase):
         op = Operations(env)
 
         op.alter_column("t", "c", nullable=True)
-        eq_(buf, ['ALTER TABLE t ALTER COLUMN c DROP NOT NULL'])
+        eq_(buf, ["ALTER TABLE t ALTER COLUMN c DROP NOT NULL"])
 
     def test_no_script_error(self):
         cfg = config.Config()
         assert_raises_message(
             util.CommandError,
             "No 'script_location' key found in configuration.",
-            ScriptDirectory.from_config, cfg
+            ScriptDirectory.from_config,
+            cfg,
         )
 
     def test_attributes_attr(self):
         m1 = Mock()
         cfg = config.Config()
-        cfg.attributes['connection'] = m1
-        eq_(
-            cfg.attributes['connection'], m1
-        )
+        cfg.attributes["connection"] = m1
+        eq_(cfg.attributes["connection"], m1)
 
     def test_attributes_construtor(self):
         m1 = Mock()
         m2 = Mock()
-        cfg = config.Config(attributes={'m1': m1})
-        cfg.attributes['connection'] = m2
-        eq_(
-            cfg.attributes, {'m1': m1, 'connection': m2}
-        )
+        cfg = config.Config(attributes={"m1": m1})
+        cfg.attributes["connection"] = m2
+        eq_(cfg.attributes, {"m1": m1, "connection": m2})
 
 
 class StdoutOutputEncodingTest(TestBase):
-
     def test_plain(self):
-        stdout = Mock(encoding='latin-1')
+        stdout = Mock(encoding="latin-1")
         cfg = config.Config(stdout=stdout)
         cfg.print_stdout("test %s %s", "x", "y")
-        eq_(
-            stdout.mock_calls,
-            [call.write('test x y'), call.write('\n')]
-        )
+        eq_(stdout.mock_calls, [call.write("test x y"), call.write("\n")])
 
     def test_utf8_unicode(self):
-        stdout = Mock(encoding='latin-1')
+        stdout = Mock(encoding="latin-1")
         cfg = config.Config(stdout=stdout)
         cfg.print_stdout(compat.u("méil %s %s"), "x", "y")
         eq_(
             stdout.mock_calls,
-            [call.write(compat.u('méil x y')), call.write('\n')]
+            [call.write(compat.u("méil x y")), call.write("\n")],
         )
 
     def test_ascii_unicode(self):
         stdout = Mock(encoding=None)
         cfg = config.Config(stdout=stdout)
         cfg.print_stdout(compat.u("méil %s %s"), "x", "y")
-        eq_(
-            stdout.mock_calls,
-            [call.write('m?il x y'), call.write('\n')]
-        )
+        eq_(stdout.mock_calls, [call.write("m?il x y"), call.write("\n")])
 
     def test_only_formats_output_with_args(self):
         stdout = Mock(encoding=None)
         cfg = config.Config(stdout=stdout)
         cfg.print_stdout(compat.u("test 3%"))
-        eq_(
-            stdout.mock_calls,
-            [call.write('test 3%'), call.write('\n')]
-        )
+        eq_(stdout.mock_calls, [call.write("test 3%"), call.write("\n")])
 
 
 class TemplateOutputEncodingTest(TestBase):
@@ -157,9 +149,9 @@ class TemplateOutputEncodingTest(TestBase):
 
     def test_default(self):
         script = ScriptDirectory.from_config(self.cfg)
-        eq_(script.output_encoding, 'utf-8')
+        eq_(script.output_encoding, "utf-8")
 
     def test_setting(self):
-        self.cfg.set_main_option('output_encoding', 'latin-1')
+        self.cfg.set_main_option("output_encoding", "latin-1")
         script = ScriptDirectory.from_config(self.cfg)
-        eq_(script.output_encoding, 'latin-1')
+        eq_(script.output_encoding, "latin-1")
index 42ff328e0d5bd8982854a469bc948b99191d7194..cfa72f6798149491d54ade3a96b86097dd26a6c3 100644 (file)
@@ -5,15 +5,19 @@ from alembic.environment import EnvironmentContext
 from alembic.migration import MigrationContext
 from alembic.testing.fixtures import TestBase
 from alembic.testing.mock import Mock, call, MagicMock
-from alembic.testing.env import _no_sql_testing_config, \
-    staging_env, clear_staging_env, write_script, _sqlite_file_db
+from alembic.testing.env import (
+    _no_sql_testing_config,
+    staging_env,
+    clear_staging_env,
+    write_script,
+    _sqlite_file_db,
+)
 from alembic.testing.assertions import expect_warnings
 
 from alembic.testing import eq_, is_
 
 
 class EnvironmentTest(TestBase):
-
     def setUp(self):
         staging_env()
         self.cfg = _no_sql_testing_config()
@@ -23,49 +27,30 @@ class EnvironmentTest(TestBase):
 
     def _fixture(self, **kw):
         script = ScriptDirectory.from_config(self.cfg)
-        env = EnvironmentContext(
-            self.cfg,
-            script,
-            **kw
-        )
+        env = EnvironmentContext(self.cfg, script, **kw)
         return env
 
     def test_x_arg(self):
         env = self._fixture()
         self.cfg.cmd_opts = Mock(x="y=5")
-        eq_(
-            env.get_x_argument(),
-            "y=5"
-        )
+        eq_(env.get_x_argument(), "y=5")
 
     def test_x_arg_asdict(self):
         env = self._fixture()
         self.cfg.cmd_opts = Mock(x=["y=5"])
-        eq_(
-            env.get_x_argument(as_dictionary=True),
-            {"y": "5"}
-        )
+        eq_(env.get_x_argument(as_dictionary=True), {"y": "5"})
 
     def test_x_arg_no_opts(self):
         env = self._fixture()
-        eq_(
-            env.get_x_argument(),
-            []
-        )
+        eq_(env.get_x_argument(), [])
 
     def test_x_arg_no_opts_asdict(self):
         env = self._fixture()
-        eq_(
-            env.get_x_argument(as_dictionary=True),
-            {}
-        )
+        eq_(env.get_x_argument(as_dictionary=True), {})
 
     def test_tag_arg(self):
         env = self._fixture(tag="x")
-        eq_(
-            env.get_tag_argument(),
-            "x"
-        )
+        eq_(env.get_tag_argument(), "x")
 
     def test_migration_context_has_config(self):
         env = self._fixture()
@@ -81,9 +66,12 @@ class EnvironmentTest(TestBase):
 
         engine = _sqlite_file_db()
 
-        a_rev = 'arev'
+        a_rev = "arev"
         env.script.generate_revision(a_rev, "revision a", refresh=True)
-        write_script(env.script, a_rev, """\
+        write_script(
+            env.script,
+            a_rev,
+            """\
 "Rev A"
 revision = '%s'
 down_revision = None
@@ -98,7 +86,9 @@ def upgrade():
 def downgrade():
     pass
 
-""" % a_rev)
+"""
+            % a_rev,
+        )
         migration_fn = MagicMock()
 
         def upgrade(rev, context):
@@ -106,15 +96,13 @@ def downgrade():
             return env.script._upgrade_revs(a_rev, rev)
 
         with expect_warnings(
-                r"'connection' argument to configure\(\) is "
-                r"expected to be a sqlalchemy.engine.Connection "):
+            r"'connection' argument to configure\(\) is "
+            r"expected to be a sqlalchemy.engine.Connection "
+        ):
             env.configure(
-                connection=engine, fn=upgrade,
-                transactional_ddl=False)
+                connection=engine, fn=upgrade, transactional_ddl=False
+            )
 
         env.run_migrations()
 
-        eq_(
-            migration_fn.mock_calls,
-            [call((), env._migration_context)]
-        )
+        eq_(migration_fn.mock_calls, [call((), env._migration_context)])
index dc01b755a9f270c0e036b34346d6fad52efac28d..1c3222dc9dbf277cf6f94ca10b3846250e715dc9 100644 (file)
@@ -15,6 +15,7 @@ from sqlalchemy.engine import default
 class CustomDialect(default.DefaultDialect):
     name = "custom_dialect"
 
+
 try:
     from sqlalchemy.dialects import registry
 except ImportError:
@@ -24,20 +25,22 @@ else:
 
 
 class CustomDialectImpl(impl.DefaultImpl):
-    __dialect__ = 'custom_dialect'
+    __dialect__ = "custom_dialect"
     transactional_ddl = False
 
     def render_type(self, type_, autogen_context):
         if type_.__module__ == __name__:
             autogen_context.imports.add(
-                "from %s import custom_dialect_types" % (__name__, ))
+                "from %s import custom_dialect_types" % (__name__,)
+            )
             is_external = True
         else:
             is_external = False
 
-        if is_external and \
-                hasattr(self, '_render_%s_type' % type_.__visit_name__):
-            meth = getattr(self, '_render_%s_type' % type_.__visit_name__)
+        if is_external and hasattr(
+            self, "_render_%s_type" % type_.__visit_name__
+        ):
+            meth = getattr(self, "_render_%s_type" % type_.__visit_name__)
             return meth(type_, autogen_context)
 
         if is_external:
@@ -47,13 +50,16 @@ class CustomDialectImpl(impl.DefaultImpl):
 
     def _render_EXT_ARRAY_type(self, type_, autogen_context):
         return render._render_type_w_subtype(
-            type_, autogen_context, 'item_type', r'(.+?\()',
-            prefix="custom_dialect_types."
+            type_,
+            autogen_context,
+            "item_type",
+            r"(.+?\()",
+            prefix="custom_dialect_types.",
         )
 
 
 class EXT_ARRAY(sqla_types.TypeEngine):
-    __visit_name__ = 'EXT_ARRAY'
+    __visit_name__ = "EXT_ARRAY"
 
     def __init__(self, item_type):
         if isinstance(item_type, type):
@@ -63,75 +69,78 @@ class EXT_ARRAY(sqla_types.TypeEngine):
 
 
 class FOOBARTYPE(sqla_types.TypeEngine):
-    __visit_name__ = 'FOOBARTYPE'
+    __visit_name__ = "FOOBARTYPE"
 
 
 class ExternalDialectRenderTest(TestBase):
-
     def setUp(self):
         ctx_opts = {
-            'sqlalchemy_module_prefix': 'sa.',
-            'alembic_module_prefix': 'op.',
-            'target_metadata': MetaData(),
-            'user_module_prefix': None
+            "sqlalchemy_module_prefix": "sa.",
+            "alembic_module_prefix": "op.",
+            "target_metadata": MetaData(),
+            "user_module_prefix": None,
         }
         context = MigrationContext.configure(
-            dialect_name="custom_dialect",
-            opts=ctx_opts
+            dialect_name="custom_dialect", opts=ctx_opts
         )
 
         self.autogen_context = api.AutogenContext(context)
 
     def test_render_type(self):
         eq_ignore_whitespace(
-            autogenerate.render._repr_type(
-                FOOBARTYPE(), self.autogen_context),
-            "custom_dialect_types.FOOBARTYPE()"
+            autogenerate.render._repr_type(FOOBARTYPE(), self.autogen_context),
+            "custom_dialect_types.FOOBARTYPE()",
         )
 
         eq_(
             self.autogen_context.imports,
-            set([
-                'from tests.test_external_dialect import custom_dialect_types'
-            ])
+            set(
+                [
+                    "from tests.test_external_dialect import custom_dialect_types"
+                ]
+            ),
         )
 
     def test_external_nested_render_sqla_type(self):
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
-                EXT_ARRAY(sqla_types.Integer), self.autogen_context),
-            "custom_dialect_types.EXT_ARRAY(sa.Integer())"
+                EXT_ARRAY(sqla_types.Integer), self.autogen_context
+            ),
+            "custom_dialect_types.EXT_ARRAY(sa.Integer())",
         )
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
-                EXT_ARRAY(
-                    sqla_types.DateTime(timezone=True)
-                ),
-                self.autogen_context),
-            "custom_dialect_types.EXT_ARRAY(sa.DateTime(timezone=True))"
+                EXT_ARRAY(sqla_types.DateTime(timezone=True)),
+                self.autogen_context,
+            ),
+            "custom_dialect_types.EXT_ARRAY(sa.DateTime(timezone=True))",
         )
 
         eq_(
             self.autogen_context.imports,
-            set([
-                'from tests.test_external_dialect import custom_dialect_types'
-            ])
+            set(
+                [
+                    "from tests.test_external_dialect import custom_dialect_types"
+                ]
+            ),
         )
 
     def test_external_nested_render_external_type(self):
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
-                EXT_ARRAY(FOOBARTYPE),
-                self.autogen_context),
-            "custom_dialect_types.EXT_ARRAY(custom_dialect_types.FOOBARTYPE())"
+                EXT_ARRAY(FOOBARTYPE), self.autogen_context
+            ),
+            "custom_dialect_types.EXT_ARRAY(custom_dialect_types.FOOBARTYPE())",
         )
 
         eq_(
             self.autogen_context.imports,
-            set([
-                'from tests.test_external_dialect import custom_dialect_types'
-            ])
+            set(
+                [
+                    "from tests.test_external_dialect import custom_dialect_types"
+                ]
+            ),
         )
index b092dcf6e2e7a13e54ca653e0c724ba9ac7bc413..1657cccdf357cd9a33765dce2263fa290ada2ee7 100644 (file)
@@ -8,13 +8,16 @@ from alembic import op, command, util
 
 from alembic.testing import eq_, assert_raises_message
 from alembic.testing.fixtures import capture_context_buffer, op_fixture
-from alembic.testing.env import staging_env, _no_sql_testing_config, \
-    three_rev_fixture, clear_staging_env
+from alembic.testing.env import (
+    staging_env,
+    _no_sql_testing_config,
+    three_rev_fixture,
+    clear_staging_env,
+)
 from alembic.testing import config
 
 
 class FullEnvironmentTests(TestBase):
-
     @classmethod
     def setup_class(cls):
         staging_env()
@@ -24,8 +27,7 @@ class FullEnvironmentTests(TestBase):
             directives = ""
         cls.cfg = cfg = _no_sql_testing_config("mssql", directives)
 
-        cls.a, cls.b, cls.c = \
-            three_rev_fixture(cfg)
+        cls.a, cls.b, cls.c = three_rev_fixture(cfg)
 
     @classmethod
     def teardown_class(cls):
@@ -39,7 +41,7 @@ class FullEnvironmentTests(TestBase):
         # ensure ends in COMMIT; GO
         eq_(
             [x for x in buf.getvalue().splitlines() if x][-2:],
-            ['COMMIT;', 'GO']
+            ["COMMIT;", "GO"],
         )
 
     def test_batch_separator_default(self):
@@ -54,242 +56,248 @@ class FullEnvironmentTests(TestBase):
 
 
 class OpTest(TestBase):
-
     def test_add_column(self):
-        context = op_fixture('mssql')
-        op.add_column('t1', Column('c1', Integer, nullable=False))
+        context = op_fixture("mssql")
+        op.add_column("t1", Column("c1", Integer, nullable=False))
         context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL")
 
     def test_add_column_with_default(self):
         context = op_fixture("mssql")
         op.add_column(
-            't1', Column('c1', Integer, nullable=False, server_default="12"))
+            "t1", Column("c1", Integer, nullable=False, server_default="12")
+        )
         context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL DEFAULT '12'")
 
     def test_alter_column_rename_mssql(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         op.alter_column("t", "c", new_column_name="x")
-        context.assert_(
-            "EXEC sp_rename 't.c', x, 'COLUMN'"
-        )
+        context.assert_("EXEC sp_rename 't.c', x, 'COLUMN'")
 
     def test_alter_column_rename_quoted_mssql(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         op.alter_column("t", "c", new_column_name="SomeFancyName")
-        context.assert_(
-            "EXEC sp_rename 't.c', [SomeFancyName], 'COLUMN'"
-        )
+        context.assert_("EXEC sp_rename 't.c', [SomeFancyName], 'COLUMN'")
 
     def test_alter_column_new_type(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         op.alter_column("t", "c", type_=Integer)
-        context.assert_(
-            'ALTER TABLE t ALTER COLUMN c INTEGER'
-        )
+        context.assert_("ALTER TABLE t ALTER COLUMN c INTEGER")
 
     def test_alter_column_dont_touch_constraints(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         from sqlalchemy import Boolean
-        op.alter_column('tests', 'col',
-                        existing_type=Boolean(),
-                        nullable=False)
-        context.assert_('ALTER TABLE tests ALTER COLUMN col BIT NOT NULL')
+
+        op.alter_column(
+            "tests", "col", existing_type=Boolean(), nullable=False
+        )
+        context.assert_("ALTER TABLE tests ALTER COLUMN col BIT NOT NULL")
 
     def test_drop_index(self):
-        context = op_fixture('mssql')
-        op.drop_index('my_idx', 'my_table')
+        context = op_fixture("mssql")
+        op.drop_index("my_idx", "my_table")
         context.assert_contains("DROP INDEX my_idx ON my_table")
 
     def test_drop_column_w_default(self):
-        context = op_fixture('mssql')
-        op.drop_column('t1', 'c1', mssql_drop_default=True)
-        op.drop_column('t1', 'c2', mssql_drop_default=True)
+        context = op_fixture("mssql")
+        op.drop_column("t1", "c1", mssql_drop_default=True)
+        op.drop_column("t1", "c2", mssql_drop_default=True)
         context.assert_contains(
-            "exec('alter table t1 drop constraint ' + @const_name)")
+            "exec('alter table t1 drop constraint ' + @const_name)"
+        )
         context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
 
     def test_drop_column_w_default_in_batch(self):
-        context = op_fixture('mssql')
-        with op.batch_alter_table('t1', schema=None) as batch_op:
-            batch_op.drop_column('c1', mssql_drop_default=True)
-            batch_op.drop_column('c2', mssql_drop_default=True)
+        context = op_fixture("mssql")
+        with op.batch_alter_table("t1", schema=None) as batch_op:
+            batch_op.drop_column("c1", mssql_drop_default=True)
+            batch_op.drop_column("c2", mssql_drop_default=True)
         context.assert_contains(
-            "exec('alter table t1 drop constraint ' + @const_name)")
+            "exec('alter table t1 drop constraint ' + @const_name)"
+        )
         context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
 
     def test_alter_column_drop_default(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         op.alter_column("t", "c", server_default=None)
         context.assert_contains(
-            "exec('alter table t drop constraint ' + @const_name)")
+            "exec('alter table t drop constraint ' + @const_name)"
+        )
 
     def test_alter_column_dont_drop_default(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         op.alter_column("t", "c", server_default=False)
         context.assert_()
 
     def test_drop_column_w_schema(self):
-        context = op_fixture('mssql')
-        op.drop_column('t1', 'c1', schema='xyz')
+        context = op_fixture("mssql")
+        op.drop_column("t1", "c1", schema="xyz")
         context.assert_contains("ALTER TABLE xyz.t1 DROP COLUMN c1")
 
     def test_drop_column_w_check(self):
-        context = op_fixture('mssql')
-        op.drop_column('t1', 'c1', mssql_drop_check=True)
-        op.drop_column('t1', 'c2', mssql_drop_check=True)
+        context = op_fixture("mssql")
+        op.drop_column("t1", "c1", mssql_drop_check=True)
+        op.drop_column("t1", "c2", mssql_drop_check=True)
         context.assert_contains(
-            "exec('alter table t1 drop constraint ' + @const_name)")
+            "exec('alter table t1 drop constraint ' + @const_name)"
+        )
         context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
 
     def test_drop_column_w_check_in_batch(self):
-        context = op_fixture('mssql')
-        with op.batch_alter_table('t1', schema=None) as batch_op:
-            batch_op.drop_column('c1', mssql_drop_check=True)
-            batch_op.drop_column('c2', mssql_drop_check=True)
+        context = op_fixture("mssql")
+        with op.batch_alter_table("t1", schema=None) as batch_op:
+            batch_op.drop_column("c1", mssql_drop_check=True)
+            batch_op.drop_column("c2", mssql_drop_check=True)
         context.assert_contains(
-            "exec('alter table t1 drop constraint ' + @const_name)")
+            "exec('alter table t1 drop constraint ' + @const_name)"
+        )
         context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
 
     def test_drop_column_w_check_quoting(self):
-        context = op_fixture('mssql')
-        op.drop_column('table', 'column', mssql_drop_check=True)
+        context = op_fixture("mssql")
+        op.drop_column("table", "column", mssql_drop_check=True)
         context.assert_contains(
-            "exec('alter table [table] drop constraint ' + @const_name)")
+            "exec('alter table [table] drop constraint ' + @const_name)"
+        )
         context.assert_contains("ALTER TABLE [table] DROP COLUMN [column]")
 
     def test_alter_column_nullable_w_existing_type(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         op.alter_column("t", "c", nullable=True, existing_type=Integer)
-        context.assert_(
-            "ALTER TABLE t ALTER COLUMN c INTEGER NULL"
-        )
+        context.assert_("ALTER TABLE t ALTER COLUMN c INTEGER NULL")
 
     def test_drop_column_w_fk(self):
-        context = op_fixture('mssql')
-        op.drop_column('t1', 'c1', mssql_drop_foreign_key=True)
+        context = op_fixture("mssql")
+        op.drop_column("t1", "c1", mssql_drop_foreign_key=True)
         context.assert_contains(
-            "exec('alter table t1 drop constraint ' + @const_name)")
+            "exec('alter table t1 drop constraint ' + @const_name)"
+        )
         context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
 
     def test_drop_column_w_fk_in_batch(self):
-        context = op_fixture('mssql')
-        with op.batch_alter_table('t1', schema=None) as batch_op:
-            batch_op.drop_column('c1', mssql_drop_foreign_key=True)
+        context = op_fixture("mssql")
+        with op.batch_alter_table("t1", schema=None) as batch_op:
+            batch_op.drop_column("c1", mssql_drop_foreign_key=True)
         context.assert_contains(
-            "exec('alter table t1 drop constraint ' + @const_name)")
+            "exec('alter table t1 drop constraint ' + @const_name)"
+        )
         context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
 
     def test_alter_column_not_nullable_w_existing_type(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         op.alter_column("t", "c", nullable=False, existing_type=Integer)
-        context.assert_(
-            "ALTER TABLE t ALTER COLUMN c INTEGER NOT NULL"
-        )
+        context.assert_("ALTER TABLE t ALTER COLUMN c INTEGER NOT NULL")
 
     def test_alter_column_nullable_w_new_type(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         op.alter_column("t", "c", nullable=True, type_=Integer)
-        context.assert_(
-            "ALTER TABLE t ALTER COLUMN c INTEGER NULL"
-        )
+        context.assert_("ALTER TABLE t ALTER COLUMN c INTEGER NULL")
 
     def test_alter_column_not_nullable_w_new_type(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         op.alter_column("t", "c", nullable=False, type_=Integer)
-        context.assert_(
-            "ALTER TABLE t ALTER COLUMN c INTEGER NOT NULL"
-        )
+        context.assert_("ALTER TABLE t ALTER COLUMN c INTEGER NOT NULL")
 
     def test_alter_column_nullable_type_required(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         assert_raises_message(
             util.CommandError,
             "MS-SQL ALTER COLUMN operations with NULL or "
             "NOT NULL require the existing_type or a new "
             "type_ be passed.",
-            op.alter_column, "t", "c", nullable=False
+            op.alter_column,
+            "t",
+            "c",
+            nullable=False,
         )
 
     def test_alter_add_server_default(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         op.alter_column("t", "c", server_default="5")
-        context.assert_(
-            "ALTER TABLE t ADD DEFAULT '5' FOR c"
-        )
+        context.assert_("ALTER TABLE t ADD DEFAULT '5' FOR c")
 
     def test_alter_replace_server_default(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         op.alter_column(
-            "t", "c", server_default="5", existing_server_default="6")
-        context.assert_contains(
-            "exec('alter table t drop constraint ' + @const_name)")
+            "t", "c", server_default="5", existing_server_default="6"
+        )
         context.assert_contains(
-            "ALTER TABLE t ADD DEFAULT '5' FOR c"
+            "exec('alter table t drop constraint ' + @const_name)"
         )
+        context.assert_contains("ALTER TABLE t ADD DEFAULT '5' FOR c")
 
     def test_alter_remove_server_default(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         op.alter_column("t", "c", server_default=None)
         context.assert_contains(
-            "exec('alter table t drop constraint ' + @const_name)")
+            "exec('alter table t drop constraint ' + @const_name)"
+        )
 
     def test_alter_do_everything(self):
-        context = op_fixture('mssql')
-        op.alter_column("t", "c", new_column_name="c2", nullable=True,
-                        type_=Integer, server_default="5")
+        context = op_fixture("mssql")
+        op.alter_column(
+            "t",
+            "c",
+            new_column_name="c2",
+            nullable=True,
+            type_=Integer,
+            server_default="5",
+        )
         context.assert_(
-            'ALTER TABLE t ALTER COLUMN c INTEGER NULL',
+            "ALTER TABLE t ALTER COLUMN c INTEGER NULL",
             "ALTER TABLE t ADD DEFAULT '5' FOR c",
-            "EXEC sp_rename 't.c', c2, 'COLUMN'"
+            "EXEC sp_rename 't.c', c2, 'COLUMN'",
         )
 
     def test_rename_table(self):
-        context = op_fixture('mssql')
-        op.rename_table('t1', 't2')
+        context = op_fixture("mssql")
+        op.rename_table("t1", "t2")
         context.assert_contains("EXEC sp_rename 't1', t2")
 
     def test_rename_table_schema(self):
-        context = op_fixture('mssql')
-        op.rename_table('t1', 't2', schema="foobar")
+        context = op_fixture("mssql")
+        op.rename_table("t1", "t2", schema="foobar")
         context.assert_contains("EXEC sp_rename 'foobar.t1', t2")
 
     def test_rename_table_casesens(self):
-        context = op_fixture('mssql')
-        op.rename_table('TeeOne', 'TeeTwo')
+        context = op_fixture("mssql")
+        op.rename_table("TeeOne", "TeeTwo")
         # yup, ran this in SQL Server 2014, the two levels of quoting
         # seems to be understood.  Can't do the two levels on the
         # target name though !
         context.assert_contains("EXEC sp_rename '[TeeOne]', [TeeTwo]")
 
     def test_rename_table_schema_casesens(self):
-        context = op_fixture('mssql')
-        op.rename_table('TeeOne', 'TeeTwo', schema="FooBar")
+        context = op_fixture("mssql")
+        op.rename_table("TeeOne", "TeeTwo", schema="FooBar")
         # yup, ran this in SQL Server 2014, the two levels of quoting
         # seems to be understood.  Can't do the two levels on the
         # target name though !
         context.assert_contains("EXEC sp_rename '[FooBar].[TeeOne]', [TeeTwo]")
 
     def test_alter_column_rename_mssql_schema(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         op.alter_column("t", "c", name="x", schema="y")
-        context.assert_(
-            "EXEC sp_rename 'y.t.c', x, 'COLUMN'"
-        )
+        context.assert_("EXEC sp_rename 'y.t.c', x, 'COLUMN'")
 
     def test_create_index_mssql_include(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         op.create_index(
-            op.f('ix_mytable_a_b'), 'mytable', ['col_a', 'col_b'],
-            unique=False, mssql_include=['col_c'])
+            op.f("ix_mytable_a_b"),
+            "mytable",
+            ["col_a", "col_b"],
+            unique=False,
+            mssql_include=["col_c"],
+        )
         context.assert_contains(
             "CREATE INDEX ix_mytable_a_b ON mytable "
-            "(col_a, col_b) INCLUDE (col_c)")
+            "(col_a, col_b) INCLUDE (col_c)"
+        )
 
     def test_create_index_mssql_include_is_none(self):
-        context = op_fixture('mssql')
+        context = op_fixture("mssql")
         op.create_index(
-            op.f('ix_mytable_a_b'), 'mytable', ['col_a', 'col_b'],
-            unique=False)
+            op.f("ix_mytable_a_b"), "mytable", ["col_a", "col_b"], unique=False
+        )
         context.assert_contains(
-            "CREATE INDEX ix_mytable_a_b ON mytable "
-            "(col_a, col_b)")
+            "CREATE INDEX ix_mytable_a_b ON mytable " "(col_a, col_b)"
+        )
index dd872f7f5b43ca7764a869c0b9b631c427ccbc5b..68746ba6213e5111b62cd69000ee8cd0799079d1 100644 (file)
@@ -7,53 +7,68 @@ from alembic import op, util
 
 from alembic.testing import eq_, assert_raises_message
 from alembic.testing.fixtures import capture_context_buffer, op_fixture
-from alembic.testing.env import staging_env, _no_sql_testing_config, \
-    three_rev_fixture, clear_staging_env
+from alembic.testing.env import (
+    staging_env,
+    _no_sql_testing_config,
+    three_rev_fixture,
+    clear_staging_env,
+)
 
 from alembic.migration import MigrationContext
 
 
 class MySQLOpTest(TestBase):
-
     def test_rename_column(self):
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         op.alter_column(
-            't1', 'c1', new_column_name="c2", existing_type=Integer)
-        context.assert_(
-            'ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL'
+            "t1", "c1", new_column_name="c2", existing_type=Integer
         )
+        context.assert_("ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL")
 
     def test_rename_column_quotes_needed_one(self):
-        context = op_fixture('mysql')
-        op.alter_column('MyTable', 'ColumnOne', new_column_name="ColumnTwo",
-                        existing_type=Integer)
+        context = op_fixture("mysql")
+        op.alter_column(
+            "MyTable",
+            "ColumnOne",
+            new_column_name="ColumnTwo",
+            existing_type=Integer,
+        )
         context.assert_(
-            'ALTER TABLE `MyTable` CHANGE `ColumnOne` `ColumnTwo` INTEGER NULL'
+            "ALTER TABLE `MyTable` CHANGE `ColumnOne` `ColumnTwo` INTEGER NULL"
         )
 
     def test_rename_column_quotes_needed_two(self):
-        context = op_fixture('mysql')
-        op.alter_column('my table', 'column one', new_column_name="column two",
-                        existing_type=Integer)
+        context = op_fixture("mysql")
+        op.alter_column(
+            "my table",
+            "column one",
+            new_column_name="column two",
+            existing_type=Integer,
+        )
         context.assert_(
-            'ALTER TABLE `my table` CHANGE `column one` '
-            '`column two` INTEGER NULL'
+            "ALTER TABLE `my table` CHANGE `column one` "
+            "`column two` INTEGER NULL"
         )
 
     def test_rename_column_serv_default(self):
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         op.alter_column(
-            't1', 'c1', new_column_name="c2", existing_type=Integer,
-            existing_server_default="q")
-        context.assert_(
-            "ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL DEFAULT 'q'"
+            "t1",
+            "c1",
+            new_column_name="c2",
+            existing_type=Integer,
+            existing_server_default="q",
         )
+        context.assert_("ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL DEFAULT 'q'")
 
     def test_rename_column_serv_compiled_default(self):
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         op.alter_column(
-            't1', 'c1', existing_type=Integer,
-            server_default=func.utc_thing(func.current_timestamp()))
+            "t1",
+            "c1",
+            existing_type=Integer,
+            server_default=func.utc_thing(func.current_timestamp()),
+        )
         # this is not a valid MySQL default but the point is to just
         # test SQL expression rendering
         context.assert_(
@@ -62,184 +77,183 @@ class MySQLOpTest(TestBase):
         )
 
     def test_rename_column_autoincrement(self):
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         op.alter_column(
-            't1', 'c1', new_column_name="c2", existing_type=Integer,
-            existing_autoincrement=True)
+            "t1",
+            "c1",
+            new_column_name="c2",
+            existing_type=Integer,
+            existing_autoincrement=True,
+        )
         context.assert_(
-            'ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL AUTO_INCREMENT'
+            "ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL AUTO_INCREMENT"
         )
 
     def test_col_add_autoincrement(self):
-        context = op_fixture('mysql')
-        op.alter_column('t1', 'c1', existing_type=Integer,
-                        autoincrement=True)
-        context.assert_(
-            'ALTER TABLE t1 MODIFY c1 INTEGER NULL AUTO_INCREMENT'
-        )
+        context = op_fixture("mysql")
+        op.alter_column("t1", "c1", existing_type=Integer, autoincrement=True)
+        context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NULL AUTO_INCREMENT")
 
     def test_col_remove_autoincrement(self):
-        context = op_fixture('mysql')
-        op.alter_column('t1', 'c1', existing_type=Integer,
-                        existing_autoincrement=True,
-                        autoincrement=False)
-        context.assert_(
-            'ALTER TABLE t1 MODIFY c1 INTEGER NULL'
+        context = op_fixture("mysql")
+        op.alter_column(
+            "t1",
+            "c1",
+            existing_type=Integer,
+            existing_autoincrement=True,
+            autoincrement=False,
         )
+        context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NULL")
 
     def test_col_dont_remove_server_default(self):
-        context = op_fixture('mysql')
-        op.alter_column('t1', 'c1', existing_type=Integer,
-                        existing_server_default='1',
-                        server_default=False)
+        context = op_fixture("mysql")
+        op.alter_column(
+            "t1",
+            "c1",
+            existing_type=Integer,
+            existing_server_default="1",
+            server_default=False,
+        )
 
         context.assert_()
 
     def test_alter_column_drop_default(self):
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         op.alter_column("t", "c", existing_type=Integer, server_default=None)
-        context.assert_(
-            'ALTER TABLE t ALTER COLUMN c DROP DEFAULT'
-        )
+        context.assert_("ALTER TABLE t ALTER COLUMN c DROP DEFAULT")
 
     def test_alter_column_remove_schematype(self):
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         op.alter_column(
-            "t", "c",
+            "t",
+            "c",
             type_=Integer,
             existing_type=Boolean(create_constraint=True, name="ck1"),
-            server_default=None)
-        context.assert_(
-            'ALTER TABLE t MODIFY c INTEGER NULL'
+            server_default=None,
         )
+        context.assert_("ALTER TABLE t MODIFY c INTEGER NULL")
 
     def test_alter_column_modify_default(self):
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         # notice we dont need the existing type on this one...
-        op.alter_column("t", "c", server_default='1')
-        context.assert_(
-            "ALTER TABLE t ALTER COLUMN c SET DEFAULT '1'"
-        )
+        op.alter_column("t", "c", server_default="1")
+        context.assert_("ALTER TABLE t ALTER COLUMN c SET DEFAULT '1'")
 
     def test_col_not_nullable(self):
-        context = op_fixture('mysql')
-        op.alter_column('t1', 'c1', nullable=False, existing_type=Integer)
-        context.assert_(
-            'ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL'
-        )
+        context = op_fixture("mysql")
+        op.alter_column("t1", "c1", nullable=False, existing_type=Integer)
+        context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL")
 
     def test_col_not_nullable_existing_serv_default(self):
-        context = op_fixture('mysql')
-        op.alter_column('t1', 'c1', nullable=False, existing_type=Integer,
-                        existing_server_default='5')
+        context = op_fixture("mysql")
+        op.alter_column(
+            "t1",
+            "c1",
+            nullable=False,
+            existing_type=Integer,
+            existing_server_default="5",
+        )
         context.assert_(
             "ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL DEFAULT '5'"
         )
 
     def test_col_nullable(self):
-        context = op_fixture('mysql')
-        op.alter_column('t1', 'c1', nullable=True, existing_type=Integer)
-        context.assert_(
-            'ALTER TABLE t1 MODIFY c1 INTEGER NULL'
-        )
+        context = op_fixture("mysql")
+        op.alter_column("t1", "c1", nullable=True, existing_type=Integer)
+        context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NULL")
 
     def test_col_multi_alter(self):
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         op.alter_column(
-            't1', 'c1', nullable=False, server_default="q", type_=Integer)
+            "t1", "c1", nullable=False, server_default="q", type_=Integer
+        )
         context.assert_(
             "ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL DEFAULT 'q'"
         )
 
     def test_alter_column_multi_alter_w_drop_default(self):
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         op.alter_column(
-            't1', 'c1', nullable=False, server_default=None, type_=Integer)
-        context.assert_(
-            "ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL"
+            "t1", "c1", nullable=False, server_default=None, type_=Integer
         )
+        context.assert_("ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL")
 
     def test_col_alter_type_required(self):
-        op_fixture('mysql')
+        op_fixture("mysql")
         assert_raises_message(
             util.CommandError,
             "MySQL CHANGE/MODIFY COLUMN operations require the existing type.",
-            op.alter_column, 't1', 'c1', nullable=False, server_default="q"
+            op.alter_column,
+            "t1",
+            "c1",
+            nullable=False,
+            server_default="q",
         )
 
     def test_drop_fk(self):
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         op.drop_constraint("f1", "t1", "foreignkey")
-        context.assert_(
-            "ALTER TABLE t1 DROP FOREIGN KEY f1"
-        )
+        context.assert_("ALTER TABLE t1 DROP FOREIGN KEY f1")
 
     def test_drop_fk_quoted(self):
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         op.drop_constraint("MyFk", "MyTable", "foreignkey")
-        context.assert_(
-            "ALTER TABLE `MyTable` DROP FOREIGN KEY `MyFk`"
-        )
+        context.assert_("ALTER TABLE `MyTable` DROP FOREIGN KEY `MyFk`")
 
     def test_drop_constraint_primary(self):
-        context = op_fixture('mysql')
-        op.drop_constraint('primary', 't1', type_='primary')
-        context.assert_(
-            "ALTER TABLE t1 DROP PRIMARY KEY"
-        )
+        context = op_fixture("mysql")
+        op.drop_constraint("primary", "t1", type_="primary")
+        context.assert_("ALTER TABLE t1 DROP PRIMARY KEY")
 
     def test_drop_unique(self):
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         op.drop_constraint("f1", "t1", "unique")
-        context.assert_(
-            "ALTER TABLE t1 DROP INDEX f1"
-        )
+        context.assert_("ALTER TABLE t1 DROP INDEX f1")
 
     def test_drop_unique_quoted(self):
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         op.drop_constraint("MyUnique", "MyTable", "unique")
-        context.assert_(
-            "ALTER TABLE `MyTable` DROP INDEX `MyUnique`"
-        )
+        context.assert_("ALTER TABLE `MyTable` DROP INDEX `MyUnique`")
 
     def test_drop_check(self):
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         op.drop_constraint("f1", "t1", "check")
-        context.assert_(
-            "ALTER TABLE t1 DROP CONSTRAINT f1"
-        )
+        context.assert_("ALTER TABLE t1 DROP CONSTRAINT f1")
 
     def test_drop_check_quoted(self):
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         op.drop_constraint("MyCheck", "MyTable", "check")
-        context.assert_(
-            "ALTER TABLE `MyTable` DROP CONSTRAINT `MyCheck`"
-        )
+        context.assert_("ALTER TABLE `MyTable` DROP CONSTRAINT `MyCheck`")
 
     def test_drop_unknown(self):
-        op_fixture('mysql')
+        op_fixture("mysql")
         assert_raises_message(
             TypeError,
             "'type' can be one of 'check', 'foreignkey', "
             "'primary', 'unique', None",
-            op.drop_constraint, "f1", "t1", "typo"
+            op.drop_constraint,
+            "f1",
+            "t1",
+            "typo",
         )
 
     def test_drop_generic_constraint(self):
-        op_fixture('mysql')
+        op_fixture("mysql")
         assert_raises_message(
             NotImplementedError,
             "No generic 'DROP CONSTRAINT' in MySQL - please "
             "specify constraint type",
-            op.drop_constraint, "f1", "t1"
+            op.drop_constraint,
+            "f1",
+            "t1",
         )
 
 
 class MySQLDefaultCompareTest(TestBase):
-    __only_on__ = 'mysql'
+    __only_on__ = "mysql"
     __backend__ = True
 
-    __requires__ = 'mysql_timestamp_reflection',
+    __requires__ = ("mysql_timestamp_reflection",)
 
     @classmethod
     def setup_class(cls):
@@ -247,17 +261,14 @@ class MySQLDefaultCompareTest(TestBase):
         staging_env()
         context = MigrationContext.configure(
             connection=cls.bind.connect(),
-            opts={
-                'compare_type': True,
-                'compare_server_default': True
-            }
+            opts={"compare_type": True, "compare_server_default": True},
         )
         connection = context.bind
         cls.autogen_context = {
-            'imports': set(),
-            'connection': connection,
-            'dialect': connection.dialect,
-            'context': context
+            "imports": set(),
+            "connection": connection,
+            "dialect": connection.dialect,
+            "context": context,
         }
 
     @classmethod
@@ -277,64 +288,46 @@ class MySQLDefaultCompareTest(TestBase):
             alternate = txt
             expected = False
         t = Table(
-            "test", self.metadata,
+            "test",
+            self.metadata,
             Column(
-                "somecol", type_,
-                server_default=text(txt) if txt else None
-            )
+                "somecol", type_, server_default=text(txt) if txt else None
+            ),
+        )
+        t2 = Table(
+            "test",
+            MetaData(),
+            Column("somecol", type_, server_default=text(alternate)),
         )
-        t2 = Table("test", MetaData(),
-                   Column("somecol", type_, server_default=text(alternate))
-                   )
-        assert self._compare_default(
-            t, t2, t2.c.somecol, alternate
-        ) is expected
-
-    def _compare_default(
-        self,
-        t1, t2, col,
-        rendered
-    ):
+        assert (
+            self._compare_default(t, t2, t2.c.somecol, alternate) is expected
+        )
+
+    def _compare_default(self, t1, t2, col, rendered):
         t1.create(self.bind)
         insp = Inspector.from_engine(self.bind)
         cols = insp.get_columns(t1.name)
         refl = Table(t1.name, MetaData())
         insp.reflecttable(refl, None)
-        ctx = self.autogen_context['context']
+        ctx = self.autogen_context["context"]
         return ctx.impl.compare_server_default(
-            refl.c[cols[0]['name']],
-            col,
-            rendered,
-            cols[0]['default'])
+            refl.c[cols[0]["name"]], col, rendered, cols[0]["default"]
+        )
 
     def test_compare_timestamp_current_timestamp(self):
-        self._compare_default_roundtrip(
-            TIMESTAMP(),
-            "CURRENT_TIMESTAMP",
-        )
+        self._compare_default_roundtrip(TIMESTAMP(), "CURRENT_TIMESTAMP")
 
     def test_compare_timestamp_current_timestamp_diff(self):
-        self._compare_default_roundtrip(
-            TIMESTAMP(),
-            None, "CURRENT_TIMESTAMP",
-        )
+        self._compare_default_roundtrip(TIMESTAMP(), None, "CURRENT_TIMESTAMP")
 
     def test_compare_integer_same(self):
-        self._compare_default_roundtrip(
-            Integer(), "5"
-        )
+        self._compare_default_roundtrip(Integer(), "5")
 
     def test_compare_integer_diff(self):
-        self._compare_default_roundtrip(
-            Integer(), "5", "7"
-        )
+        self._compare_default_roundtrip(Integer(), "5", "7")
 
     def test_compare_boolean_same(self):
-        self._compare_default_roundtrip(
-            Boolean(), "1"
-        )
+        self._compare_default_roundtrip(Boolean(), "1")
 
     def test_compare_boolean_diff(self):
-        self._compare_default_roundtrip(
-            Boolean(), "1", "0"
-        )
+        self._compare_default_roundtrip(Boolean(), "1", "0")
index 3920690dc53e69124fe5148f7e7aa2ea38774ef5..fbbbec3f509f991b66e6fb34ce7fdde28ce096ad 100644 (file)
@@ -3,16 +3,20 @@ from alembic.testing.fixtures import TestBase, capture_context_buffer
 from alembic import command, util
 
 from alembic.testing import assert_raises_message
-from alembic.testing.env import staging_env, _no_sql_testing_config, \
-    three_rev_fixture, clear_staging_env, env_file_fixture, \
-    multi_heads_fixture
+from alembic.testing.env import (
+    staging_env,
+    _no_sql_testing_config,
+    three_rev_fixture,
+    clear_staging_env,
+    env_file_fixture,
+    multi_heads_fixture,
+)
 import re
 
 a = b = c = None
 
 
 class OfflineEnvironmentTest(TestBase):
-
     def setUp(self):
         staging_env()
         self.cfg = _no_sql_testing_config()
@@ -24,92 +28,122 @@ class OfflineEnvironmentTest(TestBase):
         clear_staging_env()
 
     def test_not_requires_connection(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 assert not context.requires_connection()
-""")
+"""
+        )
         command.upgrade(self.cfg, a, sql=True)
         command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True)
 
     def test_requires_connection(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 assert context.requires_connection()
-""")
+"""
+        )
         command.upgrade(self.cfg, a)
         command.downgrade(self.cfg, a)
 
     def test_starting_rev_post_context(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 context.configure(dialect_name='sqlite', starting_rev='x')
 assert context.get_starting_revision_argument() == 'x'
-""")
+"""
+        )
         command.upgrade(self.cfg, a, sql=True)
         command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True)
         command.current(self.cfg)
         command.stamp(self.cfg, a)
 
     def test_starting_rev_pre_context(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 assert context.get_starting_revision_argument() == 'x'
-""")
+"""
+        )
         command.upgrade(self.cfg, "x:y", sql=True)
         command.downgrade(self.cfg, "x:y", sql=True)
 
     def test_starting_rev_pre_context_cmd_w_no_startrev(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 assert context.get_starting_revision_argument() == 'x'
-""")
+"""
+        )
         assert_raises_message(
             util.CommandError,
             "No starting revision argument is available.",
-            command.current, self.cfg)
+            command.current,
+            self.cfg,
+        )
 
     def test_starting_rev_current_pre_context(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 assert context.get_starting_revision_argument() is None
-""")
+"""
+        )
         assert_raises_message(
             util.CommandError,
             "No starting revision argument is available.",
-            command.current, self.cfg
+            command.current,
+            self.cfg,
         )
 
     def test_destination_rev_pre_context(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 assert context.get_revision_argument() == '%s'
-""" % b)
+"""
+            % b
+        )
         command.upgrade(self.cfg, b, sql=True)
         command.stamp(self.cfg, b, sql=True)
         command.downgrade(self.cfg, "%s:%s" % (c, b), sql=True)
 
     def test_destination_rev_pre_context_multihead(self):
         d, e, f = multi_heads_fixture(self.cfg, a, b, c)
-        env_file_fixture("""
+        env_file_fixture(
+            """
 assert set(context.get_revision_argument()) == set(('%s', '%s', '%s', ))
-""" % (f, e, c))
-        command.upgrade(self.cfg, 'heads', sql=True)
+"""
+            % (f, e, c)
+        )
+        command.upgrade(self.cfg, "heads", sql=True)
 
     def test_destination_rev_post_context(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 context.configure(dialect_name='sqlite')
 assert context.get_revision_argument() == '%s'
-""" % b)
+"""
+            % b
+        )
         command.upgrade(self.cfg, b, sql=True)
         command.downgrade(self.cfg, "%s:%s" % (c, b), sql=True)
         command.stamp(self.cfg, b, sql=True)
 
     def test_destination_rev_post_context_multihead(self):
         d, e, f = multi_heads_fixture(self.cfg, a, b, c)
-        env_file_fixture("""
+        env_file_fixture(
+            """
 context.configure(dialect_name='sqlite')
 assert set(context.get_revision_argument()) == set(('%s', '%s', '%s', ))
-""" % (f, e, c))
-        command.upgrade(self.cfg, 'heads', sql=True)
+"""
+            % (f, e, c)
+        )
+        command.upgrade(self.cfg, "heads", sql=True)
 
     def test_head_rev_pre_context(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 assert context.get_head_revision() == '%s'
 assert context.get_head_revisions() == ('%s', )
-""" % (c, c))
+"""
+            % (c, c)
+        )
         command.upgrade(self.cfg, b, sql=True)
         command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True)
         command.stamp(self.cfg, b, sql=True)
@@ -117,20 +151,26 @@ assert context.get_head_revisions() == ('%s', )
 
     def test_head_rev_pre_context_multihead(self):
         d, e, f = multi_heads_fixture(self.cfg, a, b, c)
-        env_file_fixture("""
+        env_file_fixture(
+            """
 assert set(context.get_head_revisions()) == set(('%s', '%s', '%s', ))
-""" % (e, f, c))
+"""
+            % (e, f, c)
+        )
         command.upgrade(self.cfg, e, sql=True)
         command.downgrade(self.cfg, "%s:%s" % (e, b), sql=True)
         command.stamp(self.cfg, c, sql=True)
         command.current(self.cfg)
 
     def test_head_rev_post_context(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 context.configure(dialect_name='sqlite')
 assert context.get_head_revision() == '%s'
 assert context.get_head_revisions() == ('%s', )
-""" % (c, c))
+"""
+            % (c, c)
+        )
         command.upgrade(self.cfg, b, sql=True)
         command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True)
         command.stamp(self.cfg, b, sql=True)
@@ -138,70 +178,89 @@ assert context.get_head_revisions() == ('%s', )
 
     def test_head_rev_post_context_multihead(self):
         d, e, f = multi_heads_fixture(self.cfg, a, b, c)
-        env_file_fixture("""
+        env_file_fixture(
+            """
 context.configure(dialect_name='sqlite')
 assert set(context.get_head_revisions()) == set(('%s', '%s', '%s', ))
-""" % (e, f, c))
+"""
+            % (e, f, c)
+        )
         command.upgrade(self.cfg, e, sql=True)
         command.downgrade(self.cfg, "%s:%s" % (e, b), sql=True)
         command.stamp(self.cfg, c, sql=True)
         command.current(self.cfg)
 
     def test_tag_pre_context(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 assert context.get_tag_argument() == 'hi'
-""")
-        command.upgrade(self.cfg, b, sql=True, tag='hi')
-        command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag='hi')
+"""
+        )
+        command.upgrade(self.cfg, b, sql=True, tag="hi")
+        command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag="hi")
 
     def test_tag_pre_context_None(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 assert context.get_tag_argument() is None
-""")
+"""
+        )
         command.upgrade(self.cfg, b, sql=True)
         command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True)
 
     def test_tag_cmd_arg(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 context.configure(dialect_name='sqlite')
 assert context.get_tag_argument() == 'hi'
-""")
-        command.upgrade(self.cfg, b, sql=True, tag='hi')
-        command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag='hi')
+"""
+        )
+        command.upgrade(self.cfg, b, sql=True, tag="hi")
+        command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag="hi")
 
     def test_tag_cfg_arg(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 context.configure(dialect_name='sqlite', tag='there')
 assert context.get_tag_argument() == 'there'
-""")
-        command.upgrade(self.cfg, b, sql=True, tag='hi')
-        command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag='hi')
+"""
+        )
+        command.upgrade(self.cfg, b, sql=True, tag="hi")
+        command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True, tag="hi")
 
     def test_tag_None(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 context.configure(dialect_name='sqlite')
 assert context.get_tag_argument() is None
-""")
+"""
+        )
         command.upgrade(self.cfg, b, sql=True)
         command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True)
 
     def test_downgrade_wo_colon(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 context.configure(dialect_name='sqlite')
-""")
+"""
+        )
         assert_raises_message(
             util.CommandError,
             "downgrade with --sql requires <fromrev>:<torev>",
             command.downgrade,
-            self.cfg, b, sql=True
+            self.cfg,
+            b,
+            sql=True,
         )
 
     def test_upgrade_with_output_encoding(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 url = config.get_main_option('sqlalchemy.url')
 context.configure(url=url, output_encoding='utf-8')
 assert not context.requires_connection()
-""")
+"""
+        )
         command.upgrade(self.cfg, a, sql=True)
         command.downgrade(self.cfg, "%s:%s" % (b, a), sql=True)
 
@@ -213,37 +272,49 @@ assert not context.requires_connection()
         with capture_context_buffer(transactional_ddl=True) as buf:
             command.upgrade(self.cfg, "%s:%s" % (a, d.revision), sql=True)
 
-        assert not re.match(r".*-- .*and multiline", buf.getvalue(), re.S | re.M)
+        assert not re.match(
+            r".*-- .*and multiline", buf.getvalue(), re.S | re.M
+        )
 
     def test_starting_rev_pre_context_abbreviated(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 assert context.get_starting_revision_argument() == '%s'
-""" % b[0:4])
+"""
+            % b[0:4]
+        )
         command.upgrade(self.cfg, "%s:%s" % (b[0:4], c), sql=True)
         command.stamp(self.cfg, "%s:%s" % (b[0:4], c), sql=True)
         command.downgrade(self.cfg, "%s:%s" % (b[0:4], a), sql=True)
 
     def test_destination_rev_pre_context_abbreviated(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 assert context.get_revision_argument() == '%s'
-""" % b[0:4])
+"""
+            % b[0:4]
+        )
         command.upgrade(self.cfg, "%s:%s" % (a, b[0:4]), sql=True)
         command.stamp(self.cfg, b[0:4], sql=True)
         command.downgrade(self.cfg, "%s:%s" % (c, b[0:4]), sql=True)
 
     def test_starting_rev_context_runs_abbreviated(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 context.configure(dialect_name='sqlite')
 context.run_migrations()
-""")
+"""
+        )
         command.upgrade(self.cfg, "%s:%s" % (b[0:4], c), sql=True)
         command.downgrade(self.cfg, "%s:%s" % (b[0:4], a), sql=True)
 
     def test_destination_rev_context_runs_abbreviated(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 context.configure(dialect_name='sqlite')
 context.run_migrations()
-""")
+"""
+        )
         command.upgrade(self.cfg, "%s:%s" % (a, b[0:4]), sql=True)
         command.stamp(self.cfg, b[0:4], sql=True)
         command.downgrade(self.cfg, "%s:%s" % (c, b[0:4]), sql=True)
index f9a6c51d2da23abfde81a02a0394269ea4f7ef3b..fb2db5f44285342418ecc61b213cdf2a29cfbb23 100644 (file)
@@ -1,7 +1,6 @@
 """Test against the builders in the op.* module."""
 
-from sqlalchemy import Integer, Column, ForeignKey, \
-    Table, String, Boolean
+from sqlalchemy import Integer, Column, ForeignKey, Table, String, Boolean
 from sqlalchemy.sql import column, func, text
 from sqlalchemy import event
 
@@ -17,19 +16,18 @@ from alembic.operations import schemaobj, ops
 @event.listens_for(Table, "after_parent_attach")
 def _add_cols(table, metadata):
     if table.name == "tbl_with_auto_appended_column":
-        table.append_column(Column('bat', Integer))
+        table.append_column(Column("bat", Integer))
 
 
 class OpTest(TestBase):
-
     def test_rename_table(self):
         context = op_fixture()
-        op.rename_table('t1', 't2')
+        op.rename_table("t1", "t2")
         context.assert_("ALTER TABLE t1 RENAME TO t2")
 
     def test_rename_table_schema(self):
         context = op_fixture()
-        op.rename_table('t1', 't2', schema="foo")
+        op.rename_table("t1", "t2", schema="foo")
         context.assert_("ALTER TABLE foo.t1 RENAME TO foo.t2")
 
     def test_create_index_no_expr_allowed(self):
@@ -37,15 +35,21 @@ class OpTest(TestBase):
         assert_raises_message(
             ValueError,
             r"String or text\(\) construct expected",
-            op.create_index, 'name', 'tname', [func.foo(column('x'))]
+            op.create_index,
+            "name",
+            "tname",
+            [func.foo(column("x"))],
         )
 
     def test_add_column_schema_hard_quoting(self):
         from sqlalchemy.sql.schema import quoted_name
+
         context = op_fixture("postgresql")
         op.add_column(
-            "somename", Column("colname", String),
-            schema=quoted_name("some.schema", quote=True))
+            "somename",
+            Column("colname", String),
+            schema=quoted_name("some.schema", quote=True),
+        )
 
         context.assert_(
             'ALTER TABLE "some.schema".somename ADD COLUMN colname VARCHAR'
@@ -53,68 +57,67 @@ class OpTest(TestBase):
 
     def test_rename_table_schema_hard_quoting(self):
         from sqlalchemy.sql.schema import quoted_name
+
         context = op_fixture("postgresql")
         op.rename_table(
-            't1', 't2',
-            schema=quoted_name("some.schema", quote=True))
-
-        context.assert_(
-            'ALTER TABLE "some.schema".t1 RENAME TO t2'
+            "t1", "t2", schema=quoted_name("some.schema", quote=True)
         )
 
+        context.assert_('ALTER TABLE "some.schema".t1 RENAME TO t2')
+
     def test_add_constraint_schema_hard_quoting(self):
         from sqlalchemy.sql.schema import quoted_name
+
         context = op_fixture("postgresql")
         op.create_check_constraint(
             "ck_user_name_len",
             "user_table",
-            func.len(column('name')) > 5,
-            schema=quoted_name("some.schema", quote=True)
+            func.len(column("name")) > 5,
+            schema=quoted_name("some.schema", quote=True),
         )
         context.assert_(
             'ALTER TABLE "some.schema".user_table ADD '
-            'CONSTRAINT ck_user_name_len CHECK (len(name) > 5)'
+            "CONSTRAINT ck_user_name_len CHECK (len(name) > 5)"
         )
 
     def test_create_index_quoting(self):
         context = op_fixture("postgresql")
-        op.create_index(
-            'geocoded',
-            'locations',
-            ["IShouldBeQuoted"])
+        op.create_index("geocoded", "locations", ["IShouldBeQuoted"])
         context.assert_(
-            'CREATE INDEX geocoded ON locations ("IShouldBeQuoted")')
+            'CREATE INDEX geocoded ON locations ("IShouldBeQuoted")'
+        )
 
     def test_create_index_expressions(self):
         context = op_fixture()
-        op.create_index(
-            'geocoded',
-            'locations',
-            [text('lower(coordinates)')])
+        op.create_index("geocoded", "locations", [text("lower(coordinates)")])
         context.assert_(
-            "CREATE INDEX geocoded ON locations (lower(coordinates))")
+            "CREATE INDEX geocoded ON locations (lower(coordinates))"
+        )
 
     def test_add_column(self):
         context = op_fixture()
-        op.add_column('t1', Column('c1', Integer, nullable=False))
+        op.add_column("t1", Column("c1", Integer, nullable=False))
         context.assert_("ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL")
 
     def test_add_column_schema(self):
         context = op_fixture()
-        op.add_column('t1', Column('c1', Integer, nullable=False), schema="foo")
+        op.add_column(
+            "t1", Column("c1", Integer, nullable=False), schema="foo"
+        )
         context.assert_("ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL")
 
     def test_add_column_with_default(self):
         context = op_fixture()
         op.add_column(
-            't1', Column('c1', Integer, nullable=False, server_default="12"))
+            "t1", Column("c1", Integer, nullable=False, server_default="12")
+        )
         context.assert_(
-            "ALTER TABLE t1 ADD COLUMN c1 INTEGER DEFAULT '12' NOT NULL")
+            "ALTER TABLE t1 ADD COLUMN c1 INTEGER DEFAULT '12' NOT NULL"
+        )
 
     def test_add_column_with_index(self):
         context = op_fixture()
-        op.add_column(
-            't1', Column('c1', Integer, nullable=False, index=True))
+        op.add_column("t1", Column("c1", Integer, nullable=False, index=True))
         context.assert_(
             "ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL",
             "CREATE INDEX ix_t1_c1 ON t1 (c1)",
@@ -122,107 +125,117 @@ class OpTest(TestBase):
 
     def test_add_column_schema_with_default(self):
         context = op_fixture()
-        op.add_column('t1',
-                      Column('c1', Integer, nullable=False, server_default="12"),
-                      schema='foo')
+        op.add_column(
+            "t1",
+            Column("c1", Integer, nullable=False, server_default="12"),
+            schema="foo",
+        )
         context.assert_(
-            "ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER DEFAULT '12' NOT NULL")
+            "ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER DEFAULT '12' NOT NULL"
+        )
 
     def test_add_column_fk(self):
         context = op_fixture()
         op.add_column(
-            't1', Column('c1', Integer, ForeignKey('c2.id'), nullable=False))
+            "t1", Column("c1", Integer, ForeignKey("c2.id"), nullable=False)
+        )
         context.assert_(
             "ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL",
-            "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES c2 (id)"
+            "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES c2 (id)",
         )
 
     def test_add_column_schema_fk(self):
         context = op_fixture()
-        op.add_column('t1',
-                      Column('c1', Integer, ForeignKey('c2.id'), nullable=False),
-                      schema='foo')
+        op.add_column(
+            "t1",
+            Column("c1", Integer, ForeignKey("c2.id"), nullable=False),
+            schema="foo",
+        )
         context.assert_(
             "ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL",
-            "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES c2 (id)"
+            "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES c2 (id)",
         )
 
     def test_add_column_schema_type(self):
         """Test that a schema type generates its constraints...."""
         context = op_fixture()
-        op.add_column('t1', Column('c1', Boolean, nullable=False))
+        op.add_column("t1", Column("c1", Boolean, nullable=False))
         context.assert_(
-            'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL',
-            'ALTER TABLE t1 ADD CHECK (c1 IN (0, 1))'
+            "ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL",
+            "ALTER TABLE t1 ADD CHECK (c1 IN (0, 1))",
         )
 
     def test_add_column_schema_schema_type(self):
         """Test that a schema type generates its constraints...."""
         context = op_fixture()
-        op.add_column('t1', Column('c1', Boolean, nullable=False), schema='foo')
+        op.add_column(
+            "t1", Column("c1", Boolean, nullable=False), schema="foo"
+        )
         context.assert_(
-            'ALTER TABLE foo.t1 ADD COLUMN c1 BOOLEAN NOT NULL',
-            'ALTER TABLE foo.t1 ADD CHECK (c1 IN (0, 1))'
+            "ALTER TABLE foo.t1 ADD COLUMN c1 BOOLEAN NOT NULL",
+            "ALTER TABLE foo.t1 ADD CHECK (c1 IN (0, 1))",
         )
 
     def test_add_column_schema_type_checks_rule(self):
         """Test that a schema type doesn't generate a
         constraint based on check rule."""
-        context = op_fixture('postgresql')
-        op.add_column('t1', Column('c1', Boolean, nullable=False))
-        context.assert_(
-            'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL',
-        )
+        context = op_fixture("postgresql")
+        op.add_column("t1", Column("c1", Boolean, nullable=False))
+        context.assert_("ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL")
 
     def test_add_column_fk_self_referential(self):
         context = op_fixture()
         op.add_column(
-            't1', Column('c1', Integer, ForeignKey('t1.c2'), nullable=False))
+            "t1", Column("c1", Integer, ForeignKey("t1.c2"), nullable=False)
+        )
         context.assert_(
             "ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL",
-            "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES t1 (c2)"
+            "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES t1 (c2)",
         )
 
     def test_add_column_schema_fk_self_referential(self):
         context = op_fixture()
         op.add_column(
-            't1',
-            Column('c1', Integer, ForeignKey('foo.t1.c2'), nullable=False),
-            schema='foo')
+            "t1",
+            Column("c1", Integer, ForeignKey("foo.t1.c2"), nullable=False),
+            schema="foo",
+        )
         context.assert_(
             "ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL",
-            "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES foo.t1 (c2)"
+            "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES foo.t1 (c2)",
         )
 
     def test_add_column_fk_schema(self):
         context = op_fixture()
         op.add_column(
-            't1',
-            Column('c1', Integer, ForeignKey('remote.t2.c2'), nullable=False))
+            "t1",
+            Column("c1", Integer, ForeignKey("remote.t2.c2"), nullable=False),
+        )
         context.assert_(
-            'ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL',
-            'ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)'
+            "ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL",
+            "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)",
         )
 
     def test_add_column_schema_fk_schema(self):
         context = op_fixture()
         op.add_column(
-            't1',
-            Column('c1', Integer, ForeignKey('remote.t2.c2'), nullable=False),
-            schema='foo')
+            "t1",
+            Column("c1", Integer, ForeignKey("remote.t2.c2"), nullable=False),
+            schema="foo",
+        )
         context.assert_(
-            'ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL',
-            'ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)'
+            "ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL",
+            "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)",
         )
 
     def test_drop_column(self):
         context = op_fixture()
-        op.drop_column('t1', 'c1')
+        op.drop_column("t1", "c1")
         context.assert_("ALTER TABLE t1 DROP COLUMN c1")
 
     def test_drop_column_schema(self):
         context = op_fixture()
-        op.drop_column('t1', 'c1', schema='foo')
+        op.drop_column("t1", "c1", schema="foo")
         context.assert_("ALTER TABLE foo.t1 DROP COLUMN c1")
 
     def test_alter_column_nullable(self):
@@ -236,7 +249,7 @@ class OpTest(TestBase):
 
     def test_alter_column_schema_nullable(self):
         context = op_fixture()
-        op.alter_column("t", "c", nullable=True, schema='foo')
+        op.alter_column("t", "c", nullable=True, schema="foo")
         context.assert_(
             # TODO: not sure if this is PG only or standard
             # SQL
@@ -254,7 +267,7 @@ class OpTest(TestBase):
 
     def test_alter_column_schema_not_nullable(self):
         context = op_fixture()
-        op.alter_column("t", "c", nullable=False, schema='foo')
+        op.alter_column("t", "c", nullable=False, schema="foo")
         context.assert_(
             # TODO: not sure if this is PG only or standard
             # SQL
@@ -264,58 +277,50 @@ class OpTest(TestBase):
     def test_alter_column_rename(self):
         context = op_fixture()
         op.alter_column("t", "c", new_column_name="x")
-        context.assert_(
-            "ALTER TABLE t RENAME c TO x"
-        )
+        context.assert_("ALTER TABLE t RENAME c TO x")
 
     def test_alter_column_schema_rename(self):
         context = op_fixture()
-        op.alter_column("t", "c", new_column_name="x", schema='foo')
-        context.assert_(
-            "ALTER TABLE foo.t RENAME c TO x"
-        )
+        op.alter_column("t", "c", new_column_name="x", schema="foo")
+        context.assert_("ALTER TABLE foo.t RENAME c TO x")
 
     def test_alter_column_type(self):
         context = op_fixture()
         op.alter_column("t", "c", type_=String(50))
-        context.assert_(
-            'ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(50)'
-        )
+        context.assert_("ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(50)")
 
     def test_alter_column_schema_type(self):
         context = op_fixture()
-        op.alter_column("t", "c", type_=String(50), schema='foo')
-        context.assert_(
-            'ALTER TABLE foo.t ALTER COLUMN c TYPE VARCHAR(50)'
-        )
+        op.alter_column("t", "c", type_=String(50), schema="foo")
+        context.assert_("ALTER TABLE foo.t ALTER COLUMN c TYPE VARCHAR(50)")
 
     def test_alter_column_set_default(self):
         context = op_fixture()
         op.alter_column("t", "c", server_default="q")
-        context.assert_(
-            "ALTER TABLE t ALTER COLUMN c SET DEFAULT 'q'"
-        )
+        context.assert_("ALTER TABLE t ALTER COLUMN c SET DEFAULT 'q'")
 
     def test_alter_column_schema_set_default(self):
         context = op_fixture()
-        op.alter_column("t", "c", server_default="q", schema='foo')
-        context.assert_(
-            "ALTER TABLE foo.t ALTER COLUMN c SET DEFAULT 'q'"
-        )
+        op.alter_column("t", "c", server_default="q", schema="foo")
+        context.assert_("ALTER TABLE foo.t ALTER COLUMN c SET DEFAULT 'q'")
 
     def test_alter_column_set_compiled_default(self):
         context = op_fixture()
-        op.alter_column("t", "c",
-                        server_default=func.utc_thing(func.current_timestamp()))
+        op.alter_column(
+            "t", "c", server_default=func.utc_thing(func.current_timestamp())
+        )
         context.assert_(
             "ALTER TABLE t ALTER COLUMN c SET DEFAULT utc_thing(CURRENT_TIMESTAMP)"
         )
 
     def test_alter_column_schema_set_compiled_default(self):
         context = op_fixture()
-        op.alter_column("t", "c",
-                        server_default=func.utc_thing(func.current_timestamp()),
-                        schema='foo')
+        op.alter_column(
+            "t",
+            "c",
+            server_default=func.utc_thing(func.current_timestamp()),
+            schema="foo",
+        )
         context.assert_(
             "ALTER TABLE foo.t ALTER COLUMN c "
             "SET DEFAULT utc_thing(CURRENT_TIMESTAMP)"
@@ -324,101 +329,98 @@ class OpTest(TestBase):
     def test_alter_column_drop_default(self):
         context = op_fixture()
         op.alter_column("t", "c", server_default=None)
-        context.assert_(
-            'ALTER TABLE t ALTER COLUMN c DROP DEFAULT'
-        )
+        context.assert_("ALTER TABLE t ALTER COLUMN c DROP DEFAULT")
 
     def test_alter_column_schema_drop_default(self):
         context = op_fixture()
-        op.alter_column("t", "c", server_default=None, schema='foo')
-        context.assert_(
-            'ALTER TABLE foo.t ALTER COLUMN c DROP DEFAULT'
-        )
+        op.alter_column("t", "c", server_default=None, schema="foo")
+        context.assert_("ALTER TABLE foo.t ALTER COLUMN c DROP DEFAULT")
 
     def test_alter_column_schema_type_unnamed(self):
-        context = op_fixture('mssql', native_boolean=False)
+        context = op_fixture("mssql", native_boolean=False)
         op.alter_column("t", "c", type_=Boolean())
         context.assert_(
-            'ALTER TABLE t ALTER COLUMN c BIT',
-            'ALTER TABLE t ADD CHECK (c IN (0, 1))'
+            "ALTER TABLE t ALTER COLUMN c BIT",
+            "ALTER TABLE t ADD CHECK (c IN (0, 1))",
         )
 
     def test_alter_column_schema_schema_type_unnamed(self):
-        context = op_fixture('mssql', native_boolean=False)
-        op.alter_column("t", "c", type_=Boolean(), schema='foo')
+        context = op_fixture("mssql", native_boolean=False)
+        op.alter_column("t", "c", type_=Boolean(), schema="foo")
         context.assert_(
-            'ALTER TABLE foo.t ALTER COLUMN c BIT',
-            'ALTER TABLE foo.t ADD CHECK (c IN (0, 1))'
+            "ALTER TABLE foo.t ALTER COLUMN c BIT",
+            "ALTER TABLE foo.t ADD CHECK (c IN (0, 1))",
         )
 
     def test_alter_column_schema_type_named(self):
-        context = op_fixture('mssql', native_boolean=False)
+        context = op_fixture("mssql", native_boolean=False)
         op.alter_column("t", "c", type_=Boolean(name="xyz"))
         context.assert_(
-            'ALTER TABLE t ALTER COLUMN c BIT',
-            'ALTER TABLE t ADD CONSTRAINT xyz CHECK (c IN (0, 1))'
+            "ALTER TABLE t ALTER COLUMN c BIT",
+            "ALTER TABLE t ADD CONSTRAINT xyz CHECK (c IN (0, 1))",
         )
 
     def test_alter_column_schema_schema_type_named(self):
-        context = op_fixture('mssql', native_boolean=False)
-        op.alter_column("t", "c", type_=Boolean(name="xyz"), schema='foo')
+        context = op_fixture("mssql", native_boolean=False)
+        op.alter_column("t", "c", type_=Boolean(name="xyz"), schema="foo")
         context.assert_(
-            'ALTER TABLE foo.t ALTER COLUMN c BIT',
-            'ALTER TABLE foo.t ADD CONSTRAINT xyz CHECK (c IN (0, 1))'
+            "ALTER TABLE foo.t ALTER COLUMN c BIT",
+            "ALTER TABLE foo.t ADD CONSTRAINT xyz CHECK (c IN (0, 1))",
         )
 
     def test_alter_column_schema_type_existing_type(self):
-        context = op_fixture('mssql', native_boolean=False)
+        context = op_fixture("mssql", native_boolean=False)
         op.alter_column(
-            "t", "c", type_=String(10), existing_type=Boolean(name="xyz"))
+            "t", "c", type_=String(10), existing_type=Boolean(name="xyz")
+        )
         context.assert_(
-            'ALTER TABLE t DROP CONSTRAINT xyz',
-            'ALTER TABLE t ALTER COLUMN c VARCHAR(10)'
+            "ALTER TABLE t DROP CONSTRAINT xyz",
+            "ALTER TABLE t ALTER COLUMN c VARCHAR(10)",
         )
 
     def test_alter_column_schema_schema_type_existing_type(self):
-        context = op_fixture('mssql', native_boolean=False)
-        op.alter_column("t", "c", type_=String(10),
-                        existing_type=Boolean(name="xyz"), schema='foo')
+        context = op_fixture("mssql", native_boolean=False)
+        op.alter_column(
+            "t",
+            "c",
+            type_=String(10),
+            existing_type=Boolean(name="xyz"),
+            schema="foo",
+        )
         context.assert_(
-            'ALTER TABLE foo.t DROP CONSTRAINT xyz',
-            'ALTER TABLE foo.t ALTER COLUMN c VARCHAR(10)'
+            "ALTER TABLE foo.t DROP CONSTRAINT xyz",
+            "ALTER TABLE foo.t ALTER COLUMN c VARCHAR(10)",
         )
 
     def test_alter_column_schema_type_existing_type_no_const(self):
-        context = op_fixture('postgresql')
+        context = op_fixture("postgresql")
         op.alter_column("t", "c", type_=String(10), existing_type=Boolean())
-        context.assert_(
-            'ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(10)'
-        )
+        context.assert_("ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(10)")
 
     def test_alter_column_schema_schema_type_existing_type_no_const(self):
-        context = op_fixture('postgresql')
-        op.alter_column("t", "c", type_=String(10), existing_type=Boolean(),
-                        schema='foo')
-        context.assert_(
-            'ALTER TABLE foo.t ALTER COLUMN c TYPE VARCHAR(10)'
+        context = op_fixture("postgresql")
+        op.alter_column(
+            "t", "c", type_=String(10), existing_type=Boolean(), schema="foo"
         )
+        context.assert_("ALTER TABLE foo.t ALTER COLUMN c TYPE VARCHAR(10)")
 
     def test_alter_column_schema_type_existing_type_no_new_type(self):
-        context = op_fixture('postgresql')
+        context = op_fixture("postgresql")
         op.alter_column("t", "c", nullable=False, existing_type=Boolean())
-        context.assert_(
-            'ALTER TABLE t ALTER COLUMN c SET NOT NULL'
-        )
+        context.assert_("ALTER TABLE t ALTER COLUMN c SET NOT NULL")
 
     def test_alter_column_schema_schema_type_existing_type_no_new_type(self):
-        context = op_fixture('postgresql')
-        op.alter_column("t", "c", nullable=False, existing_type=Boolean(),
-                        schema='foo')
-        context.assert_(
-            'ALTER TABLE foo.t ALTER COLUMN c SET NOT NULL'
+        context = op_fixture("postgresql")
+        op.alter_column(
+            "t", "c", nullable=False, existing_type=Boolean(), schema="foo"
         )
+        context.assert_("ALTER TABLE foo.t ALTER COLUMN c SET NOT NULL")
 
     def test_add_foreign_key(self):
         context = op_fixture()
-        op.create_foreign_key('fk_test', 't1', 't2',
-                              ['foo', 'bar'], ['bat', 'hoho'])
+        op.create_foreign_key(
+            "fk_test", "t1", "t2", ["foo", "bar"], ["bat", "hoho"]
+        )
         context.assert_(
             "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
             "REFERENCES t2 (bat, hoho)"
@@ -426,9 +428,15 @@ class OpTest(TestBase):
 
     def test_add_foreign_key_schema(self):
         context = op_fixture()
-        op.create_foreign_key('fk_test', 't1', 't2',
-                              ['foo', 'bar'], ['bat', 'hoho'],
-                              source_schema='foo2', referent_schema='bar2')
+        op.create_foreign_key(
+            "fk_test",
+            "t1",
+            "t2",
+            ["foo", "bar"],
+            ["bat", "hoho"],
+            source_schema="foo2",
+            referent_schema="bar2",
+        )
         context.assert_(
             "ALTER TABLE foo2.t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
             "REFERENCES bar2.t2 (bat, hoho)"
@@ -436,9 +444,15 @@ class OpTest(TestBase):
 
     def test_add_foreign_key_schema_same_tablename(self):
         context = op_fixture()
-        op.create_foreign_key('fk_test', 't1', 't1',
-                              ['foo', 'bar'], ['bat', 'hoho'],
-                              source_schema='foo2', referent_schema='bar2')
+        op.create_foreign_key(
+            "fk_test",
+            "t1",
+            "t1",
+            ["foo", "bar"],
+            ["bat", "hoho"],
+            source_schema="foo2",
+            referent_schema="bar2",
+        )
         context.assert_(
             "ALTER TABLE foo2.t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
             "REFERENCES bar2.t1 (bat, hoho)"
@@ -446,9 +460,14 @@ class OpTest(TestBase):
 
     def test_add_foreign_key_onupdate(self):
         context = op_fixture()
-        op.create_foreign_key('fk_test', 't1', 't2',
-                              ['foo', 'bar'], ['bat', 'hoho'],
-                              onupdate='CASCADE')
+        op.create_foreign_key(
+            "fk_test",
+            "t1",
+            "t2",
+            ["foo", "bar"],
+            ["bat", "hoho"],
+            onupdate="CASCADE",
+        )
         context.assert_(
             "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
             "REFERENCES t2 (bat, hoho) ON UPDATE CASCADE"
@@ -456,9 +475,14 @@ class OpTest(TestBase):
 
     def test_add_foreign_key_ondelete(self):
         context = op_fixture()
-        op.create_foreign_key('fk_test', 't1', 't2',
-                              ['foo', 'bar'], ['bat', 'hoho'],
-                              ondelete='CASCADE')
+        op.create_foreign_key(
+            "fk_test",
+            "t1",
+            "t2",
+            ["foo", "bar"],
+            ["bat", "hoho"],
+            ondelete="CASCADE",
+        )
         context.assert_(
             "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
             "REFERENCES t2 (bat, hoho) ON DELETE CASCADE"
@@ -466,9 +490,14 @@ class OpTest(TestBase):
 
     def test_add_foreign_key_deferrable(self):
         context = op_fixture()
-        op.create_foreign_key('fk_test', 't1', 't2',
-                              ['foo', 'bar'], ['bat', 'hoho'],
-                              deferrable=True)
+        op.create_foreign_key(
+            "fk_test",
+            "t1",
+            "t2",
+            ["foo", "bar"],
+            ["bat", "hoho"],
+            deferrable=True,
+        )
         context.assert_(
             "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
             "REFERENCES t2 (bat, hoho) DEFERRABLE"
@@ -476,9 +505,14 @@ class OpTest(TestBase):
 
     def test_add_foreign_key_initially(self):
         context = op_fixture()
-        op.create_foreign_key('fk_test', 't1', 't2',
-                              ['foo', 'bar'], ['bat', 'hoho'],
-                              initially='INITIAL')
+        op.create_foreign_key(
+            "fk_test",
+            "t1",
+            "t2",
+            ["foo", "bar"],
+            ["bat", "hoho"],
+            initially="INITIAL",
+        )
         context.assert_(
             "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
             "REFERENCES t2 (bat, hoho) INITIALLY INITIAL"
@@ -487,9 +521,14 @@ class OpTest(TestBase):
     @config.requirements.foreign_key_match
     def test_add_foreign_key_match(self):
         context = op_fixture()
-        op.create_foreign_key('fk_test', 't1', 't2',
-                              ['foo', 'bar'], ['bat', 'hoho'],
-                              match='SIMPLE')
+        op.create_foreign_key(
+            "fk_test",
+            "t1",
+            "t2",
+            ["foo", "bar"],
+            ["bat", "hoho"],
+            match="SIMPLE",
+        )
         context.assert_(
             "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) "
             "REFERENCES t2 (bat, hoho) MATCH SIMPLE"
@@ -497,24 +536,44 @@ class OpTest(TestBase):
 
     def test_add_foreign_key_dialect_kw(self):
         op_fixture()
-        with mock.patch(
-                "sqlalchemy.schema.ForeignKeyConstraint"
-        ) as fkc:
-            op.create_foreign_key('fk_test', 't1', 't2',
-                                  ['foo', 'bar'], ['bat', 'hoho'],
-                                  foobar_arg='xyz')
+        with mock.patch("sqlalchemy.schema.ForeignKeyConstraint") as fkc:
+            op.create_foreign_key(
+                "fk_test",
+                "t1",
+                "t2",
+                ["foo", "bar"],
+                ["bat", "hoho"],
+                foobar_arg="xyz",
+            )
             if config.requirements.foreign_key_match.enabled:
-                eq_(fkc.mock_calls[0],
-                    mock.call(['foo', 'bar'], ['t2.bat', 't2.hoho'],
-                              onupdate=None, ondelete=None, name='fk_test',
-                              foobar_arg='xyz',
-                              deferrable=None, initially=None, match=None))
+                eq_(
+                    fkc.mock_calls[0],
+                    mock.call(
+                        ["foo", "bar"],
+                        ["t2.bat", "t2.hoho"],
+                        onupdate=None,
+                        ondelete=None,
+                        name="fk_test",
+                        foobar_arg="xyz",
+                        deferrable=None,
+                        initially=None,
+                        match=None,
+                    ),
+                )
             else:
-                eq_(fkc.mock_calls[0],
-                    mock.call(['foo', 'bar'], ['t2.bat', 't2.hoho'],
-                              onupdate=None, ondelete=None, name='fk_test',
-                              foobar_arg='xyz',
-                              deferrable=None, initially=None))
+                eq_(
+                    fkc.mock_calls[0],
+                    mock.call(
+                        ["foo", "bar"],
+                        ["t2.bat", "t2.hoho"],
+                        onupdate=None,
+                        ondelete=None,
+                        name="fk_test",
+                        foobar_arg="xyz",
+                        deferrable=None,
+                        initially=None,
+                    ),
+                )
 
     def test_add_foreign_key_self_referential(self):
         context = op_fixture()
@@ -541,9 +600,7 @@ class OpTest(TestBase):
     def test_add_check_constraint(self):
         context = op_fixture()
         op.create_check_constraint(
-            "ck_user_name_len",
-            "user_table",
-            func.len(column('name')) > 5
+            "ck_user_name_len", "user_table", func.len(column("name")) > 5
         )
         context.assert_(
             "ALTER TABLE user_table ADD CONSTRAINT ck_user_name_len "
@@ -555,8 +612,8 @@ class OpTest(TestBase):
         op.create_check_constraint(
             "ck_user_name_len",
             "user_table",
-            func.len(column('name')) > 5,
-            schema='foo'
+            func.len(column("name")) > 5,
+            schema="foo",
         )
         context.assert_(
             "ALTER TABLE foo.user_table ADD CONSTRAINT ck_user_name_len "
@@ -565,7 +622,7 @@ class OpTest(TestBase):
 
     def test_add_unique_constraint(self):
         context = op_fixture()
-        op.create_unique_constraint('uk_test', 't1', ['foo', 'bar'])
+        op.create_unique_constraint("uk_test", "t1", ["foo", "bar"])
         context.assert_(
             "ALTER TABLE t1 ADD CONSTRAINT uk_test UNIQUE (foo, bar)"
         )
@@ -574,12 +631,12 @@ class OpTest(TestBase):
         context = op_fixture()
 
         op.create_foreign_key(
-            name='some_fk',
-            source='some_table',
-            referent='referred_table',
-            local_cols=['a', 'b'],
-            remote_cols=['c', 'd'],
-            ondelete='CASCADE'
+            name="some_fk",
+            source="some_table",
+            referent="referred_table",
+            local_cols=["a", "b"],
+            remote_cols=["c", "d"],
+            ondelete="CASCADE",
         )
         context.assert_(
             "ALTER TABLE some_table ADD CONSTRAINT some_fk "
@@ -590,27 +647,26 @@ class OpTest(TestBase):
     def test_add_unique_constraint_legacy_kwarg(self):
         context = op_fixture()
         op.create_unique_constraint(
-            name='uk_test',
-            source='t1',
-            local_cols=['foo', 'bar'])
+            name="uk_test", source="t1", local_cols=["foo", "bar"]
+        )
         context.assert_(
             "ALTER TABLE t1 ADD CONSTRAINT uk_test UNIQUE (foo, bar)"
         )
 
     def test_drop_constraint_legacy_kwarg(self):
         context = op_fixture()
-        op.drop_constraint(name='pk_name',
-                           table_name='sometable',
-                           type_='primary')
-        context.assert_(
-            "ALTER TABLE sometable DROP CONSTRAINT pk_name"
+        op.drop_constraint(
+            name="pk_name", table_name="sometable", type_="primary"
         )
+        context.assert_("ALTER TABLE sometable DROP CONSTRAINT pk_name")
 
     def test_create_pk_legacy_kwarg(self):
         context = op_fixture()
-        op.create_primary_key(name=None,
-                              table_name='sometable',
-                              cols=['router_id', 'l3_agent_id'])
+        op.create_primary_key(
+            name=None,
+            table_name="sometable",
+            cols=["router_id", "l3_agent_id"],
+        )
         context.assert_(
             "ALTER TABLE sometable ADD PRIMARY KEY (router_id, l3_agent_id)"
         )
@@ -623,57 +679,50 @@ class OpTest(TestBase):
             "missing required positional argument: columns",
             op.create_primary_key,
             name=None,
-            table_name='sometable',
-            wrong_cols=['router_id', 'l3_agent_id']
+            table_name="sometable",
+            wrong_cols=["router_id", "l3_agent_id"],
         )
 
     def test_add_unique_constraint_schema(self):
         context = op_fixture()
         op.create_unique_constraint(
-            'uk_test', 't1', ['foo', 'bar'], schema='foo')
+            "uk_test", "t1", ["foo", "bar"], schema="foo"
+        )
         context.assert_(
             "ALTER TABLE foo.t1 ADD CONSTRAINT uk_test UNIQUE (foo, bar)"
         )
 
     def test_drop_constraint(self):
         context = op_fixture()
-        op.drop_constraint('foo_bar_bat', 't1')
-        context.assert_(
-            "ALTER TABLE t1 DROP CONSTRAINT foo_bar_bat"
-        )
+        op.drop_constraint("foo_bar_bat", "t1")
+        context.assert_("ALTER TABLE t1 DROP CONSTRAINT foo_bar_bat")
 
     def test_drop_constraint_schema(self):
         context = op_fixture()
-        op.drop_constraint('foo_bar_bat', 't1', schema='foo')
-        context.assert_(
-            "ALTER TABLE foo.t1 DROP CONSTRAINT foo_bar_bat"
-        )
+        op.drop_constraint("foo_bar_bat", "t1", schema="foo")
+        context.assert_("ALTER TABLE foo.t1 DROP CONSTRAINT foo_bar_bat")
 
     def test_create_index(self):
         context = op_fixture()
-        op.create_index('ik_test', 't1', ['foo', 'bar'])
-        context.assert_(
-            "CREATE INDEX ik_test ON t1 (foo, bar)"
-        )
+        op.create_index("ik_test", "t1", ["foo", "bar"])
+        context.assert_("CREATE INDEX ik_test ON t1 (foo, bar)")
 
     def test_create_unique_index(self):
         context = op_fixture()
-        op.create_index('ik_test', 't1', ['foo', 'bar'], unique=True)
-        context.assert_(
-            "CREATE UNIQUE INDEX ik_test ON t1 (foo, bar)"
-        )
+        op.create_index("ik_test", "t1", ["foo", "bar"], unique=True)
+        context.assert_("CREATE UNIQUE INDEX ik_test ON t1 (foo, bar)")
 
     def test_create_index_quote_flag(self):
         context = op_fixture()
-        op.create_index('ik_test', 't1', ['foo', 'bar'], quote=True)
-        context.assert_(
-            'CREATE INDEX "ik_test" ON t1 (foo, bar)'
-        )
+        op.create_index("ik_test", "t1", ["foo", "bar"], quote=True)
+        context.assert_('CREATE INDEX "ik_test" ON t1 (foo, bar)')
 
     def test_create_index_table_col_event(self):
         context = op_fixture()
 
-        op.create_index('ik_test', 'tbl_with_auto_appended_column', ['foo', 'bar'])
+        op.create_index(
+            "ik_test", "tbl_with_auto_appended_column", ["foo", "bar"]
+        )
         context.assert_(
             "CREATE INDEX ik_test ON tbl_with_auto_appended_column (foo, bar)"
         )
@@ -681,8 +730,8 @@ class OpTest(TestBase):
     def test_add_unique_constraint_col_event(self):
         context = op_fixture()
         op.create_unique_constraint(
-            'ik_test',
-            'tbl_with_auto_appended_column', ['foo', 'bar'])
+            "ik_test", "tbl_with_auto_appended_column", ["foo", "bar"]
+        )
         context.assert_(
             "ALTER TABLE tbl_with_auto_appended_column "
             "ADD CONSTRAINT ik_test UNIQUE (foo, bar)"
@@ -690,45 +739,35 @@ class OpTest(TestBase):
 
     def test_create_index_schema(self):
         context = op_fixture()
-        op.create_index('ik_test', 't1', ['foo', 'bar'], schema='foo')
-        context.assert_(
-            "CREATE INDEX ik_test ON foo.t1 (foo, bar)"
-        )
+        op.create_index("ik_test", "t1", ["foo", "bar"], schema="foo")
+        context.assert_("CREATE INDEX ik_test ON foo.t1 (foo, bar)")
 
     def test_drop_index(self):
         context = op_fixture()
-        op.drop_index('ik_test')
-        context.assert_(
-            "DROP INDEX ik_test"
-        )
+        op.drop_index("ik_test")
+        context.assert_("DROP INDEX ik_test")
 
     def test_drop_index_schema(self):
         context = op_fixture()
-        op.drop_index('ik_test', schema='foo')
-        context.assert_(
-            "DROP INDEX foo.ik_test"
-        )
+        op.drop_index("ik_test", schema="foo")
+        context.assert_("DROP INDEX foo.ik_test")
 
     def test_drop_table(self):
         context = op_fixture()
-        op.drop_table('tb_test')
-        context.assert_(
-            "DROP TABLE tb_test"
-        )
+        op.drop_table("tb_test")
+        context.assert_("DROP TABLE tb_test")
 
     def test_drop_table_schema(self):
         context = op_fixture()
-        op.drop_table('tb_test', schema='foo')
-        context.assert_(
-            "DROP TABLE foo.tb_test"
-        )
+        op.drop_table("tb_test", schema="foo")
+        context.assert_("DROP TABLE foo.tb_test")
 
     def test_create_table_selfref(self):
         context = op_fixture()
         op.create_table(
             "some_table",
-            Column('id', Integer, primary_key=True),
-            Column('st_id', Integer, ForeignKey('some_table.id'))
+            Column("id", Integer, primary_key=True),
+            Column("st_id", Integer, ForeignKey("some_table.id")),
         )
         context.assert_(
             "CREATE TABLE some_table ("
@@ -742,9 +781,9 @@ class OpTest(TestBase):
         context = op_fixture()
         t1 = op.create_table(
             "some_table",
-            Column('id', Integer, primary_key=True),
-            Column('foo_id', Integer, ForeignKey('foo.id')),
-            schema='schema'
+            Column("id", Integer, primary_key=True),
+            Column("foo_id", Integer, ForeignKey("foo.id")),
+            schema="schema",
         )
         context.assert_(
             "CREATE TABLE schema.some_table ("
@@ -760,9 +799,9 @@ class OpTest(TestBase):
         context = op_fixture()
         t1 = op.create_table(
             "some_table",
-            Column('x', Integer),
-            Column('y', Integer),
-            Column('z', Integer),
+            Column("x", Integer),
+            Column("y", Integer),
+            Column("z", Integer),
         )
         context.assert_(
             "CREATE TABLE some_table (x INTEGER, y INTEGER, z INTEGER)"
@@ -773,9 +812,9 @@ class OpTest(TestBase):
         context = op_fixture()
         op.create_table(
             "some_table",
-            Column('id', Integer, primary_key=True),
-            Column('foo_id', Integer, ForeignKey('foo.id')),
-            Column('foo_bar', Integer, ForeignKey('foo.bar')),
+            Column("id", Integer, primary_key=True),
+            Column("foo_id", Integer, ForeignKey("foo.id")),
+            Column("foo_bar", Integer, ForeignKey("foo.bar")),
         )
         context.assert_(
             "CREATE TABLE some_table ("
@@ -792,27 +831,26 @@ class OpTest(TestBase):
         from sqlalchemy.sql import table, column
         from sqlalchemy import String, Integer
 
-        account = table('account',
-                        column('name', String),
-                        column('id', Integer)
-                        )
+        account = table(
+            "account", column("name", String), column("id", Integer)
+        )
         op.execute(
-            account.update().
-            where(account.c.name == op.inline_literal('account 1')).
-            values({'name': op.inline_literal('account 2')})
+            account.update()
+            .where(account.c.name == op.inline_literal("account 1"))
+            .values({"name": op.inline_literal("account 2")})
         )
         op.execute(
-            account.update().
-            where(account.c.id == op.inline_literal(1)).
-            values({'id': op.inline_literal(2)})
+            account.update()
+            .where(account.c.id == op.inline_literal(1))
+            .values({"id": op.inline_literal(2)})
         )
         context.assert_(
             "UPDATE account SET name='account 2' WHERE account.name = 'account 1'",
-            "UPDATE account SET id=2 WHERE account.id = 1"
+            "UPDATE account SET id=2 WHERE account.id = 1",
         )
 
     def test_cant_op(self):
-        if hasattr(op, '_proxy'):
+        if hasattr(op, "_proxy"):
             del op._proxy
         assert_raises_message(
             NameError,
@@ -820,7 +858,8 @@ class OpTest(TestBase):
             "proxy object has not yet been established "
             "for the Alembic 'Operations' class.  "
             "Try placing this code inside a callable.",
-            op.inline_literal, "asdf"
+            op.inline_literal,
+            "asdf",
         )
 
     def test_naming_changes(self):
@@ -832,17 +871,17 @@ class OpTest(TestBase):
         op.alter_column("t", "c", new_column_name="x")
         context.assert_("ALTER TABLE t RENAME c TO x")
 
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         op.drop_constraint("f1", "t1", type="foreignkey")
         context.assert_("ALTER TABLE t1 DROP FOREIGN KEY f1")
 
-        context = op_fixture('mysql')
+        context = op_fixture("mysql")
         op.drop_constraint("f1", "t1", type_="foreignkey")
         context.assert_("ALTER TABLE t1 DROP FOREIGN KEY f1")
 
     def test_naming_changes_drop_idx(self):
-        context = op_fixture('mssql')
-        op.drop_index('ik_test', tablename='t1')
+        context = op_fixture("mssql")
+        op.drop_index("ik_test", tablename="t1")
         context.assert_("DROP INDEX ik_test ON t1")
 
 
@@ -852,20 +891,19 @@ class SQLModeOpTest(TestBase):
         from sqlalchemy.sql import table, column
         from sqlalchemy import String, Integer
 
-        account = table('account',
-                        column('name', String),
-                        column('id', Integer)
-                        )
+        account = table(
+            "account", column("name", String), column("id", Integer)
+        )
         op.execute(
-            account.update().
-            where(account.c.name == op.inline_literal('account 1')).
-            values({'name': op.inline_literal('account 2')})
+            account.update()
+            .where(account.c.name == op.inline_literal("account 1"))
+            .values({"name": op.inline_literal("account 2")})
         )
-        op.execute(text("update table set foo=:bar").bindparams(bar='bat'))
+        op.execute(text("update table set foo=:bar").bindparams(bar="bat"))
         context.assert_(
             "UPDATE account SET name='account 2' "
             "WHERE account.name = 'account 1'",
-            "update table set foo='bat'"
+            "update table set foo='bat'",
         )
 
     def test_create_table_literal_binds(self):
@@ -873,8 +911,8 @@ class SQLModeOpTest(TestBase):
 
         op.create_table(
             "some_table",
-            Column('id', Integer, primary_key=True),
-            Column('st_id', Integer, ForeignKey('some_table.id'))
+            Column("id", Integer, primary_key=True),
+            Column("st_id", Integer, ForeignKey("some_table.id")),
         )
 
         context.assert_(
@@ -907,7 +945,7 @@ class CustomOpTest(TestBase):
             operations.execute("CREATE SEQUENCE %s" % operation.sequence_name)
 
         context = op_fixture()
-        op.create_sequence('foob')
+        op.create_sequence("foob")
         context.assert_("CREATE SEQUENCE foob")
 
 
@@ -923,48 +961,36 @@ class EnsureOrigObjectFromToTest(TestBase):
 
     def test_drop_index(self):
         schema_obj = schemaobj.SchemaObjects()
-        idx = schema_obj.index('x', 'y', ['z'])
+        idx = schema_obj.index("x", "y", ["z"])
         op = ops.DropIndexOp.from_index(idx)
-        is_(
-            op.to_index(), idx
-        )
+        is_(op.to_index(), idx)
 
     def test_create_index(self):
         schema_obj = schemaobj.SchemaObjects()
-        idx = schema_obj.index('x', 'y', ['z'])
+        idx = schema_obj.index("x", "y", ["z"])
         op = ops.CreateIndexOp.from_index(idx)
-        is_(
-            op.to_index(), idx
-        )
+        is_(op.to_index(), idx)
 
     def test_drop_table(self):
         schema_obj = schemaobj.SchemaObjects()
-        table = schema_obj.table('x', Column('q', Integer))
+        table = schema_obj.table("x", Column("q", Integer))
         op = ops.DropTableOp.from_table(table)
-        is_(
-            op.to_table(), table
-        )
+        is_(op.to_table(), table)
 
     def test_create_table(self):
         schema_obj = schemaobj.SchemaObjects()
-        table = schema_obj.table('x', Column('q', Integer))
+        table = schema_obj.table("x", Column("q", Integer))
         op = ops.CreateTableOp.from_table(table)
-        is_(
-            op.to_table(), table
-        )
+        is_(op.to_table(), table)
 
     def test_drop_unique_constraint(self):
         schema_obj = schemaobj.SchemaObjects()
-        const = schema_obj.unique_constraint('x', 'foobar', ['a'])
+        const = schema_obj.unique_constraint("x", "foobar", ["a"])
         op = ops.DropConstraintOp.from_constraint(const)
-        is_(
-            op.to_constraint(), const
-        )
+        is_(op.to_constraint(), const)
 
     def test_drop_constraint_not_available(self):
-        op = ops.DropConstraintOp('x', 'y', type_='unique')
+        op = ops.DropConstraintOp("x", "y", type_="unique")
         assert_raises_message(
-            ValueError,
-            "constraint cannot be produced",
-            op.to_constraint
+            ValueError, "constraint cannot be produced", op.to_constraint
         )
index fd70faafcf2726261ce10da432525da8f6b4d8af..fbcd181545ab07dec8118e0aaf962d5c55da2ce3 100644 (file)
@@ -1,5 +1,11 @@
-from sqlalchemy import Integer, Column, \
-    Table, Boolean, MetaData, CheckConstraint
+from sqlalchemy import (
+    Integer,
+    Column,
+    Table,
+    Boolean,
+    MetaData,
+    CheckConstraint,
+)
 from sqlalchemy.sql import column, func
 
 from alembic import op
@@ -9,16 +15,14 @@ from alembic.testing.fixtures import TestBase
 
 
 class AutoNamingConventionTest(TestBase):
-    __requires__ = ('sqlalchemy_094', )
+    __requires__ = ("sqlalchemy_094",)
 
     def test_add_check_constraint(self):
-        context = op_fixture(naming_convention={
-            "ck": "ck_%(table_name)s_%(constraint_name)s"
-        })
+        context = op_fixture(
+            naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+        )
         op.create_check_constraint(
-            "foo",
-            "user_table",
-            func.len(column('name')) > 5
+            "foo", "user_table", func.len(column("name")) > 5
         )
         context.assert_(
             "ALTER TABLE user_table ADD CONSTRAINT ck_user_table_foo "
@@ -26,13 +30,9 @@ class AutoNamingConventionTest(TestBase):
         )
 
     def test_add_check_constraint_name_is_none(self):
-        context = op_fixture(naming_convention={
-            "ck": "ck_%(table_name)s_foo"
-        })
+        context = op_fixture(naming_convention={"ck": "ck_%(table_name)s_foo"})
         op.create_check_constraint(
-            None,
-            "user_table",
-            func.len(column('name')) > 5
+            None, "user_table", func.len(column("name")) > 5
         )
         context.assert_(
             "ALTER TABLE user_table ADD CONSTRAINT ck_user_table_foo "
@@ -40,44 +40,29 @@ class AutoNamingConventionTest(TestBase):
         )
 
     def test_add_unique_constraint_name_is_none(self):
-        context = op_fixture(naming_convention={
-            "uq": "uq_%(table_name)s_foo"
-        })
-        op.create_unique_constraint(
-            None,
-            "user_table",
-            'x'
-        )
+        context = op_fixture(naming_convention={"uq": "uq_%(table_name)s_foo"})
+        op.create_unique_constraint(None, "user_table", "x")
         context.assert_(
             "ALTER TABLE user_table ADD CONSTRAINT uq_user_table_foo UNIQUE (x)"
         )
 
     def test_add_index_name_is_none(self):
-        context = op_fixture(naming_convention={
-            "ix": "ix_%(table_name)s_foo"
-        })
-        op.create_index(
-            None,
-            "user_table",
-            'x'
-        )
-        context.assert_(
-            "CREATE INDEX ix_user_table_foo ON user_table (x)"
-        )
+        context = op_fixture(naming_convention={"ix": "ix_%(table_name)s_foo"})
+        op.create_index(None, "user_table", "x")
+        context.assert_("CREATE INDEX ix_user_table_foo ON user_table (x)")
 
     def test_add_check_constraint_already_named_from_schema(self):
         m1 = MetaData(
-            naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
+            naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+        )
         ck = CheckConstraint("im a constraint", name="cc1")
-        Table('t', m1, Column('x'), ck)
+        Table("t", m1, Column("x"), ck)
 
         context = op_fixture(
-            naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
-
-        op.create_table(
-            "some_table",
-            Column('x', Integer, ck),
+            naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
         )
+
+        op.create_table("some_table", Column("x", Integer, ck))
         context.assert_(
             "CREATE TABLE some_table "
             "(x INTEGER CONSTRAINT ck_t_cc1 CHECK (im a constraint))"
@@ -85,11 +70,12 @@ class AutoNamingConventionTest(TestBase):
 
     def test_add_check_constraint_inline_on_table(self):
         context = op_fixture(
-            naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
+            naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+        )
         op.create_table(
             "some_table",
-            Column('x', Integer),
-            CheckConstraint("im a constraint", name="cc1")
+            Column("x", Integer),
+            CheckConstraint("im a constraint", name="cc1"),
         )
         context.assert_(
             "CREATE TABLE some_table "
@@ -98,11 +84,12 @@ class AutoNamingConventionTest(TestBase):
 
     def test_add_check_constraint_inline_on_table_w_f(self):
         context = op_fixture(
-            naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
+            naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+        )
         op.create_table(
             "some_table",
-            Column('x', Integer),
-            CheckConstraint("im a constraint", name=op.f("ck_some_table_cc1"))
+            Column("x", Integer),
+            CheckConstraint("im a constraint", name=op.f("ck_some_table_cc1")),
         )
         context.assert_(
             "CREATE TABLE some_table "
@@ -111,10 +98,13 @@ class AutoNamingConventionTest(TestBase):
 
     def test_add_check_constraint_inline_on_column(self):
         context = op_fixture(
-            naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
+            naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+        )
         op.create_table(
             "some_table",
-            Column('x', Integer, CheckConstraint("im a constraint", name="cc1"))
+            Column(
+                "x", Integer, CheckConstraint("im a constraint", name="cc1")
+            ),
         )
         context.assert_(
             "CREATE TABLE some_table "
@@ -123,12 +113,15 @@ class AutoNamingConventionTest(TestBase):
 
     def test_add_check_constraint_inline_on_column_w_f(self):
         context = op_fixture(
-            naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"})
+            naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+        )
         op.create_table(
             "some_table",
             Column(
-                'x', Integer,
-                CheckConstraint("im a constraint", name=op.f("ck_q_cc1")))
+                "x",
+                Integer,
+                CheckConstraint("im a constraint", name=op.f("ck_q_cc1")),
+            ),
         )
         context.assert_(
             "CREATE TABLE some_table "
@@ -136,22 +129,23 @@ class AutoNamingConventionTest(TestBase):
         )
 
     def test_add_column_schema_type(self):
-        context = op_fixture(naming_convention={
-            "ck": "ck_%(table_name)s_%(constraint_name)s"
-        })
-        op.add_column('t1', Column('c1', Boolean(name='foo'), nullable=False))
+        context = op_fixture(
+            naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+        )
+        op.add_column("t1", Column("c1", Boolean(name="foo"), nullable=False))
         context.assert_(
-            'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL',
-            'ALTER TABLE t1 ADD CONSTRAINT ck_t1_foo CHECK (c1 IN (0, 1))'
+            "ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL",
+            "ALTER TABLE t1 ADD CONSTRAINT ck_t1_foo CHECK (c1 IN (0, 1))",
         )
 
     def test_add_column_schema_type_w_f(self):
-        context = op_fixture(naming_convention={
-            "ck": "ck_%(table_name)s_%(constraint_name)s"
-        })
+        context = op_fixture(
+            naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
+        )
         op.add_column(
-            't1', Column('c1', Boolean(name=op.f('foo')), nullable=False))
+            "t1", Column("c1", Boolean(name=op.f("foo")), nullable=False)
+        )
         context.assert_(
-            'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL',
-            'ALTER TABLE t1 ADD CONSTRAINT foo CHECK (c1 IN (0, 1))'
+            "ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL",
+            "ALTER TABLE t1 ADD CONSTRAINT foo CHECK (c1 IN (0, 1))",
         )
index 8b9c9e5ff3629c60d7c2916d0938475f2a373fba..86e0ecec4557eca267ee1a985b61fa49b146b376 100644 (file)
@@ -1,23 +1,24 @@
-
 from sqlalchemy import Integer, Column
 
 from alembic import op, command
 from alembic.testing.fixtures import TestBase
 
 from alembic.testing.fixtures import op_fixture, capture_context_buffer
-from alembic.testing.env import _no_sql_testing_config, staging_env, \
-    three_rev_fixture, clear_staging_env
+from alembic.testing.env import (
+    _no_sql_testing_config,
+    staging_env,
+    three_rev_fixture,
+    clear_staging_env,
+)
 
 
 class FullEnvironmentTests(TestBase):
-
     @classmethod
     def setup_class(cls):
         staging_env()
         cls.cfg = cfg = _no_sql_testing_config("oracle")
 
-        cls.a, cls.b, cls.c = \
-            three_rev_fixture(cfg)
+        cls.a, cls.b, cls.c = three_rev_fixture(cfg)
 
     @classmethod
     def teardown_class(cls):
@@ -42,113 +43,99 @@ class FullEnvironmentTests(TestBase):
 
 
 class OpTest(TestBase):
-
     def test_add_column(self):
-        context = op_fixture('oracle')
-        op.add_column('t1', Column('c1', Integer, nullable=False))
+        context = op_fixture("oracle")
+        op.add_column("t1", Column("c1", Integer, nullable=False))
         context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL")
 
     def test_add_column_with_default(self):
         context = op_fixture("oracle")
         op.add_column(
-            't1', Column('c1', Integer, nullable=False, server_default="12"))
+            "t1", Column("c1", Integer, nullable=False, server_default="12")
+        )
         context.assert_("ALTER TABLE t1 ADD c1 INTEGER DEFAULT '12' NOT NULL")
 
     def test_alter_column_rename_oracle(self):
-        context = op_fixture('oracle')
+        context = op_fixture("oracle")
         op.alter_column("t", "c", name="x")
-        context.assert_(
-            "ALTER TABLE t RENAME COLUMN c TO x"
-        )
+        context.assert_("ALTER TABLE t RENAME COLUMN c TO x")
 
     def test_alter_column_new_type(self):
-        context = op_fixture('oracle')
+        context = op_fixture("oracle")
         op.alter_column("t", "c", type_=Integer)
-        context.assert_(
-            'ALTER TABLE t MODIFY c INTEGER'
-        )
+        context.assert_("ALTER TABLE t MODIFY c INTEGER")
 
     def test_drop_index(self):
-        context = op_fixture('oracle')
-        op.drop_index('my_idx', 'my_table')
+        context = op_fixture("oracle")
+        op.drop_index("my_idx", "my_table")
         context.assert_contains("DROP INDEX my_idx")
 
     def test_drop_column_w_default(self):
-        context = op_fixture('oracle')
-        op.drop_column('t1', 'c1')
-        context.assert_(
-            "ALTER TABLE t1 DROP COLUMN c1"
-        )
+        context = op_fixture("oracle")
+        op.drop_column("t1", "c1")
+        context.assert_("ALTER TABLE t1 DROP COLUMN c1")
 
     def test_drop_column_w_check(self):
-        context = op_fixture('oracle')
-        op.drop_column('t1', 'c1')
-        context.assert_(
-            "ALTER TABLE t1 DROP COLUMN c1"
-        )
+        context = op_fixture("oracle")
+        op.drop_column("t1", "c1")
+        context.assert_("ALTER TABLE t1 DROP COLUMN c1")
 
     def test_alter_column_nullable_w_existing_type(self):
-        context = op_fixture('oracle')
+        context = op_fixture("oracle")
         op.alter_column("t", "c", nullable=True, existing_type=Integer)
-        context.assert_(
-            "ALTER TABLE t MODIFY c NULL"
-        )
+        context.assert_("ALTER TABLE t MODIFY c NULL")
 
     def test_alter_column_not_nullable_w_existing_type(self):
-        context = op_fixture('oracle')
+        context = op_fixture("oracle")
         op.alter_column("t", "c", nullable=False, existing_type=Integer)
-        context.assert_(
-            "ALTER TABLE t MODIFY c NOT NULL"
-        )
+        context.assert_("ALTER TABLE t MODIFY c NOT NULL")
 
     def test_alter_column_nullable_w_new_type(self):
-        context = op_fixture('oracle')
+        context = op_fixture("oracle")
         op.alter_column("t", "c", nullable=True, type_=Integer)
         context.assert_(
-            "ALTER TABLE t MODIFY c NULL",
-            'ALTER TABLE t MODIFY c INTEGER'
+            "ALTER TABLE t MODIFY c NULL", "ALTER TABLE t MODIFY c INTEGER"
         )
 
     def test_alter_column_not_nullable_w_new_type(self):
-        context = op_fixture('oracle')
+        context = op_fixture("oracle")
         op.alter_column("t", "c", nullable=False, type_=Integer)
         context.assert_(
-            "ALTER TABLE t MODIFY c NOT NULL",
-            "ALTER TABLE t MODIFY c INTEGER"
+            "ALTER TABLE t MODIFY c NOT NULL", "ALTER TABLE t MODIFY c INTEGER"
         )
 
     def test_alter_add_server_default(self):
-        context = op_fixture('oracle')
+        context = op_fixture("oracle")
         op.alter_column("t", "c", server_default="5")
-        context.assert_(
-            "ALTER TABLE t MODIFY c DEFAULT '5'"
-        )
+        context.assert_("ALTER TABLE t MODIFY c DEFAULT '5'")
 
     def test_alter_replace_server_default(self):
-        context = op_fixture('oracle')
+        context = op_fixture("oracle")
         op.alter_column(
-            "t", "c", server_default="5", existing_server_default="6")
-        context.assert_(
-            "ALTER TABLE t MODIFY c DEFAULT '5'"
+            "t", "c", server_default="5", existing_server_default="6"
         )
+        context.assert_("ALTER TABLE t MODIFY c DEFAULT '5'")
 
     def test_alter_remove_server_default(self):
-        context = op_fixture('oracle')
+        context = op_fixture("oracle")
         op.alter_column("t", "c", server_default=None)
-        context.assert_(
-            "ALTER TABLE t MODIFY c DEFAULT NULL"
-        )
+        context.assert_("ALTER TABLE t MODIFY c DEFAULT NULL")
 
     def test_alter_do_everything(self):
-        context = op_fixture('oracle')
+        context = op_fixture("oracle")
         op.alter_column(
-            "t", "c", name="c2", nullable=True,
-            type_=Integer, server_default="5")
+            "t",
+            "c",
+            name="c2",
+            nullable=True,
+            type_=Integer,
+            server_default="5",
+        )
         context.assert_(
-            'ALTER TABLE t MODIFY c NULL',
+            "ALTER TABLE t MODIFY c NULL",
             "ALTER TABLE t MODIFY c DEFAULT '5'",
-            'ALTER TABLE t MODIFY c INTEGER',
-            'ALTER TABLE t RENAME COLUMN c TO c2'
+            "ALTER TABLE t MODIFY c INTEGER",
+            "ALTER TABLE t RENAME COLUMN c TO c2",
         )
 
     # TODO: when we add schema support
index 23ec49ca3dbbadde12d6dedf6214abff508bf965..61ba2d10b66ee1c040f337229bf1067d77123f44 100644 (file)
@@ -1,13 +1,28 @@
-
-from sqlalchemy import DateTime, MetaData, Table, Column, text, Integer, \
-    String, Interval, Sequence, Numeric, BigInteger, Float, Numeric
+from sqlalchemy import (
+    DateTime,
+    MetaData,
+    Table,
+    Column,
+    text,
+    Integer,
+    String,
+    Interval,
+    Sequence,
+    Numeric,
+    BigInteger,
+    Float,
+    Numeric,
+)
 from sqlalchemy.dialects.postgresql import ARRAY, UUID, BYTEA
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy import types
 from alembic.operations import Operations
 from sqlalchemy.sql import table, column
-from alembic.autogenerate.compare import \
-    _compare_server_default, _compare_tables, _render_server_default_for_compare
+from alembic.autogenerate.compare import (
+    _compare_server_default,
+    _compare_tables,
+    _render_server_default_for_compare,
+)
 
 from alembic.operations import ops
 from alembic import command, util
@@ -16,8 +31,12 @@ from alembic.script import ScriptDirectory
 from alembic.autogenerate import api
 
 from alembic.testing import eq_, provide_metadata
-from alembic.testing.env import staging_env, clear_staging_env, \
-    _no_sql_testing_config, write_script
+from alembic.testing.env import (
+    staging_env,
+    clear_staging_env,
+    _no_sql_testing_config,
+    write_script,
+)
 from alembic.testing.fixtures import capture_context_buffer
 from alembic.testing.fixtures import TestBase
 from alembic.testing.fixtures import op_fixture
@@ -36,79 +55,79 @@ if util.sqla_09:
 
 
 class PostgresqlOpTest(TestBase):
-
     def test_rename_table_postgresql(self):
         context = op_fixture("postgresql")
-        op.rename_table('t1', 't2')
+        op.rename_table("t1", "t2")
         context.assert_("ALTER TABLE t1 RENAME TO t2")
 
     def test_rename_table_schema_postgresql(self):
         context = op_fixture("postgresql")
-        op.rename_table('t1', 't2', schema="foo")
+        op.rename_table("t1", "t2", schema="foo")
         context.assert_("ALTER TABLE foo.t1 RENAME TO t2")
 
     def test_create_index_postgresql_expressions(self):
         context = op_fixture("postgresql")
         op.create_index(
-            'geocoded',
-            'locations',
-            [text('lower(coordinates)')],
-            postgresql_where=text("locations.coordinates != Null"))
+            "geocoded",
+            "locations",
+            [text("lower(coordinates)")],
+            postgresql_where=text("locations.coordinates != Null"),
+        )
         context.assert_(
             "CREATE INDEX geocoded ON locations (lower(coordinates)) "
-            "WHERE locations.coordinates != Null")
+            "WHERE locations.coordinates != Null"
+        )
 
     def test_create_index_postgresql_where(self):
         context = op_fixture("postgresql")
         op.create_index(
-            'geocoded',
-            'locations',
-            ['coordinates'],
-            postgresql_where=text("locations.coordinates != Null"))
+            "geocoded",
+            "locations",
+            ["coordinates"],
+            postgresql_where=text("locations.coordinates != Null"),
+        )
         context.assert_(
             "CREATE INDEX geocoded ON locations (coordinates) "
-            "WHERE locations.coordinates != Null")
+            "WHERE locations.coordinates != Null"
+        )
 
     @config.requirements.fail_before_sqla_099
     def test_create_index_postgresql_concurrently(self):
         context = op_fixture("postgresql")
         op.create_index(
-            'geocoded',
-            'locations',
-            ['coordinates'],
-            postgresql_concurrently=True)
+            "geocoded",
+            "locations",
+            ["coordinates"],
+            postgresql_concurrently=True,
+        )
         context.assert_(
-            "CREATE INDEX CONCURRENTLY geocoded ON locations (coordinates)")
+            "CREATE INDEX CONCURRENTLY geocoded ON locations (coordinates)"
+        )
 
     @config.requirements.fail_before_sqla_110
     def test_drop_index_postgresql_concurrently(self):
         context = op_fixture("postgresql")
-        op.drop_index(
-            'geocoded',
-            'locations',
-            postgresql_concurrently=True)
-        context.assert_(
-            "DROP INDEX CONCURRENTLY geocoded")
+        op.drop_index("geocoded", "locations", postgresql_concurrently=True)
+        context.assert_("DROP INDEX CONCURRENTLY geocoded")
 
     def test_alter_column_type_using(self):
-        context = op_fixture('postgresql')
-        op.alter_column("t", "c", type_=Integer, postgresql_using='c::integer')
+        context = op_fixture("postgresql")
+        op.alter_column("t", "c", type_=Integer, postgresql_using="c::integer")
         context.assert_(
-            'ALTER TABLE t ALTER COLUMN c TYPE INTEGER USING c::integer'
+            "ALTER TABLE t ALTER COLUMN c TYPE INTEGER USING c::integer"
         )
 
     def test_col_w_pk_is_serial(self):
         context = op_fixture("postgresql")
-        op.add_column("some_table", Column('q', Integer, primary_key=True))
-        context.assert_(
-            'ALTER TABLE some_table ADD COLUMN q SERIAL NOT NULL'
-        )
+        op.add_column("some_table", Column("q", Integer, primary_key=True))
+        context.assert_("ALTER TABLE some_table ADD COLUMN q SERIAL NOT NULL")
 
     @config.requirements.fail_before_sqla_100
     def test_create_exclude_constraint(self):
         context = op_fixture("postgresql")
         op.create_exclude_constraint(
-            "ex1", "t1", ('x', '>'), where='x > 5', using="gist")
+            "ex1", "t1", ("x", ">"), where="x > 5", using="gist"
+        )
         context.assert_(
             "ALTER TABLE t1 ADD CONSTRAINT ex1 EXCLUDE USING gist (x WITH >) "
             "WHERE (x > 5)"
@@ -118,8 +137,12 @@ class PostgresqlOpTest(TestBase):
     def test_create_exclude_constraint_quoted_literal(self):
         context = op_fixture("postgresql")
         op.create_exclude_constraint(
-            "ex1", "SomeTable", ('"SomeColumn"', '>'),
-            where='"SomeColumn" > 5', using="gist")
+            "ex1",
+            "SomeTable",
+            ('"SomeColumn"', ">"),
+            where='"SomeColumn" > 5',
+            using="gist",
+        )
         context.assert_(
             'ALTER TABLE "SomeTable" ADD CONSTRAINT ex1 EXCLUDE USING gist '
             '("SomeColumn" WITH >) WHERE ("SomeColumn" > 5)'
@@ -129,8 +152,12 @@ class PostgresqlOpTest(TestBase):
     def test_create_exclude_constraint_quoted_column(self):
         context = op_fixture("postgresql")
         op.create_exclude_constraint(
-            "ex1", "SomeTable", (column("SomeColumn"), '>'),
-            where=column("SomeColumn") > 5, using="gist")
+            "ex1",
+            "SomeTable",
+            (column("SomeColumn"), ">"),
+            where=column("SomeColumn") > 5,
+            using="gist",
+        )
         context.assert_(
             'ALTER TABLE "SomeTable" ADD CONSTRAINT ex1 EXCLUDE '
             'USING gist ("SomeColumn" WITH >) WHERE ("SomeColumn" > 5)'
@@ -138,7 +165,6 @@ class PostgresqlOpTest(TestBase):
 
 
 class PGOfflineEnumTest(TestBase):
-
     def setUp(self):
         staging_env()
         self.cfg = cfg = _no_sql_testing_config()
@@ -152,7 +178,10 @@ class PGOfflineEnumTest(TestBase):
         clear_staging_env()
 
     def _inline_enum_script(self):
-        write_script(self.script, self.rid, """
+        write_script(
+            self.script,
+            self.rid,
+            """
 revision = '%s'
 down_revision = None
 
@@ -169,10 +198,15 @@ def upgrade():
 
 def downgrade():
     op.drop_table("sometable")
-""" % self.rid)
+"""
+            % self.rid,
+        )
 
     def _distinct_enum_script(self):
-        write_script(self.script, self.rid, """
+        write_script(
+            self.script,
+            self.rid,
+            """
 revision = '%s'
 down_revision = None
 
@@ -193,14 +227,18 @@ def downgrade():
     op.drop_table("sometable")
     ENUM(name="pgenum").drop(op.get_bind(), checkfirst=False)
 
-""" % self.rid)
+"""
+            % self.rid,
+        )
 
     def test_offline_inline_enum_create(self):
         self._inline_enum_script()
         with capture_context_buffer() as buf:
             command.upgrade(self.cfg, self.rid, sql=True)
-        assert "CREATE TYPE pgenum AS "\
+        assert (
+            "CREATE TYPE pgenum AS "
             "ENUM ('one', 'two', 'three')" in buf.getvalue()
+        )
         assert "CREATE TABLE sometable (\n    data pgenum\n)" in buf.getvalue()
 
     def test_offline_inline_enum_drop(self):
@@ -215,8 +253,10 @@ def downgrade():
         self._distinct_enum_script()
         with capture_context_buffer() as buf:
             command.upgrade(self.cfg, self.rid, sql=True)
-        assert "CREATE TYPE pgenum AS ENUM "\
+        assert (
+            "CREATE TYPE pgenum AS ENUM "
             "('one', 'two', 'three')" in buf.getvalue()
+        )
         assert "CREATE TABLE sometable (\n    data pgenum\n)" in buf.getvalue()
 
     def test_offline_distinct_enum_drop(self):
@@ -228,23 +268,27 @@ def downgrade():
 
 
 class PostgresqlInlineLiteralTest(TestBase):
-    __only_on__ = 'postgresql'
+    __only_on__ = "postgresql"
     __backend__ = True
 
     @classmethod
     def setup_class(cls):
         cls.bind = config.db
-        cls.bind.execute("""
+        cls.bind.execute(
+            """
             create table tab (
                 col varchar(50)
             )
-        """)
-        cls.bind.execute("""
+        """
+        )
+        cls.bind.execute(
+            """
             insert into tab (col) values
                 ('old data 1'),
                 ('old data 2.1'),
                 ('old data 3')
-        """)
+        """
+        )
 
     @classmethod
     def teardown_class(cls):
@@ -260,35 +304,32 @@ class PostgresqlInlineLiteralTest(TestBase):
 
     def test_inline_percent(self):
         # TODO: here's the issue, you need to escape this.
-        tab = table('tab', column('col'))
+        tab = table("tab", column("col"))
         self.op.execute(
-            tab.update().where(
-                tab.c.col.like(self.op.inline_literal('%.%'))
-            ).values(col=self.op.inline_literal('new data')),
-            execution_options={'no_parameters': True}
+            tab.update()
+            .where(tab.c.col.like(self.op.inline_literal("%.%")))
+            .values(col=self.op.inline_literal("new data")),
+            execution_options={"no_parameters": True},
         )
         eq_(
             self.conn.execute(
-                "select count(*) from tab where col='new data'").scalar(),
+                "select count(*) from tab where col='new data'"
+            ).scalar(),
             1,
         )
 
 
 class PostgresqlDefaultCompareTest(TestBase):
-    __only_on__ = 'postgresql'
+    __only_on__ = "postgresql"
     __backend__ = True
 
-
     @classmethod
     def setup_class(cls):
         cls.bind = config.db
         staging_env()
         cls.migration_context = MigrationContext.configure(
             connection=cls.bind.connect(),
-            opts={
-                'compare_type': True,
-                'compare_server_default': True
-            }
+            opts={"compare_type": True, "compare_server_default": True},
         )
 
     def setUp(self):
@@ -303,216 +344,166 @@ class PostgresqlDefaultCompareTest(TestBase):
         self.metadata.drop_all()
 
     def _compare_default_roundtrip(
-            self, type_, orig_default, alternate=None, diff_expected=None):
-        diff_expected = diff_expected \
-            if diff_expected is not None \
+        self, type_, orig_default, alternate=None, diff_expected=None
+    ):
+        diff_expected = (
+            diff_expected
+            if diff_expected is not None
             else alternate is not None
+        )
         if alternate is None:
             alternate = orig_default
 
-        t1 = Table("test", self.metadata,
-                   Column("somecol", type_, server_default=orig_default))
-        t2 = Table("test", MetaData(),
-                   Column("somecol", type_, server_default=alternate))
+        t1 = Table(
+            "test",
+            self.metadata,
+            Column("somecol", type_, server_default=orig_default),
+        )
+        t2 = Table(
+            "test",
+            MetaData(),
+            Column("somecol", type_, server_default=alternate),
+        )
 
         t1.create(self.bind)
 
         insp = Inspector.from_engine(self.bind)
         cols = insp.get_columns(t1.name)
-        insp_col = Column("somecol", cols[0]['type'],
-                          server_default=text(cols[0]['default']))
+        insp_col = Column(
+            "somecol", cols[0]["type"], server_default=text(cols[0]["default"])
+        )
         op = ops.AlterColumnOp("test", "somecol")
         _compare_server_default(
-            self.autogen_context, op,
-            None, "test", "somecol", insp_col, t2.c.somecol)
+            self.autogen_context,
+            op,
+            None,
+            "test",
+            "somecol",
+            insp_col,
+            t2.c.somecol,
+        )
 
         diffs = op.to_diff_tuple()
         eq_(bool(diffs), diff_expected)
 
-    def _compare_default(
-        self,
-        t1, t2, col,
-        rendered
-    ):
+    def _compare_default(self, t1, t2, col, rendered):
         t1.create(self.bind, checkfirst=True)
         insp = Inspector.from_engine(self.bind)
         cols = insp.get_columns(t1.name)
         ctx = self.autogen_context.migration_context
 
         return ctx.impl.compare_server_default(
-            None,
-            col,
-            rendered,
-            cols[0]['default'])
+            None, col, rendered, cols[0]["default"]
+        )
 
     def test_compare_interval_str(self):
         # this form shouldn't be used but testing here
         # for compatibility
-        self._compare_default_roundtrip(
-            Interval,
-            "14 days"
-        )
+        self._compare_default_roundtrip(Interval, "14 days")
 
     @config.requirements.postgresql_uuid_ossp
     def test_compare_uuid_text(self):
-        self._compare_default_roundtrip(
-            UUID,
-            text("uuid_generate_v4()")
-        )
+        self._compare_default_roundtrip(UUID, text("uuid_generate_v4()"))
 
     def test_compare_interval_text(self):
-        self._compare_default_roundtrip(
-            Interval,
-            text("'14 days'")
-        )
+        self._compare_default_roundtrip(Interval, text("'14 days'"))
 
     def test_compare_array_of_integer_text(self):
         self._compare_default_roundtrip(
-            ARRAY(Integer),
-            text("(ARRAY[]::integer[])")
+            ARRAY(Integer), text("(ARRAY[]::integer[])")
         )
 
     def test_compare_current_timestamp_text(self):
         self._compare_default_roundtrip(
-            DateTime(),
-            text("TIMEZONE('utc', CURRENT_TIMESTAMP)"),
+            DateTime(), text("TIMEZONE('utc', CURRENT_TIMESTAMP)")
         )
 
     def test_compare_integer_str(self):
-        self._compare_default_roundtrip(
-            Integer(),
-            "5",
-        )
+        self._compare_default_roundtrip(Integer(), "5")
 
     def test_compare_integer_text(self):
-        self._compare_default_roundtrip(
-            Integer(),
-            text("5"),
-        )
+        self._compare_default_roundtrip(Integer(), text("5"))
 
     def test_compare_integer_text_diff(self):
-        self._compare_default_roundtrip(
-            Integer(),
-            text("5"), "7"
-        )
+        self._compare_default_roundtrip(Integer(), text("5"), "7")
 
     def test_compare_float_str(self):
-        self._compare_default_roundtrip(
-            Float(),
-            "5.2",
-        )
+        self._compare_default_roundtrip(Float(), "5.2")
 
     def test_compare_float_text(self):
-        self._compare_default_roundtrip(
-            Float(),
-            text("5.2"),
-        )
+        self._compare_default_roundtrip(Float(), text("5.2"))
 
     def test_compare_float_no_diff1(self):
         self._compare_default_roundtrip(
-            Float(),
-            text("5.2"), "5.2",
-            diff_expected=False
+            Float(), text("5.2"), "5.2", diff_expected=False
         )
 
     def test_compare_float_no_diff2(self):
         self._compare_default_roundtrip(
-            Float(),
-            "5.2", text("5.2"),
-            diff_expected=False
+            Float(), "5.2", text("5.2"), diff_expected=False
         )
 
     def test_compare_float_no_diff3(self):
         self._compare_default_roundtrip(
-            Float(),
-            text("5"), text("5.0"),
-            diff_expected=False
+            Float(), text("5"), text("5.0"), diff_expected=False
         )
 
     def test_compare_float_no_diff4(self):
         self._compare_default_roundtrip(
-            Float(),
-            "5", "5.0",
-            diff_expected=False
+            Float(), "5", "5.0", diff_expected=False
         )
 
     def test_compare_float_no_diff5(self):
         self._compare_default_roundtrip(
-            Float(),
-            text("5"), "5.0",
-            diff_expected=False
+            Float(), text("5"), "5.0", diff_expected=False
         )
 
     def test_compare_float_no_diff6(self):
         self._compare_default_roundtrip(
-            Float(),
-            "5", text("5.0"),
-            diff_expected=False
+            Float(), "5", text("5.0"), diff_expected=False
         )
 
     def test_compare_numeric_no_diff(self):
         self._compare_default_roundtrip(
-            Numeric(),
-            text("5"), "5.0",
-            diff_expected=False
+            Numeric(), text("5"), "5.0", diff_expected=False
         )
 
     def test_compare_unicode_literal(self):
-        self._compare_default_roundtrip(
-            String(),
-            u'im a default'
-        )
+        self._compare_default_roundtrip(String(), u"im a default")
 
     # TOOD: will need to actually eval() the repr() and
     # spend more effort figuring out exactly the kind of expression
     # to use
     def _TODO_test_compare_character_str_w_singlequote(self):
-        self._compare_default_roundtrip(
-            String(),
-            "hel''lo",
-        )
+        self._compare_default_roundtrip(String(), "hel''lo")
 
     def test_compare_character_str(self):
-        self._compare_default_roundtrip(
-            String(),
-            "hello",
-        )
+        self._compare_default_roundtrip(String(), "hello")
 
     def test_compare_character_text(self):
-        self._compare_default_roundtrip(
-            String(),
-            text("'hello'"),
-        )
+        self._compare_default_roundtrip(String(), text("'hello'"))
 
     def test_compare_character_str_diff(self):
-        self._compare_default_roundtrip(
-            String(),
-            "hello",
-            "there"
-        )
+        self._compare_default_roundtrip(String(), "hello", "there")
 
     def test_compare_character_text_diff(self):
         self._compare_default_roundtrip(
-            String(),
-            text("'hello'"),
-            text("'there'")
+            String(), text("'hello'"), text("'there'")
         )
 
     def test_primary_key_skip(self):
         """Test that SERIAL cols are just skipped"""
-        t1 = Table("sometable", self.metadata,
-                   Column("id", Integer, primary_key=True)
-                   )
-        t2 = Table("sometable", MetaData(),
-                   Column("id", Integer, primary_key=True)
-                   )
-        assert not self._compare_default(
-            t1, t2, t2.c.id, ""
+        t1 = Table(
+            "sometable", self.metadata, Column("id", Integer, primary_key=True)
+        )
+        t2 = Table(
+            "sometable", MetaData(), Column("id", Integer, primary_key=True)
         )
+        assert not self._compare_default(t1, t2, t2.c.id, "")
 
 
 class PostgresqlDetectSerialTest(TestBase):
-    __only_on__ = 'postgresql'
+    __only_on__ = "postgresql"
     __backend__ = True
 
     @classmethod
@@ -522,10 +513,7 @@ class PostgresqlDetectSerialTest(TestBase):
         staging_env()
         cls.migration_context = MigrationContext.configure(
             connection=cls.conn,
-            opts={
-                'compare_type': True,
-                'compare_server_default': True
-            }
+            opts={"compare_type": True, "compare_server_default": True},
         )
 
     def setUp(self):
@@ -538,7 +526,7 @@ class PostgresqlDetectSerialTest(TestBase):
 
     @provide_metadata
     def _expect_default(self, c_expected, col, seq=None):
-        Table('t', self.metadata, col)
+        Table("t", self.metadata, col)
 
         self.autogen_context.metadata = self.metadata
 
@@ -550,43 +538,50 @@ class PostgresqlDetectSerialTest(TestBase):
 
         uo = ops.UpgradeOps(ops=[])
         _compare_tables(
-            set([(None, 't')]), set([]),
-            insp, uo, self.autogen_context)
+            set([(None, "t")]), set([]), insp, uo, self.autogen_context
+        )
         diffs = uo.as_diffs()
         tab = diffs[0][1]
 
-        eq_(_render_server_default_for_compare(
-            tab.c.x.server_default, tab.c.x, self.autogen_context),
-            c_expected)
+        eq_(
+            _render_server_default_for_compare(
+                tab.c.x.server_default, tab.c.x, self.autogen_context
+            ),
+            c_expected,
+        )
 
         insp = Inspector.from_engine(config.db)
         uo = ops.UpgradeOps(ops=[])
         m2 = MetaData()
-        Table('t', m2, Column('x', BigInteger()))
+        Table("t", m2, Column("x", BigInteger()))
         self.autogen_context.metadata = m2
         _compare_tables(
-            set([(None, 't')]), set([(None, 't')]),
-            insp, uo, self.autogen_context)
+            set([(None, "t")]),
+            set([(None, "t")]),
+            insp,
+            uo,
+            self.autogen_context,
+        )
         diffs = uo.as_diffs()
-        server_default = diffs[0][0][4]['existing_server_default']
-        eq_(_render_server_default_for_compare(
-            server_default, tab.c.x, self.autogen_context),
-            c_expected)
+        server_default = diffs[0][0][4]["existing_server_default"]
+        eq_(
+            _render_server_default_for_compare(
+                server_default, tab.c.x, self.autogen_context
+            ),
+            c_expected,
+        )
 
     def test_serial(self):
-        self._expect_default(
-            None,
-            Column('x', Integer, primary_key=True)
-        )
+        self._expect_default(None, Column("x", Integer, primary_key=True))
 
     def test_separate_seq(self):
         seq = Sequence("x_id_seq")
         self._expect_default(
             "nextval('x_id_seq'::regclass)",
             Column(
-                'x', Integer,
-                server_default=seq.next_value(), primary_key=True),
-            seq
+                "x", Integer, server_default=seq.next_value(), primary_key=True
+            ),
+            seq,
         )
 
     def test_numeric(self):
@@ -594,29 +589,29 @@ class PostgresqlDetectSerialTest(TestBase):
         self._expect_default(
             "nextval('x_id_seq'::regclass)",
             Column(
-                'x', Numeric(8, 2), server_default=seq.next_value(),
-                primary_key=True),
-            seq
+                "x",
+                Numeric(8, 2),
+                server_default=seq.next_value(),
+                primary_key=True,
+            ),
+            seq,
         )
 
     def test_no_default(self):
         self._expect_default(
-            None,
-            Column('x', Integer, autoincrement=False, primary_key=True)
+            None, Column("x", Integer, autoincrement=False, primary_key=True)
         )
 
 
 class PostgresqlAutogenRenderTest(TestBase):
-
     def setUp(self):
         ctx_opts = {
-            'sqlalchemy_module_prefix': 'sa.',
-            'alembic_module_prefix': 'op.',
-            'target_metadata': MetaData()
+            "sqlalchemy_module_prefix": "sa.",
+            "alembic_module_prefix": "op.",
+            "target_metadata": MetaData(),
         }
         context = MigrationContext.configure(
-            dialect_name="postgresql",
-            opts=ctx_opts
+            dialect_name="postgresql", opts=ctx_opts
         )
 
         self.autogen_context = api.AutogenContext(context)
@@ -625,13 +620,11 @@ class PostgresqlAutogenRenderTest(TestBase):
         autogen_context = self.autogen_context
 
         m = MetaData()
-        t = Table('t', m,
-                  Column('x', String),
-                  Column('y', String)
-                  )
+        t = Table("t", m, Column("x", String), Column("y", String))
 
-        idx = Index('foo_idx', t.c.x, t.c.y,
-                    postgresql_where=(t.c.y == 'something'))
+        idx = Index(
+            "foo_idx", t.c.x, t.c.y, postgresql_where=(t.c.y == "something")
+        )
 
         op_obj = ops.CreateIndexOp.from_index(idx)
 
@@ -639,114 +632,124 @@ class PostgresqlAutogenRenderTest(TestBase):
             autogenerate.render_op_text(autogen_context, op_obj),
             """op.create_index('foo_idx', 't', \
 ['x', 'y'], unique=False, """
-            """postgresql_where=sa.text(!U"y = 'something'"))"""
+            """postgresql_where=sa.text(!U"y = 'something'"))""",
         )
 
     def test_render_server_default_native_boolean(self):
         c = Column(
-            'updated_at', Boolean(),
-            server_default=false(),
-            nullable=False)
-        result = autogenerate.render._render_column(
-            c, self.autogen_context,
+            "updated_at", Boolean(), server_default=false(), nullable=False
         )
+        result = autogenerate.render._render_column(c, self.autogen_context)
         eq_ignore_whitespace(
             result,
-            'sa.Column(\'updated_at\', sa.Boolean(), '
-            'server_default=sa.text(!U\'false\'), '
-            'nullable=False)'
+            "sa.Column('updated_at', sa.Boolean(), "
+            "server_default=sa.text(!U'false'), "
+            "nullable=False)",
         )
 
     def test_postgresql_array_type(self):
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
-                ARRAY(Integer), self.autogen_context),
-            "postgresql.ARRAY(sa.Integer())"
+                ARRAY(Integer), self.autogen_context
+            ),
+            "postgresql.ARRAY(sa.Integer())",
         )
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
-                ARRAY(DateTime(timezone=True)), self.autogen_context),
-            "postgresql.ARRAY(sa.DateTime(timezone=True))"
+                ARRAY(DateTime(timezone=True)), self.autogen_context
+            ),
+            "postgresql.ARRAY(sa.DateTime(timezone=True))",
         )
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
-                ARRAY(BYTEA, as_tuple=True, dimensions=2),
-                self.autogen_context),
-            "postgresql.ARRAY(postgresql.BYTEA(), as_tuple=True, dimensions=2)"
+                ARRAY(BYTEA, as_tuple=True, dimensions=2), self.autogen_context
+            ),
+            "postgresql.ARRAY(postgresql.BYTEA(), as_tuple=True, dimensions=2)",
         )
 
-        assert 'from sqlalchemy.dialects import postgresql' in \
-            self.autogen_context.imports
+        assert (
+            "from sqlalchemy.dialects import postgresql"
+            in self.autogen_context.imports
+        )
 
     @config.requirements.sqlalchemy_110
     def test_postgresql_hstore_subtypes(self):
         eq_ignore_whitespace(
-            autogenerate.render._repr_type(
-                HSTORE(), self.autogen_context),
-            "postgresql.HSTORE(text_type=sa.Text())"
+            autogenerate.render._repr_type(HSTORE(), self.autogen_context),
+            "postgresql.HSTORE(text_type=sa.Text())",
         )
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
-                HSTORE(text_type=String()), self.autogen_context),
-            "postgresql.HSTORE(text_type=sa.String())"
+                HSTORE(text_type=String()), self.autogen_context
+            ),
+            "postgresql.HSTORE(text_type=sa.String())",
         )
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
-                HSTORE(text_type=BYTEA()), self.autogen_context),
-            "postgresql.HSTORE(text_type=postgresql.BYTEA())"
+                HSTORE(text_type=BYTEA()), self.autogen_context
+            ),
+            "postgresql.HSTORE(text_type=postgresql.BYTEA())",
         )
 
-        assert 'from sqlalchemy.dialects import postgresql' in \
-            self.autogen_context.imports
+        assert (
+            "from sqlalchemy.dialects import postgresql"
+            in self.autogen_context.imports
+        )
 
     @config.requirements.sqlalchemy_110
     def test_generic_array_type(self):
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
-                types.ARRAY(Integer), self.autogen_context),
-            "sa.ARRAY(sa.Integer())"
+                types.ARRAY(Integer), self.autogen_context
+            ),
+            "sa.ARRAY(sa.Integer())",
         )
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
-                types.ARRAY(DateTime(timezone=True)), self.autogen_context),
-            "sa.ARRAY(sa.DateTime(timezone=True))"
+                types.ARRAY(DateTime(timezone=True)), self.autogen_context
+            ),
+            "sa.ARRAY(sa.DateTime(timezone=True))",
         )
 
-        assert 'from sqlalchemy.dialects import postgresql' not in \
-            self.autogen_context.imports
+        assert (
+            "from sqlalchemy.dialects import postgresql"
+            not in self.autogen_context.imports
+        )
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
                 types.ARRAY(BYTEA, as_tuple=True, dimensions=2),
-                self.autogen_context),
-            "sa.ARRAY(postgresql.BYTEA(), as_tuple=True, dimensions=2)"
+                self.autogen_context,
+            ),
+            "sa.ARRAY(postgresql.BYTEA(), as_tuple=True, dimensions=2)",
         )
 
-        assert 'from sqlalchemy.dialects import postgresql' in \
-            self.autogen_context.imports
+        assert (
+            "from sqlalchemy.dialects import postgresql"
+            in self.autogen_context.imports
+        )
 
     def test_array_type_user_defined_inner(self):
         def repr_type(typestring, object_, autogen_context):
-            if typestring == 'type' and isinstance(object_, String):
+            if typestring == "type" and isinstance(object_, String):
                 return "foobar.MYVARCHAR"
             else:
                 return False
 
-        self.autogen_context.opts.update(
-            render_item=repr_type
-        )
+        self.autogen_context.opts.update(render_item=repr_type)
 
         eq_ignore_whitespace(
             autogenerate.render._repr_type(
-                ARRAY(String), self.autogen_context),
-            "postgresql.ARRAY(foobar.MYVARCHAR)"
+                ARRAY(String), self.autogen_context
+            ),
+            "postgresql.ARRAY(foobar.MYVARCHAR)",
         )
 
     @config.requirements.fail_before_sqla_100
@@ -756,22 +759,18 @@ class PostgresqlAutogenRenderTest(TestBase):
         autogen_context = self.autogen_context
 
         m = MetaData()
-        t = Table('t', m,
-                  Column('x', String),
-                  Column('y', String)
-                  )
-
-        op_obj = ops.AddConstraintOp.from_constraint(ExcludeConstraint(
-            (t.c.x, ">"),
-            where=t.c.x != 2,
-            using="gist",
-            name="t_excl_x"
-        ))
+        t = Table("t", m, Column("x", String), Column("y", String))
+
+        op_obj = ops.AddConstraintOp.from_constraint(
+            ExcludeConstraint(
+                (t.c.x, ">"), where=t.c.x != 2, using="gist", name="t_excl_x"
+            )
+        )
 
         eq_ignore_whitespace(
             autogenerate.render_op_text(autogen_context, op_obj),
             "op.create_exclude_constraint('t_excl_x', 't', (sa.column('x'), '>'), "
-            "where=sa.text(!U'x != 2'), using='gist')"
+            "where=sa.text(!U'x != 2'), using='gist')",
         )
 
     @config.requirements.fail_before_sqla_100
@@ -781,25 +780,25 @@ class PostgresqlAutogenRenderTest(TestBase):
         autogen_context = self.autogen_context
 
         m = MetaData()
-        t = Table('TTAble', m,
-                  Column('XColumn', String),
-                  Column('YColumn', String)
-                  )
+        t = Table(
+            "TTAble", m, Column("XColumn", String), Column("YColumn", String)
+        )
 
-        op_obj = ops.AddConstraintOp.from_constraint(ExcludeConstraint(
-            (t.c.XColumn, ">"),
-            where=t.c.XColumn != 2,
-            using="gist",
-            name="t_excl_x"
-        ))
+        op_obj = ops.AddConstraintOp.from_constraint(
+            ExcludeConstraint(
+                (t.c.XColumn, ">"),
+                where=t.c.XColumn != 2,
+                using="gist",
+                name="t_excl_x",
+            )
+        )
 
         eq_ignore_whitespace(
             autogenerate.render_op_text(autogen_context, op_obj),
             "op.create_exclude_constraint('t_excl_x', 'TTAble', (sa.column('XColumn'), '>'), "
-            "where=sa.text(!U'\"XColumn\" != 2'), using='gist')"
+            "where=sa.text(!U'\"XColumn\" != 2'), using='gist')",
         )
 
-
     @config.requirements.fail_before_sqla_100
     def test_inline_exclude_constraint(self):
         from sqlalchemy.dialects.postgresql import ExcludeConstraint
@@ -808,15 +807,13 @@ class PostgresqlAutogenRenderTest(TestBase):
 
         m = MetaData()
         t = Table(
-            't', m,
-            Column('x', String),
-            Column('y', String),
+            "t",
+            m,
+            Column("x", String),
+            Column("y", String),
             ExcludeConstraint(
-                ('x', ">"),
-                using="gist",
-                where='x != 2',
-                name="t_excl_x"
-            )
+                ("x", ">"), using="gist", where="x != 2", name="t_excl_x"
+            ),
         )
 
         op_obj = ops.CreateTableOp.from_table(t)
@@ -827,7 +824,7 @@ class PostgresqlAutogenRenderTest(TestBase):
             "sa.Column('y', sa.String(), nullable=True),"
             "postgresql.ExcludeConstraint((!U'x', '>'), "
             "where=sa.text(!U'x != 2'), using='gist', name='t_excl_x')"
-            ")"
+            ")",
         )
 
     @config.requirements.fail_before_sqla_100
@@ -838,15 +835,13 @@ class PostgresqlAutogenRenderTest(TestBase):
 
         m = MetaData()
         t = Table(
-            'TTable', m,
-            Column('XColumn', String),
-            Column('YColumn', String),
+            "TTable", m, Column("XColumn", String), Column("YColumn", String)
         )
         ExcludeConstraint(
             (t.c.XColumn, ">"),
             using="gist",
             where='"XColumn" != 2',
-            name="TExclX"
+            name="TExclX",
         )
 
         op_obj = ops.CreateTableOp.from_table(t)
@@ -858,33 +853,29 @@ class PostgresqlAutogenRenderTest(TestBase):
             "sa.Column('YColumn', sa.String(), nullable=True),"
             "postgresql.ExcludeConstraint((sa.column('XColumn'), '>'), "
             "where=sa.text(!U'\"XColumn\" != 2'), using='gist', "
-            "name='TExclX'))"
+            "name='TExclX'))",
         )
 
     def test_json_type(self):
         if config.requirements.sqlalchemy_110.enabled:
             eq_ignore_whitespace(
-                autogenerate.render._repr_type(
-                    JSON(), self.autogen_context),
-                "postgresql.JSON(astext_type=sa.Text())"
+                autogenerate.render._repr_type(JSON(), self.autogen_context),
+                "postgresql.JSON(astext_type=sa.Text())",
             )
         else:
             eq_ignore_whitespace(
-                autogenerate.render._repr_type(
-                    JSON(), self.autogen_context),
-                "postgresql.JSON()"
+                autogenerate.render._repr_type(JSON(), self.autogen_context),
+                "postgresql.JSON()",
             )
 
     def test_jsonb_type(self):
         if config.requirements.sqlalchemy_110.enabled:
             eq_ignore_whitespace(
-                autogenerate.render._repr_type(
-                    JSONB(), self.autogen_context),
-                "postgresql.JSONB(astext_type=sa.Text())"
+                autogenerate.render._repr_type(JSONB(), self.autogen_context),
+                "postgresql.JSONB(astext_type=sa.Text())",
             )
         else:
             eq_ignore_whitespace(
-                autogenerate.render._repr_type(
-                    JSONB(), self.autogen_context),
-                "postgresql.JSONB()"
+                autogenerate.render._repr_type(JSONB(), self.autogen_context),
+                "postgresql.JSONB()",
             )
index 1f1c342011e5e919fc790be6aa765a8652239a3e..41a713ebc8eba1f52d6ef735d4ce384f65d5767d 100644 (file)
@@ -1,7 +1,11 @@
 from alembic.testing.fixtures import TestBase
 from alembic.testing import eq_, assert_raises_message
-from alembic.script.revision import RevisionMap, Revision, MultipleHeads, \
-    RevisionError
+from alembic.script.revision import (
+    RevisionMap,
+    Revision,
+    MultipleHeads,
+    RevisionError,
+)
 from . import _large_map
 
 
@@ -9,136 +13,144 @@ class APITest(TestBase):
     def test_add_revision_one_head(self):
         map_ = RevisionMap(
             lambda: [
-                Revision('a', ()),
-                Revision('b', ('a',)),
-                Revision('c', ('b',)),
+                Revision("a", ()),
+                Revision("b", ("a",)),
+                Revision("c", ("b",)),
             ]
         )
-        eq_(map_.heads, ('c', ))
+        eq_(map_.heads, ("c",))
 
-        map_.add_revision(Revision('d', ('c', )))
-        eq_(map_.heads, ('d', ))
+        map_.add_revision(Revision("d", ("c",)))
+        eq_(map_.heads, ("d",))
 
     def test_add_revision_two_head(self):
         map_ = RevisionMap(
             lambda: [
-                Revision('a', ()),
-                Revision('b', ('a',)),
-                Revision('c1', ('b',)),
-                Revision('c2', ('b',)),
+                Revision("a", ()),
+                Revision("b", ("a",)),
+                Revision("c1", ("b",)),
+                Revision("c2", ("b",)),
             ]
         )
-        eq_(map_.heads, ('c1', 'c2'))
+        eq_(map_.heads, ("c1", "c2"))
 
-        map_.add_revision(Revision('d1', ('c1', )))
-        eq_(map_.heads, ('c2', 'd1'))
+        map_.add_revision(Revision("d1", ("c1",)))
+        eq_(map_.heads, ("c2", "d1"))
 
     def test_get_revision_head_single(self):
         map_ = RevisionMap(
             lambda: [
-                Revision('a', ()),
-                Revision('b', ('a',)),
-                Revision('c', ('b',)),
+                Revision("a", ()),
+                Revision("b", ("a",)),
+                Revision("c", ("b",)),
             ]
         )
-        eq_(map_.get_revision('head'), map_._revision_map['c'])
+        eq_(map_.get_revision("head"), map_._revision_map["c"])
 
     def test_get_revision_base_single(self):
         map_ = RevisionMap(
             lambda: [
-                Revision('a', ()),
-                Revision('b', ('a',)),
-                Revision('c', ('b',)),
+                Revision("a", ()),
+                Revision("b", ("a",)),
+                Revision("c", ("b",)),
             ]
         )
-        eq_(map_.get_revision('base'), None)
+        eq_(map_.get_revision("base"), None)
 
     def test_get_revision_head_multiple(self):
         map_ = RevisionMap(
             lambda: [
-                Revision('a', ()),
-                Revision('b', ('a',)),
-                Revision('c1', ('b',)),
-                Revision('c2', ('b',)),
+                Revision("a", ()),
+                Revision("b", ("a",)),
+                Revision("c1", ("b",)),
+                Revision("c2", ("b",)),
             ]
         )
         assert_raises_message(
             MultipleHeads,
             "Multiple heads are present",
-            map_.get_revision, 'head'
+            map_.get_revision,
+            "head",
         )
 
     def test_get_revision_heads_multiple(self):
         map_ = RevisionMap(
             lambda: [
-                Revision('a', ()),
-                Revision('b', ('a',)),
-                Revision('c1', ('b',)),
-                Revision('c2', ('b',)),
+                Revision("a", ()),
+                Revision("b", ("a",)),
+                Revision("c1", ("b",)),
+                Revision("c2", ("b",)),
             ]
         )
         assert_raises_message(
             MultipleHeads,
             "Multiple heads are present",
-            map_.get_revision, "heads"
+            map_.get_revision,
+            "heads",
         )
 
     def test_get_revision_base_multiple(self):
         map_ = RevisionMap(
             lambda: [
-                Revision('a', ()),
-                Revision('b', ('a',)),
-                Revision('c', ()),
-                Revision('d', ('c',)),
+                Revision("a", ()),
+                Revision("b", ("a",)),
+                Revision("c", ()),
+                Revision("d", ("c",)),
             ]
         )
-        eq_(map_.get_revision('base'), None)
+        eq_(map_.get_revision("base"), None)
 
     def test_iterate_tolerates_dupe_targets(self):
         map_ = RevisionMap(
             lambda: [
-                Revision('a', ()),
-                Revision('b', ('a',)),
-                Revision('c', ('b',)),
+                Revision("a", ()),
+                Revision("b", ("a",)),
+                Revision("c", ("b",)),
             ]
         )
 
         eq_(
-            [
-                r.revision for r in
-                map_._iterate_revisions(('c', 'c'), 'a')
-            ],
-            ['c', 'b', 'a']
+            [r.revision for r in map_._iterate_revisions(("c", "c"), "a")],
+            ["c", "b", "a"],
         )
 
     def test_repr_revs(self):
         map_ = RevisionMap(
             lambda: [
-                Revision('a', ()),
-                Revision('b', ('a',)),
-                Revision('c', (), dependencies=('a', 'b')),
+                Revision("a", ()),
+                Revision("b", ("a",)),
+                Revision("c", (), dependencies=("a", "b")),
             ]
         )
-        c = map_._revision_map['c']
+        c = map_._revision_map["c"]
         eq_(repr(c), "Revision('c', None, dependencies=('a', 'b'))")
 
 
 class DownIterateTest(TestBase):
     def _assert_iteration(
-            self, upper, lower, assertion, inclusive=True, map_=None,
-            implicit_base=False, select_for_downgrade=False):
+        self,
+        upper,
+        lower,
+        assertion,
+        inclusive=True,
+        map_=None,
+        implicit_base=False,
+        select_for_downgrade=False,
+    ):
         if map_ is None:
             map_ = self.map
         eq_(
             [
-                rev.revision for rev in
-                map_.iterate_revisions(
-                    upper, lower,
-                    inclusive=inclusive, implicit_base=implicit_base,
-                    select_for_downgrade=select_for_downgrade
+                rev.revision
+                for rev in map_.iterate_revisions(
+                    upper,
+                    lower,
+                    inclusive=inclusive,
+                    implicit_base=implicit_base,
+                    select_for_downgrade=select_for_downgrade,
                 )
             ],
-            assertion
+            assertion,
         )
 
 
@@ -146,173 +158,141 @@ class DiamondTest(DownIterateTest):
     def setUp(self):
         self.map = RevisionMap(
             lambda: [
-                Revision('a', ()),
-                Revision('b1', ('a',)),
-                Revision('b2', ('a',)),
-                Revision('c', ('b1', 'b2')),
-                Revision('d', ('c',)),
+                Revision("a", ()),
+                Revision("b1", ("a",)),
+                Revision("b2", ("a",)),
+                Revision("c", ("b1", "b2")),
+                Revision("d", ("c",)),
             ]
         )
 
     def test_iterate_simple_diamond(self):
-        self._assert_iteration(
-            "d", "a",
-            ["d", "c", "b1", "b2", "a"]
-        )
+        self._assert_iteration("d", "a", ["d", "c", "b1", "b2", "a"])
 
 
 class EmptyMapTest(DownIterateTest):
     # see issue #258
 
     def setUp(self):
-        self.map = RevisionMap(
-            lambda: []
-        )
+        self.map = RevisionMap(lambda: [])
 
     def test_iterate(self):
-        self._assert_iteration(
-            "head", "base",
-            []
-        )
+        self._assert_iteration("head", "base", [])
 
 
 class LabeledBranchTest(DownIterateTest):
     def test_dupe_branch_collection(self):
         fn = lambda: [
-            Revision('a', ()),
-            Revision('b', ('a',)),
-            Revision('c', ('b',), branch_labels=['xy1']),
-            Revision('d', ()),
-            Revision('e', ('d',), branch_labels=['xy1']),
-            Revision('f', ('e',))
+            Revision("a", ()),
+            Revision("b", ("a",)),
+            Revision("c", ("b",), branch_labels=["xy1"]),
+            Revision("d", ()),
+            Revision("e", ("d",), branch_labels=["xy1"]),
+            Revision("f", ("e",)),
         ]
         assert_raises_message(
             RevisionError,
             r"Branch name 'xy1' in revision (?:e|c) already "
             "used by revision (?:e|c)",
-            getattr, RevisionMap(fn), "_revision_map"
+            getattr,
+            RevisionMap(fn),
+            "_revision_map",
         )
 
     def test_filter_for_lineage_labeled_head_across_merge(self):
         fn = lambda: [
-            Revision('a', ()),
-            Revision('b', ('a', )),
-            Revision('c1', ('b', ), branch_labels='c1branch'),
-            Revision('c2', ('b', )),
-            Revision('d', ('c1', 'c2')),
-
+            Revision("a", ()),
+            Revision("b", ("a",)),
+            Revision("c1", ("b",), branch_labels="c1branch"),
+            Revision("c2", ("b",)),
+            Revision("d", ("c1", "c2")),
         ]
         map_ = RevisionMap(fn)
-        c1 = map_.get_revision('c1')
-        c2 = map_.get_revision('c2')
-        d = map_.get_revision('d')
-        eq_(
-            map_.filter_for_lineage([c1, c2, d], "c1branch@head"),
-            [c1, c2, d]
-        )
+        c1 = map_.get_revision("c1")
+        c2 = map_.get_revision("c2")
+        d = map_.get_revision("d")
+        eq_(map_.filter_for_lineage([c1, c2, d], "c1branch@head"), [c1, c2, d])
 
     def test_filter_for_lineage_heads(self):
         eq_(
-            self.map.filter_for_lineage(
-                [self.map.get_revision("f")],
-                "heads"
-            ),
-            [self.map.get_revision("f")]
+            self.map.filter_for_lineage([self.map.get_revision("f")], "heads"),
+            [self.map.get_revision("f")],
         )
 
     def setUp(self):
-        self.map = RevisionMap(lambda: [
-            Revision('a', (), branch_labels='abranch'),
-            Revision('b', ('a',)),
-            Revision('somelongername', ('b',)),
-            Revision('c', ('somelongername',)),
-            Revision('d', ()),
-            Revision('e', ('d',), branch_labels=['ebranch']),
-            Revision('someothername', ('e',)),
-            Revision('f', ('someothername',)),
-        ])
+        self.map = RevisionMap(
+            lambda: [
+                Revision("a", (), branch_labels="abranch"),
+                Revision("b", ("a",)),
+                Revision("somelongername", ("b",)),
+                Revision("c", ("somelongername",)),
+                Revision("d", ()),
+                Revision("e", ("d",), branch_labels=["ebranch"]),
+                Revision("someothername", ("e",)),
+                Revision("f", ("someothername",)),
+            ]
+        )
 
     def test_get_base_revisions_labeled(self):
-        eq_(
-            self.map._get_base_revisions("somelongername@base"),
-            ['a']
-        )
+        eq_(self.map._get_base_revisions("somelongername@base"), ["a"])
 
     def test_get_current_named_rev(self):
-        eq_(
-            self.map.get_revision("ebranch@head"),
-            self.map.get_revision("f")
-        )
+        eq_(self.map.get_revision("ebranch@head"), self.map.get_revision("f"))
 
     def test_get_base_revisions(self):
-        eq_(
-            self.map._get_base_revisions("base"),
-            ['a', 'd']
-        )
+        eq_(self.map._get_base_revisions("base"), ["a", "d"])
 
     def test_iterate_head_to_named_base(self):
         self._assert_iteration(
-            "heads", "ebranch@base",
-            ['f', 'someothername', 'e', 'd']
+            "heads", "ebranch@base", ["f", "someothername", "e", "d"]
         )
 
         self._assert_iteration(
-            "heads", "abranch@base",
-            ['c', 'somelongername', 'b', 'a']
+            "heads", "abranch@base", ["c", "somelongername", "b", "a"]
         )
 
     def test_iterate_named_head_to_base(self):
         self._assert_iteration(
-            "ebranch@head", "base",
-            ['f', 'someothername', 'e', 'd']
+            "ebranch@head", "base", ["f", "someothername", "e", "d"]
         )
 
         self._assert_iteration(
-            "abranch@head", "base",
-            ['c', 'somelongername', 'b', 'a']
+            "abranch@head", "base", ["c", "somelongername", "b", "a"]
         )
 
     def test_iterate_named_head_to_heads(self):
-        self._assert_iteration(
-            "heads", "ebranch@head",
-            ['f'],
-            inclusive=True
-        )
+        self._assert_iteration("heads", "ebranch@head", ["f"], inclusive=True)
 
     def test_iterate_named_rev_to_heads(self):
         self._assert_iteration(
-            "heads", "ebranch@d",
-            ['f', 'someothername', 'e', 'd'],
-            inclusive=True
+            "heads",
+            "ebranch@d",
+            ["f", "someothername", "e", "d"],
+            inclusive=True,
         )
 
     def test_iterate_head_to_version_specific_base(self):
         self._assert_iteration(
-            "heads", "e@base",
-            ['f', 'someothername', 'e', 'd']
+            "heads", "e@base", ["f", "someothername", "e", "d"]
         )
 
         self._assert_iteration(
-            "heads", "c@base",
-            ['c', 'somelongername', 'b', 'a']
+            "heads", "c@base", ["c", "somelongername", "b", "a"]
         )
 
     def test_iterate_to_branch_at_rev(self):
         self._assert_iteration(
-            "heads", "ebranch@d",
-            ['f', 'someothername', 'e', 'd']
+            "heads", "ebranch@d", ["f", "someothername", "e", "d"]
         )
 
     def test_branch_w_down_relative(self):
         self._assert_iteration(
-            "heads", "ebranch@-2",
-            ['f', 'someothername', 'e']
+            "heads", "ebranch@-2", ["f", "someothername", "e"]
         )
 
     def test_branch_w_up_relative(self):
         self._assert_iteration(
-            "ebranch@+2", "base",
-            ['someothername', 'e', 'd']
+            "ebranch@+2", "base", ["someothername", "e", "d"]
         )
 
     def test_partial_id_resolve(self):
@@ -320,43 +300,43 @@ class LabeledBranchTest(DownIterateTest):
         eq_(self.map.get_revision("abranch@some").revision, "somelongername")
 
     def test_branch_at_heads(self):
-        eq_(
-            self.map.get_revision("abranch@heads").revision,
-            "c"
-        )
+        eq_(self.map.get_revision("abranch@heads").revision, "c")
 
     def test_branch_at_syntax(self):
-        eq_(self.map.get_revision("abranch@head").revision, 'c')
+        eq_(self.map.get_revision("abranch@head").revision, "c")
         eq_(self.map.get_revision("abranch@base"), None)
-        eq_(self.map.get_revision("ebranch@head").revision, 'f')
+        eq_(self.map.get_revision("ebranch@head").revision, "f")
         eq_(self.map.get_revision("abranch@base"), None)
-        eq_(self.map.get_revision("ebranch@d").revision, 'd')
+        eq_(self.map.get_revision("ebranch@d").revision, "d")
 
     def test_branch_at_self(self):
-        eq_(self.map.get_revision("ebranch@ebranch").revision, 'e')
+        eq_(self.map.get_revision("ebranch@ebranch").revision, "e")
 
     def test_retrieve_branch_revision(self):
-        eq_(self.map.get_revision("abranch").revision, 'a')
-        eq_(self.map.get_revision("ebranch").revision, 'e')
+        eq_(self.map.get_revision("abranch").revision, "a")
+        eq_(self.map.get_revision("ebranch").revision, "e")
 
     def test_rev_not_in_branch(self):
         assert_raises_message(
             RevisionError,
             "Revision b is not a member of branch 'ebranch'",
-            self.map.get_revision, "ebranch@b"
+            self.map.get_revision,
+            "ebranch@b",
         )
 
         assert_raises_message(
             RevisionError,
             "Revision d is not a member of branch 'abranch'",
-            self.map.get_revision, "abranch@d"
+            self.map.get_revision,
+            "abranch@d",
         )
 
     def test_no_revision_exists(self):
         assert_raises_message(
             RevisionError,
             "No such revision or branch 'q'",
-            self.map.get_revision, "abranch@q"
+            self.map.get_revision,
+            "abranch@q",
         )
 
     def test_not_actually_a_branch(self):
@@ -367,9 +347,7 @@ class LabeledBranchTest(DownIterateTest):
 
     def test_no_such_branch(self):
         assert_raises_message(
-            RevisionError,
-            "No such branch: 'x'",
-            self.map.get_revision, "x@d"
+            RevisionError, "No such branch: 'x'", self.map.get_revision, "x@d"
         )
 
 
@@ -377,19 +355,18 @@ class LongShortBranchTest(DownIterateTest):
     def setUp(self):
         self.map = RevisionMap(
             lambda: [
-                Revision('a', ()),
-                Revision('b1', ('a',)),
-                Revision('b2', ('a',)),
-                Revision('c1', ('b1',)),
-                Revision('d11', ('c1',)),
-                Revision('d12', ('c1',)),
+                Revision("a", ()),
+                Revision("b1", ("a",)),
+                Revision("b2", ("a",)),
+                Revision("c1", ("b1",)),
+                Revision("d11", ("c1",)),
+                Revision("d12", ("c1",)),
             ]
         )
 
     def test_iterate_full(self):
         self._assert_iteration(
-            "heads", "base",
-            ['b2', 'd11', 'd12', 'c1', 'b1', 'a']
+            "heads", "base", ["b2", "d11", "d12", "c1", "b1", "a"]
         )
 
 
@@ -397,56 +374,46 @@ class MultipleBranchTest(DownIterateTest):
     def setUp(self):
         self.map = RevisionMap(
             lambda: [
-                Revision('a', ()),
-                Revision('b1', ('a',)),
-                Revision('b2', ('a',)),
-                Revision('cb1', ('b1',)),
-                Revision('cb2', ('b2',)),
-                Revision('d1cb1', ('cb1',)),  # head
-                Revision('d2cb1', ('cb1',)),  # head
-                Revision('d1cb2', ('cb2',)),
-                Revision('d2cb2', ('cb2',)),
-                Revision('d3cb2', ('cb2',)),  # head
-                Revision('d1d2cb2', ('d1cb2', 'd2cb2'))  # head + merge point
+                Revision("a", ()),
+                Revision("b1", ("a",)),
+                Revision("b2", ("a",)),
+                Revision("cb1", ("b1",)),
+                Revision("cb2", ("b2",)),
+                Revision("d1cb1", ("cb1",)),  # head
+                Revision("d2cb1", ("cb1",)),  # head
+                Revision("d1cb2", ("cb2",)),
+                Revision("d2cb2", ("cb2",)),
+                Revision("d3cb2", ("cb2",)),  # head
+                Revision("d1d2cb2", ("d1cb2", "d2cb2")),  # head + merge point
             ]
         )
 
     def test_iterate_from_merge_point(self):
         self._assert_iteration(
-            "d1d2cb2", "a",
-            ['d1d2cb2', 'd1cb2', 'd2cb2', 'cb2', 'b2', 'a']
+            "d1d2cb2", "a", ["d1d2cb2", "d1cb2", "d2cb2", "cb2", "b2", "a"]
         )
 
     def test_iterate_multiple_heads(self):
         self._assert_iteration(
-            ["d2cb2", "d3cb2"], "a",
-            ['d2cb2', 'd3cb2', 'cb2', 'b2', 'a']
+            ["d2cb2", "d3cb2"], "a", ["d2cb2", "d3cb2", "cb2", "b2", "a"]
         )
 
     def test_iterate_single_branch(self):
-        self._assert_iteration(
-            "d3cb2", "a",
-            ['d3cb2', 'cb2', 'b2', 'a']
-        )
+        self._assert_iteration("d3cb2", "a", ["d3cb2", "cb2", "b2", "a"])
 
     def test_iterate_single_branch_to_base(self):
-        self._assert_iteration(
-            "d3cb2", "base",
-            ['d3cb2', 'cb2', 'b2', 'a']
-        )
+        self._assert_iteration("d3cb2", "base", ["d3cb2", "cb2", "b2", "a"])
 
     def test_iterate_multiple_branch_to_base(self):
         self._assert_iteration(
-            ["d3cb2", "cb1"], "base",
-            ['d3cb2', 'cb2', 'b2', 'cb1', 'b1', 'a']
+            ["d3cb2", "cb1"], "base", ["d3cb2", "cb2", "b2", "cb1", "b1", "a"]
         )
 
     def test_iterate_multiple_heads_single_base(self):
         # head d1cb1 is omitted as it is not
         # a descendant of b2
         self._assert_iteration(
-            ["d1cb1", "d2cb2", "d3cb2"], "b2",
-            ["d2cb2", 'd3cb2', 'cb2', 'b2']
+            ["d1cb1", "d2cb2", "d3cb2"], "b2", ["d2cb2", "d3cb2", "cb2", "b2"]
         )
 
     def test_same_branch_wrong_direction(self):
@@ -456,7 +423,7 @@ class MultipleBranchTest(DownIterateTest):
             RevisionError,
             r"Revision d1cb1 is not an ancestor of revision b1",
             list,
-            self.map._iterate_revisions('b1', 'd1cb1')
+            self.map._iterate_revisions("b1", "d1cb1"),
         )
 
     def test_distinct_branches(self):
@@ -465,7 +432,7 @@ class MultipleBranchTest(DownIterateTest):
             RevisionError,
             r"Revision b1 is not an ancestor of revision d2cb2",
             list,
-            self.map._iterate_revisions('d2cb2', 'b1')
+            self.map._iterate_revisions("d2cb2", "b1"),
         )
 
     def test_wrong_direction_to_base_as_none(self):
@@ -475,7 +442,7 @@ class MultipleBranchTest(DownIterateTest):
             RevisionError,
             r"Revision d1cb1 is not an ancestor of revision base",
             list,
-            self.map._iterate_revisions(None, 'd1cb1')
+            self.map._iterate_revisions(None, "d1cb1"),
         )
 
     def test_wrong_direction_to_base_as_empty(self):
@@ -485,7 +452,7 @@ class MultipleBranchTest(DownIterateTest):
             RevisionError,
             r"Revision d1cb1 is not an ancestor of revision base",
             list,
-            self.map._iterate_revisions((), 'd1cb1')
+            self.map._iterate_revisions((), "d1cb1"),
         )
 
 
@@ -501,22 +468,20 @@ class BranchTravellingTest(DownIterateTest):
     def setUp(self):
         self.map = RevisionMap(
             lambda: [
-                Revision('a1', ()),
-                Revision('a2', ('a1',)),
-                Revision('a3', ('a2',)),
-                Revision('b1', ('a3',)),
-                Revision('b2', ('a3',)),
-                Revision('cb1', ('b1',)),
-                Revision('cb2', ('b2',)),
-                Revision('db1', ('cb1',)),
-                Revision('db2', ('cb2',)),
-
-                Revision('e1b1', ('db1',)),
-                Revision('fe1b1', ('e1b1',)),
-
-                Revision('e2b1', ('db1',)),
-                Revision('e2b2', ('db2',)),
-                Revision("merge", ('e2b1', 'e2b2'))
+                Revision("a1", ()),
+                Revision("a2", ("a1",)),
+                Revision("a3", ("a2",)),
+                Revision("b1", ("a3",)),
+                Revision("b2", ("a3",)),
+                Revision("cb1", ("b1",)),
+                Revision("cb2", ("b2",)),
+                Revision("db1", ("cb1",)),
+                Revision("db2", ("cb2",)),
+                Revision("e1b1", ("db1",)),
+                Revision("fe1b1", ("e1b1",)),
+                Revision("e2b1", ("db1",)),
+                Revision("e2b2", ("db2",)),
+                Revision("merge", ("e2b1", "e2b2")),
             ]
         )
 
@@ -524,19 +489,31 @@ class BranchTravellingTest(DownIterateTest):
         # test that when we hit a merge point, implicit base will
         # ensure all branches that supply the merge point are filled in
         self._assert_iteration(
-            "merge", "db1",
-            ['merge',
-                'e2b1', 'db1',
-                'e2b2', 'db2', 'cb2', 'b2'],
-            implicit_base=True
+            "merge",
+            "db1",
+            ["merge", "e2b1", "db1", "e2b2", "db2", "cb2", "b2"],
+            implicit_base=True,
         )
 
     def test_three_branches_end_in_single_branch(self):
 
         self._assert_iteration(
-            ["merge", "fe1b1"], "a3",
-            ['merge', 'e2b1', 'e2b2', 'db2', 'cb2', 'b2',
-             'fe1b1', 'e1b1', 'db1', 'cb1', 'b1', 'a3']
+            ["merge", "fe1b1"],
+            "a3",
+            [
+                "merge",
+                "e2b1",
+                "e2b2",
+                "db2",
+                "cb2",
+                "b2",
+                "fe1b1",
+                "e1b1",
+                "db1",
+                "cb1",
+                "b1",
+                "a3",
+            ],
         )
 
     def test_two_branches_to_root(self):
@@ -544,80 +521,103 @@ class BranchTravellingTest(DownIterateTest):
         # here we want 'a3' as a "stop" branch point, but *not*
         # 'db1', as we don't have multiple traversals on db1
         self._assert_iteration(
-            "merge", "a1",
-            ['merge',
-                'e2b1', 'db1', 'cb1', 'b1',  # e2b1 branch
-                'e2b2', 'db2', 'cb2', 'b2',  # e2b2 branch
-                'a3',  # both terminate at a3
-                'a2', 'a1'  # finish out
-            ]  # noqa
+            "merge",
+            "a1",
+            [
+                "merge",
+                "e2b1",
+                "db1",
+                "cb1",
+                "b1",  # e2b1 branch
+                "e2b2",
+                "db2",
+                "cb2",
+                "b2",  # e2b2 branch
+                "a3",  # both terminate at a3
+                "a2",
+                "a1",  # finish out
+            ],  # noqa
         )
 
     def test_two_branches_end_in_branch(self):
         self._assert_iteration(
-            "merge", "b1",
+            "merge",
+            "b1",
             # 'b1' is local to 'e2b1'
             # branch so that is all we get
-            ['merge', 'e2b1', 'db1', 'cb1', 'b1',
-
-        ]  # noqa
+            ["merge", "e2b1", "db1", "cb1", "b1"],  # noqa
         )
 
     def test_two_branches_end_behind_branch(self):
         self._assert_iteration(
-            "merge", "a2",
-            ['merge',
-                'e2b1', 'db1', 'cb1', 'b1',  # e2b1 branch
-                'e2b2', 'db2', 'cb2', 'b2',  # e2b2 branch
-                'a3',  # both terminate at a3
-                'a2'
-            ]  # noqa
+            "merge",
+            "a2",
+            [
+                "merge",
+                "e2b1",
+                "db1",
+                "cb1",
+                "b1",  # e2b1 branch
+                "e2b2",
+                "db2",
+                "cb2",
+                "b2",  # e2b2 branch
+                "a3",  # both terminate at a3
+                "a2",
+            ],  # noqa
         )
 
     def test_three_branches_to_root(self):
 
         # in this case, both "a3" and "db1" are stop points
         self._assert_iteration(
-            ["merge", "fe1b1"], "a1",
-            ['merge',
-                'e2b1',  # e2b1 branch
-                'e2b2', 'db2', 'cb2', 'b2',  # e2b2 branch
-                'fe1b1', 'e1b1',  # fe1b1 branch
-                'db1',  # fe1b1 and e2b1 branches terminate at db1
-                'cb1', 'b1',  # e2b1 branch continued....might be nicer
-                              # if this was before the e2b2 branch...
-                'a3',  # e2b1 and e2b2 branches terminate at a3
-                'a2', 'a1'  # finish out
-            ]  # noqa
+            ["merge", "fe1b1"],
+            "a1",
+            [
+                "merge",
+                "e2b1",  # e2b1 branch
+                "e2b2",
+                "db2",
+                "cb2",
+                "b2",  # e2b2 branch
+                "fe1b1",
+                "e1b1",  # fe1b1 branch
+                "db1",  # fe1b1 and e2b1 branches terminate at db1
+                "cb1",
+                "b1",  # e2b1 branch continued....might be nicer
+                # if this was before the e2b2 branch...
+                "a3",  # e2b1 and e2b2 branches terminate at a3
+                "a2",
+                "a1",  # finish out
+            ],  # noqa
         )
 
     def test_three_branches_end_multiple_bases(self):
 
         # in this case, both "a3" and "db1" are stop points
         self._assert_iteration(
-            ["merge", "fe1b1"], ["cb1", "cb2"],
+            ["merge", "fe1b1"],
+            ["cb1", "cb2"],
             [
-                'merge',
-                'e2b1',
-                'e2b2', 'db2', 'cb2',
-                'fe1b1', 'e1b1',
-                'db1',
-                'cb1'
-            ]
+                "merge",
+                "e2b1",
+                "e2b2",
+                "db2",
+                "cb2",
+                "fe1b1",
+                "e1b1",
+                "db1",
+                "cb1",
+            ],
         )
 
     def test_three_branches_end_multiple_bases_exclusive(self):
 
         self._assert_iteration(
-            ["merge", "fe1b1"], ["cb1", "cb2"],
-            [
-                'merge',
-                'e2b1',
-                'e2b2', 'db2',
-                'fe1b1', 'e1b1',
-                'db1',
-            ],
-            inclusive=False
+            ["merge", "fe1b1"],
+            ["cb1", "cb2"],
+            ["merge", "e2b1", "e2b2", "db2", "fe1b1", "e1b1", "db1"],
+            inclusive=False,
         )
 
     def test_detect_invalid_head_selection(self):
@@ -627,26 +627,34 @@ class BranchTravellingTest(DownIterateTest):
             "Requested revision fe1b1 overlaps "
             "with other requested revisions",
             list,
-            self.map._iterate_revisions(["db1", "b2", "fe1b1"], ())
+            self.map._iterate_revisions(["db1", "b2", "fe1b1"], ()),
         )
 
     def test_three_branches_end_multiple_bases_exclusive_blank(self):
         self._assert_iteration(
-            ["e2b1", "b2", "fe1b1"], (),
+            ["e2b1", "b2", "fe1b1"],
+            (),
             [
-                'e2b1',
-                'b2',
-                'fe1b1', 'e1b1',
-                'db1', 'cb1', 'b1', 'a3', 'a2', 'a1'
+                "e2b1",
+                "b2",
+                "fe1b1",
+                "e1b1",
+                "db1",
+                "cb1",
+                "b1",
+                "a3",
+                "a2",
+                "a1",
             ],
-            inclusive=False
+            inclusive=False,
         )
 
     def test_iterate_to_symbolic_base(self):
         self._assert_iteration(
-            ["fe1b1"], "base",
-            ['fe1b1', 'e1b1', 'db1', 'cb1', 'b1', 'a3', 'a2', 'a1'],
-            inclusive=False
+            ["fe1b1"],
+            "base",
+            ["fe1b1", "e1b1", "db1", "cb1", "b1", "a3", "a2", "a1"],
+            inclusive=False,
         )
 
     def test_ancestor_nodes(self):
@@ -656,8 +664,22 @@ class BranchTravellingTest(DownIterateTest):
                 rev.revision
                 for rev in self.map._get_ancestor_nodes([merge], check=True)
             ),
-            set(['a1', 'e2b2', 'e2b1', 'cb2', 'merge',
-                'a3', 'a2', 'b1', 'b2', 'db1', 'db2', 'cb1'])
+            set(
+                [
+                    "a1",
+                    "e2b2",
+                    "e2b1",
+                    "cb2",
+                    "merge",
+                    "a3",
+                    "a2",
+                    "b1",
+                    "b2",
+                    "db1",
+                    "db2",
+                    "cb1",
+                ]
+            ),
         )
 
 
@@ -665,125 +687,153 @@ class MultipleBaseTest(DownIterateTest):
     def setUp(self):
         self.map = RevisionMap(
             lambda: [
-                Revision('base1', ()),
-                Revision('base2', ()),
-                Revision('base3', ()),
-
-                Revision('a1a', ('base1',)),
-                Revision('a1b', ('base1',)),
-                Revision('a2', ('base2',)),
-                Revision('a3', ('base3',)),
-
-                Revision('b1a', ('a1a',)),
-                Revision('b1b', ('a1b',)),
-                Revision('b2', ('a2',)),
-                Revision('b3', ('a3',)),
-
-                Revision('c2', ('b2',)),
-                Revision('d2', ('c2',)),
-
-                Revision('mergeb3d2', ('b3', 'd2'))
+                Revision("base1", ()),
+                Revision("base2", ()),
+                Revision("base3", ()),
+                Revision("a1a", ("base1",)),
+                Revision("a1b", ("base1",)),
+                Revision("a2", ("base2",)),
+                Revision("a3", ("base3",)),
+                Revision("b1a", ("a1a",)),
+                Revision("b1b", ("a1b",)),
+                Revision("b2", ("a2",)),
+                Revision("b3", ("a3",)),
+                Revision("c2", ("b2",)),
+                Revision("d2", ("c2",)),
+                Revision("mergeb3d2", ("b3", "d2")),
             ]
         )
 
     def test_heads_to_base(self):
         self._assert_iteration(
-            "heads", "base",
+            "heads",
+            "base",
             [
-                'b1a', 'a1a',
-                'b1b', 'a1b',
-                'mergeb3d2',
-                    'b3', 'a3', 'base3',
-                    'd2', 'c2', 'b2', 'a2', 'base2',
-                'base1'
-            ]
+                "b1a",
+                "a1a",
+                "b1b",
+                "a1b",
+                "mergeb3d2",
+                "b3",
+                "a3",
+                "base3",
+                "d2",
+                "c2",
+                "b2",
+                "a2",
+                "base2",
+                "base1",
+            ],
         )
 
     def test_heads_to_base_exclusive(self):
         self._assert_iteration(
-            "heads", "base",
+            "heads",
+            "base",
             [
-                'b1a', 'a1a',
-                'b1b', 'a1b',
-                'mergeb3d2',
-                    'b3', 'a3', 'base3',
-                    'd2', 'c2', 'b2', 'a2', 'base2',
-                    'base1',
+                "b1a",
+                "a1a",
+                "b1b",
+                "a1b",
+                "mergeb3d2",
+                "b3",
+                "a3",
+                "base3",
+                "d2",
+                "c2",
+                "b2",
+                "a2",
+                "base2",
+                "base1",
             ],
-            inclusive=False
+            inclusive=False,
         )
 
     def test_heads_to_blank(self):
         self._assert_iteration(
-            "heads", None,
+            "heads",
+            None,
             [
-                'b1a', 'a1a',
-                'b1b', 'a1b',
-                'mergeb3d2',
-                    'b3', 'a3', 'base3',
-                    'd2', 'c2', 'b2', 'a2', 'base2',
-                'base1'
-            ]
+                "b1a",
+                "a1a",
+                "b1b",
+                "a1b",
+                "mergeb3d2",
+                "b3",
+                "a3",
+                "base3",
+                "d2",
+                "c2",
+                "b2",
+                "a2",
+                "base2",
+                "base1",
+            ],
         )
 
     def test_detect_invalid_base_selection(self):
         assert_raises_message(
             RevisionError,
-            "Requested revision a2 overlaps with "
-            "other requested revisions",
+            "Requested revision a2 overlaps with " "other requested revisions",
             list,
-            self.map._iterate_revisions(["c2"], ["a2", "b2"])
+            self.map._iterate_revisions(["c2"], ["a2", "b2"]),
         )
 
     def test_heads_to_revs_plus_implicit_base_exclusive(self):
         self._assert_iteration(
-            "heads", ["c2"],
+            "heads",
+            ["c2"],
             [
-                'b1a', 'a1a',
-                'b1b', 'a1b',
-                'mergeb3d2',
-                    'b3', 'a3', 'base3',
-                    'd2',
-                'base1'
+                "b1a",
+                "a1a",
+                "b1b",
+                "a1b",
+                "mergeb3d2",
+                "b3",
+                "a3",
+                "base3",
+                "d2",
+                "base1",
             ],
             inclusive=False,
-            implicit_base=True
+            implicit_base=True,
         )
 
     def test_heads_to_revs_base_exclusive(self):
         self._assert_iteration(
-            "heads", ["c2"],
-            [
-                'mergeb3d2', 'd2'
-            ],
-            inclusive=False
+            "heads", ["c2"], ["mergeb3d2", "d2"], inclusive=False
         )
 
     def test_heads_to_revs_plus_implicit_base_inclusive(self):
         self._assert_iteration(
-            "heads", ["c2"],
+            "heads",
+            ["c2"],
             [
-                'b1a', 'a1a',
-                'b1b', 'a1b',
-                'mergeb3d2',
-                    'b3', 'a3', 'base3',
-                    'd2', 'c2',
-                'base1'
+                "b1a",
+                "a1a",
+                "b1b",
+                "a1b",
+                "mergeb3d2",
+                "b3",
+                "a3",
+                "base3",
+                "d2",
+                "c2",
+                "base1",
             ],
-            implicit_base=True
+            implicit_base=True,
         )
 
     def test_specific_path_one(self):
-        self._assert_iteration(
-            "b3", "base3",
-            ['b3', 'a3', 'base3']
-        )
+        self._assert_iteration("b3", "base3", ["b3", "a3", "base3"])
 
     def test_specific_path_two_implicit_base(self):
         self._assert_iteration(
-            ["b3", "b2"], "base3",
-            ['b3', 'a3', 'b2', 'a2', 'base2'],
-            inclusive=False, implicit_base=True
+            ["b3", "b2"],
+            "base3",
+            ["b3", "a3", "b2", "a2", "base2"],
+            inclusive=False,
+            implicit_base=True,
         )
 
 
@@ -808,21 +858,19 @@ class MultipleBaseCrossDependencyTestOne(DownIterateTest):
         """
         self.map = RevisionMap(
             lambda: [
-                Revision('base1', (), branch_labels='b_1'),
-                Revision('a1a', ('base1',)),
-                Revision('a1b', ('base1',)),
-                Revision('b1a', ('a1a',)),
-                Revision('b1b', ('a1b', ), dependencies='a3'),
-
-                Revision('base2', (), branch_labels='b_2'),
-                Revision('a2', ('base2',)),
-                Revision('b2', ('a2',)),
-                Revision('c2', ('b2', ), dependencies='a3'),
-                Revision('d2', ('c2',)),
-
-                Revision('base3', (), branch_labels='b_3'),
-                Revision('a3', ('base3',)),
-                Revision('b3', ('a3',)),
+                Revision("base1", (), branch_labels="b_1"),
+                Revision("a1a", ("base1",)),
+                Revision("a1b", ("base1",)),
+                Revision("b1a", ("a1a",)),
+                Revision("b1b", ("a1b",), dependencies="a3"),
+                Revision("base2", (), branch_labels="b_2"),
+                Revision("a2", ("base2",)),
+                Revision("b2", ("a2",)),
+                Revision("c2", ("b2",), dependencies="a3"),
+                Revision("d2", ("c2",)),
+                Revision("base3", (), branch_labels="b_3"),
+                Revision("a3", ("base3",)),
+                Revision("b3", ("a3",)),
             ]
         )
 
@@ -831,25 +879,45 @@ class MultipleBaseCrossDependencyTestOne(DownIterateTest):
 
     def test_heads_to_base(self):
         self._assert_iteration(
-            "heads", "base",
+            "heads",
+            "base",
             [
-
-                'b1a', 'a1a', 'b1b', 'a1b', 'd2', 'c2', 'b2', 'a2', 'base2',
-                'b3', 'a3', 'base3',
-                'base1'
-            ]
+                "b1a",
+                "a1a",
+                "b1b",
+                "a1b",
+                "d2",
+                "c2",
+                "b2",
+                "a2",
+                "base2",
+                "b3",
+                "a3",
+                "base3",
+                "base1",
+            ],
         )
 
     def test_heads_to_base_downgrade(self):
         self._assert_iteration(
-            "heads", "base",
+            "heads",
+            "base",
             [
-
-                'b1a', 'a1a', 'b1b', 'a1b', 'd2', 'c2', 'b2', 'a2', 'base2',
-                'b3', 'a3', 'base3',
-                'base1'
+                "b1a",
+                "a1a",
+                "b1b",
+                "a1b",
+                "d2",
+                "c2",
+                "b2",
+                "a2",
+                "base2",
+                "b3",
+                "a3",
+                "base3",
+                "base1",
             ],
-            select_for_downgrade=True
+            select_for_downgrade=True,
         )
 
     def test_same_branch_wrong_direction(self):
@@ -857,83 +925,79 @@ class MultipleBaseCrossDependencyTestOne(DownIterateTest):
             RevisionError,
             r"Revision d2 is not an ancestor of revision b2",
             list,
-            self.map._iterate_revisions('b2', 'd2')
+            self.map._iterate_revisions("b2", "d2"),
         )
 
     def test_different_branch_not_wrong_direction(self):
-        self._assert_iteration(
-            "b3", "d2",
-            []
-        )
+        self._assert_iteration("b3", "d2", [])
 
     def test_we_need_head2_upgrade(self):
         # the 2 branch relies on the 3 branch
         self._assert_iteration(
-            "b_2@head", "base",
-            ['d2', 'c2', 'b2', 'a2', 'base2', 'a3', 'base3']
+            "b_2@head",
+            "base",
+            ["d2", "c2", "b2", "a2", "base2", "a3", "base3"],
         )
 
     def test_we_need_head2_downgrade(self):
         # the 2 branch relies on the 3 branch, but
         # on the downgrade side, don't need to touch the 3 branch
         self._assert_iteration(
-            "b_2@head", "b_2@base",
-            ['d2', 'c2', 'b2', 'a2', 'base2'],
-            select_for_downgrade=True
+            "b_2@head",
+            "b_2@base",
+            ["d2", "c2", "b2", "a2", "base2"],
+            select_for_downgrade=True,
         )
 
     def test_we_need_head3_upgrade(self):
         # the 3 branch can be upgraded alone.
-        self._assert_iteration(
-            "b_3@head", "base",
-            ['b3', 'a3', 'base3']
-        )
+        self._assert_iteration("b_3@head", "base", ["b3", "a3", "base3"])
 
     def test_we_need_head3_downgrade(self):
         # the 3 branch can be upgraded alone.
         self._assert_iteration(
-            "b_3@head", "base",
-            ['b3', 'a3', 'base3'],
-            select_for_downgrade=True
+            "b_3@head",
+            "base",
+            ["b3", "a3", "base3"],
+            select_for_downgrade=True,
         )
 
     def test_we_need_head1_upgrade(self):
         # the 1 branch relies on the 3 branch
         self._assert_iteration(
-            "b1b@head", "base",
-            ['b1b', 'a1b', 'base1', 'a3', 'base3']
+            "b1b@head", "base", ["b1b", "a1b", "base1", "a3", "base3"]
         )
 
     def test_we_need_head1_downgrade(self):
         # going down we don't need a3-> base3, as long
         # as we are limiting the base target
         self._assert_iteration(
-            "b1b@head", "b1b@base",
-            ['b1b', 'a1b', 'base1'],
-            select_for_downgrade=True
+            "b1b@head",
+            "b1b@base",
+            ["b1b", "a1b", "base1"],
+            select_for_downgrade=True,
         )
 
     def test_we_need_base2_upgrade(self):
         # consider a downgrade to b_2@base - we
         # want to run through all the "2"s alone, and we're done.
         self._assert_iteration(
-            "heads", "b_2@base",
-            ['d2', 'c2', 'b2', 'a2', 'base2']
+            "heads", "b_2@base", ["d2", "c2", "b2", "a2", "base2"]
         )
 
     def test_we_need_base2_downgrade(self):
         # consider a downgrade to b_2@base - we
         # want to run through all the "2"s alone, and we're done.
         self._assert_iteration(
-            "heads", "b_2@base",
-            ['d2', 'c2', 'b2', 'a2', 'base2'],
-            select_for_downgrade=True
+            "heads",
+            "b_2@base",
+            ["d2", "c2", "b2", "a2", "base2"],
+            select_for_downgrade=True,
         )
 
     def test_we_need_base3_upgrade(self):
         self._assert_iteration(
-            "heads", "b_3@base",
-            ['b1b', 'd2', 'c2', 'b3', 'a3', 'base3']
+            "heads", "b_3@base", ["b1b", "d2", "c2", "b3", "a3", "base3"]
         )
 
     def test_we_need_base3_downgrade(self):
@@ -942,9 +1006,10 @@ class MultipleBaseCrossDependencyTestOne(DownIterateTest):
         # as well, which means b1b and c2.  Then we can downgrade
         # the 3s.
         self._assert_iteration(
-            "heads", "b_3@base",
-            ['b1b', 'd2', 'c2', 'b3', 'a3', 'base3'],
-            select_for_downgrade=True
+            "heads",
+            "b_3@base",
+            ["b1b", "d2", "c2", "b3", "a3", "base3"],
+            select_for_downgrade=True,
         )
 
 
@@ -952,22 +1017,20 @@ class MultipleBaseCrossDependencyTestTwo(DownIterateTest):
     def setUp(self):
         self.map = RevisionMap(
             lambda: [
-                Revision('base1', (), branch_labels='b_1'),
-                Revision('a1', 'base1'),
-                Revision('b1', 'a1'),
-                Revision('c1', 'b1'),
-
-                Revision('base2', (), dependencies='b_1', branch_labels='b_2'),
-                Revision('a2', 'base2'),
-                Revision('b2', 'a2'),
-                Revision('c2', 'b2'),
-                Revision('d2', 'c2'),
-
-                Revision('base3', (), branch_labels='b_3'),
-                Revision('a3', 'base3'),
-                Revision('b3', 'a3'),
-                Revision('c3', 'b3', dependencies='b2'),
-                Revision('d3', 'c3'),
+                Revision("base1", (), branch_labels="b_1"),
+                Revision("a1", "base1"),
+                Revision("b1", "a1"),
+                Revision("c1", "b1"),
+                Revision("base2", (), dependencies="b_1", branch_labels="b_2"),
+                Revision("a2", "base2"),
+                Revision("b2", "a2"),
+                Revision("c2", "b2"),
+                Revision("d2", "c2"),
+                Revision("base3", (), branch_labels="b_3"),
+                Revision("a3", "base3"),
+                Revision("b3", "a3"),
+                Revision("c3", "b3", dependencies="b2"),
+                Revision("d3", "c3"),
             ]
         )
 
@@ -976,55 +1039,68 @@ class MultipleBaseCrossDependencyTestTwo(DownIterateTest):
 
     def test_heads_to_base(self):
         self._assert_iteration(
-            "heads", "base",
+            "heads",
+            "base",
             [
-                'c1', 'b1', 'a1',
-                'd2', 'c2',
-                'd3', 'c3', 'b3', 'a3', 'base3',
-                'b2', 'a2', 'base2',
-                'base1'
-            ]
+                "c1",
+                "b1",
+                "a1",
+                "d2",
+                "c2",
+                "d3",
+                "c3",
+                "b3",
+                "a3",
+                "base3",
+                "b2",
+                "a2",
+                "base2",
+                "base1",
+            ],
         )
 
     def test_we_need_head2(self):
         self._assert_iteration(
-            "b_2@head", "base",
-            ['d2', 'c2', 'b2', 'a2', 'base2', 'base1']
+            "b_2@head", "base", ["d2", "c2", "b2", "a2", "base2", "base1"]
         )
 
     def test_we_need_head3(self):
         self._assert_iteration(
-            "b_3@head", "base",
-            ['d3', 'c3', 'b3', 'a3', 'base3', 'b2', 'a2', 'base2', 'base1']
+            "b_3@head",
+            "base",
+            ["d3", "c3", "b3", "a3", "base3", "b2", "a2", "base2", "base1"],
         )
 
     def test_we_need_head1(self):
-        self._assert_iteration(
-            "b_1@head", "base",
-            ['c1', 'b1', 'a1', 'base1']
-        )
+        self._assert_iteration("b_1@head", "base", ["c1", "b1", "a1", "base1"])
 
     def test_we_need_base1(self):
         self._assert_iteration(
-            "heads", "b_1@base",
+            "heads",
+            "b_1@base",
             [
-                'c1', 'b1', 'a1',
-                'd2', 'c2',
-                'd3', 'c3', 'b2', 'a2', 'base2',
-                'base1'
-            ]
+                "c1",
+                "b1",
+                "a1",
+                "d2",
+                "c2",
+                "d3",
+                "c3",
+                "b2",
+                "a2",
+                "base2",
+                "base1",
+            ],
         )
 
     def test_we_need_base2(self):
         self._assert_iteration(
-            "heads", "b_2@base",
-            ['d2', 'c2', 'd3', 'c3', 'b2', 'a2', 'base2']
+            "heads", "b_2@base", ["d2", "c2", "d3", "c3", "b2", "a2", "base2"]
         )
 
     def test_we_need_base3(self):
         self._assert_iteration(
-            "heads", "b_3@base",
-            ['d3', 'c3', 'b3', 'a3', 'base3']
+            "heads", "b_3@base", ["d3", "c3", "b3", "a3", "base3"]
         )
 
 
@@ -1035,24 +1111,21 @@ class LargeMapTest(DownIterateTest):
     def test_all(self):
         raw = [r for r in self.map._revision_map.values() if r is not None]
 
-        revs = [
-            rev for rev in
-            self.map.iterate_revisions(
-                "heads", "base"
-            )
-        ]
+        revs = [rev for rev in self.map.iterate_revisions("heads", "base")]
 
         eq_(set(raw), set(revs))
 
         for idx, rev in enumerate(revs):
-            ancestors = set(
-                self.map._get_ancestor_nodes([rev])).difference([rev])
+            ancestors = set(self.map._get_ancestor_nodes([rev])).difference(
+                [rev]
+            )
             descendants = set(
-                self.map._get_descendant_nodes([rev])).difference([rev])
+                self.map._get_descendant_nodes([rev])
+            ).difference([rev])
 
             assert not ancestors.intersection(descendants)
 
-            remaining = set(revs[idx + 1:])
+            remaining = set(revs[idx + 1 :])
             if remaining:
                 assert remaining.intersection(ancestors)
 
@@ -1061,22 +1134,20 @@ class DepResolutionFailedTest(DownIterateTest):
     def setUp(self):
         self.map = RevisionMap(
             lambda: [
-                Revision('base1', ()),
-                Revision('a1', 'base1'),
-                Revision('a2', 'base1'),
-                Revision('b1', 'a1'),
-                Revision('c1', 'b1'),
+                Revision("base1", ()),
+                Revision("a1", "base1"),
+                Revision("a2", "base1"),
+                Revision("b1", "a1"),
+                Revision("c1", "b1"),
             ]
         )
         # intentionally make a broken map
-        self.map._revision_map['fake'] = self.map._revision_map['a2']
-        self.map._revision_map['b1'].dependencies = 'fake'
-        self.map._revision_map['b1']._resolved_dependencies = ('fake', )
+        self.map._revision_map["fake"] = self.map._revision_map["a2"]
+        self.map._revision_map["b1"].dependencies = "fake"
+        self.map._revision_map["b1"]._resolved_dependencies = ("fake",)
 
     def test_failure_message(self):
         iter_ = self.map.iterate_revisions("c1", "base1")
         assert_raises_message(
-            RevisionError,
-            "Dependency resolution failed;",
-            list, iter_
+            RevisionError, "Dependency resolution failed;", list, iter_
         )
index b394784b848279ad77ba7ec6b9a1cb218a98c5f7..749b1734a35549141011b70b0dd2261261d8c4a1 100644 (file)
@@ -7,9 +7,16 @@ import textwrap
 from alembic import command, util
 from alembic.util import compat
 from alembic.script import ScriptDirectory, Script
-from alembic.testing.env import clear_staging_env, staging_env, \
-    _sqlite_testing_config, write_script, _sqlite_file_db, \
-    three_rev_fixture, _no_sql_testing_config, env_file_fixture
+from alembic.testing.env import (
+    clear_staging_env,
+    staging_env,
+    _sqlite_testing_config,
+    write_script,
+    _sqlite_file_db,
+    three_rev_fixture,
+    _no_sql_testing_config,
+    env_file_fixture,
+)
 from alembic.testing import eq_, assert_raises_message
 from alembic.testing.fixtures import TestBase, capture_context_buffer
 from alembic.environment import EnvironmentContext
@@ -18,7 +25,7 @@ from alembic.testing import mock
 
 
 class ApplyVersionsFunctionalTest(TestBase):
-    __only_on__ = 'sqlite'
+    __only_on__ = "sqlite"
 
     sourceless = False
 
@@ -46,7 +53,10 @@ class ApplyVersionsFunctionalTest(TestBase):
 
         script = ScriptDirectory.from_config(self.cfg)
         script.generate_revision(a, None, refresh=True)
-        write_script(script, a, """
+        write_script(
+            script,
+            a,
+            """
     revision = '%s'
     down_revision = None
 
@@ -60,10 +70,16 @@ class ApplyVersionsFunctionalTest(TestBase):
     def downgrade():
         op.execute("DROP TABLE foo")
 
-    """ % a, sourceless=self.sourceless)
+    """
+            % a,
+            sourceless=self.sourceless,
+        )
 
         script.generate_revision(b, None, refresh=True)
-        write_script(script, b, """
+        write_script(
+            script,
+            b,
+            """
     revision = '%s'
     down_revision = '%s'
 
@@ -77,10 +93,16 @@ class ApplyVersionsFunctionalTest(TestBase):
     def downgrade():
         op.execute("DROP TABLE bar")
 
-    """ % (b, a), sourceless=self.sourceless)
+    """
+            % (b, a),
+            sourceless=self.sourceless,
+        )
 
         script.generate_revision(c, None, refresh=True)
-        write_script(script, c, """
+        write_script(
+            script,
+            c,
+            """
     revision = '%s'
     down_revision = '%s'
 
@@ -94,49 +116,52 @@ class ApplyVersionsFunctionalTest(TestBase):
     def downgrade():
         op.execute("DROP TABLE bat")
 
-    """ % (c, b), sourceless=self.sourceless)
+    """
+            % (c, b),
+            sourceless=self.sourceless,
+        )
 
     def _test_002_upgrade(self):
         command.upgrade(self.cfg, self.c)
         db = self.bind
-        assert db.dialect.has_table(db.connect(), 'foo')
-        assert db.dialect.has_table(db.connect(), 'bar')
-        assert db.dialect.has_table(db.connect(), 'bat')
+        assert db.dialect.has_table(db.connect(), "foo")
+        assert db.dialect.has_table(db.connect(), "bar")
+        assert db.dialect.has_table(db.connect(), "bat")
 
     def _test_003_downgrade(self):
         command.downgrade(self.cfg, self.a)
         db = self.bind
-        assert db.dialect.has_table(db.connect(), 'foo')
-        assert not db.dialect.has_table(db.connect(), 'bar')
-        assert not db.dialect.has_table(db.connect(), 'bat')
+        assert db.dialect.has_table(db.connect(), "foo")
+        assert not db.dialect.has_table(db.connect(), "bar")
+        assert not db.dialect.has_table(db.connect(), "bat")
 
     def _test_004_downgrade(self):
-        command.downgrade(self.cfg, 'base')
+        command.downgrade(self.cfg, "base")
         db = self.bind
-        assert not db.dialect.has_table(db.connect(), 'foo')
-        assert not db.dialect.has_table(db.connect(), 'bar')
-        assert not db.dialect.has_table(db.connect(), 'bat')
+        assert not db.dialect.has_table(db.connect(), "foo")
+        assert not db.dialect.has_table(db.connect(), "bar")
+        assert not db.dialect.has_table(db.connect(), "bat")
 
     def _test_005_upgrade(self):
         command.upgrade(self.cfg, self.b)
         db = self.bind
-        assert db.dialect.has_table(db.connect(), 'foo')
-        assert db.dialect.has_table(db.connect(), 'bar')
-        assert not db.dialect.has_table(db.connect(), 'bat')
+        assert db.dialect.has_table(db.connect(), "foo")
+        assert db.dialect.has_table(db.connect(), "bar")
+        assert not db.dialect.has_table(db.connect(), "bat")
 
     def _test_006_upgrade_again(self):
         command.upgrade(self.cfg, self.b)
         db = self.bind
-        assert db.dialect.has_table(db.connect(), 'foo')
-        assert db.dialect.has_table(db.connect(), 'bar')
-        assert not db.dialect.has_table(db.connect(), 'bat')
+        assert db.dialect.has_table(db.connect(), "foo")
+        assert db.dialect.has_table(db.connect(), "bar")
+        assert not db.dialect.has_table(db.connect(), "bat")
 
     def _test_007_stamp_upgrade(self):
         command.stamp(self.cfg, self.c)
         db = self.bind
-        assert db.dialect.has_table(db.connect(), 'foo')
-        assert db.dialect.has_table(db.connect(), 'bar')
-        assert not db.dialect.has_table(db.connect(), 'bat')
+        assert db.dialect.has_table(db.connect(), "foo")
+        assert db.dialect.has_table(db.connect(), "bar")
+        assert not db.dialect.has_table(db.connect(), "bat")
 
 
 class SimpleSourcelessApplyVersionsTest(ApplyVersionsFunctionalTest):
@@ -144,25 +169,29 @@ class SimpleSourcelessApplyVersionsTest(ApplyVersionsFunctionalTest):
 
 
 class NewFangledSourcelessEnvOnlyApplyVersionsTest(
-        ApplyVersionsFunctionalTest):
+    ApplyVersionsFunctionalTest
+):
     sourceless = "pep3147_envonly"
 
-    __requires__ = "pep3147",
+    __requires__ = ("pep3147",)
 
 
 class NewFangledSourcelessEverythingApplyVersionsTest(
-        ApplyVersionsFunctionalTest):
+    ApplyVersionsFunctionalTest
+):
     sourceless = "pep3147_everything"
 
-    __requires__ = "pep3147",
+    __requires__ = ("pep3147",)
 
 
 class CallbackEnvironmentTest(ApplyVersionsFunctionalTest):
-    exp_kwargs = frozenset(('ctx', 'heads', 'run_args', 'step'))
+    exp_kwargs = frozenset(("ctx", "heads", "run_args", "step"))
 
     @staticmethod
     def _env_file_fixture():
-        env_file_fixture(textwrap.dedent("""\
+        env_file_fixture(
+            textwrap.dedent(
+                """\
             import alembic
             from alembic import context
             from sqlalchemy import engine_from_config, pool
@@ -199,13 +228,16 @@ class CallbackEnvironmentTest(ApplyVersionsFunctionalTest):
                 run_migrations_offline()
             else:
                 run_migrations_online()
-            """))
+            """
+            )
+        )
 
     def test_steps(self):
         import alembic
+
         alembic.mock_event_listener = None
         self._env_file_fixture()
-        with mock.patch('alembic.mock_event_listener', mock.Mock()) as mymock:
+        with mock.patch("alembic.mock_event_listener", mock.Mock()) as mymock:
             super(CallbackEnvironmentTest, self).test_steps()
         calls = mymock.call_args_list
         assert calls
@@ -213,27 +245,27 @@ class CallbackEnvironmentTest(ApplyVersionsFunctionalTest):
             args, kw = call
             assert not args
             assert set(kw.keys()) >= self.exp_kwargs
-            assert kw['run_args'] == {}
-            assert hasattr(kw['ctx'], 'get_current_revision')
+            assert kw["run_args"] == {}
+            assert hasattr(kw["ctx"], "get_current_revision")
 
-            step = kw['step']
+            step = kw["step"]
             assert isinstance(step.is_upgrade, bool)
             assert isinstance(step.is_stamp, bool)
             assert isinstance(step.is_migration, bool)
             assert isinstance(step.up_revision_id, compat.string_types)
             assert isinstance(step.up_revision, Script)
 
-            for revtype in 'up', 'down', 'source', 'destination':
-                revs = getattr(step, '%s_revisions' % revtype)
+            for revtype in "up", "down", "source", "destination":
+                revs = getattr(step, "%s_revisions" % revtype)
                 assert isinstance(revs, tuple)
                 for rev in revs:
                     assert isinstance(rev, Script)
-                revids = getattr(step, '%s_revision_ids' % revtype)
+                revids = getattr(step, "%s_revision_ids" % revtype)
                 for revid in revids:
                     assert isinstance(revid, compat.string_types)
 
-            heads = kw['heads']
-            assert hasattr(heads, '__iter__')
+            heads = kw["heads"]
+            assert hasattr(heads, "__iter__")
             for h in heads:
                 assert h is None or isinstance(h, compat.string_types)
 
@@ -242,8 +274,8 @@ class OfflineTransactionalDDLTest(TestBase):
     def setUp(self):
         self.env = staging_env()
         self.cfg = cfg = _no_sql_testing_config()
-        cfg.set_main_option('dialect_name', 'sqlite')
-        cfg.remove_main_option('url')
+        cfg.set_main_option("dialect_name", "sqlite")
+        cfg.remove_main_option("url")
 
         self.a, self.b, self.c = three_rev_fixture(cfg)
 
@@ -254,11 +286,12 @@ class OfflineTransactionalDDLTest(TestBase):
         with capture_context_buffer(transactional_ddl=True) as buf:
             command.upgrade(self.cfg, self.c, sql=True)
         assert re.match(
-            (r"^BEGIN;\s+CREATE TABLE.*?%s.*" % self.a) +
-            (r".*%s" % self.b) +
-            (r".*%s.*?COMMIT;.*$" % self.c),
-
-            buf.getvalue(), re.S)
+            (r"^BEGIN;\s+CREATE TABLE.*?%s.*" % self.a)
+            + (r".*%s" % self.b)
+            + (r".*%s.*?COMMIT;.*$" % self.c),
+            buf.getvalue(),
+            re.S,
+        )
 
     def test_begin_commit_nontransactional_ddl(self):
         with capture_context_buffer(transactional_ddl=False) as buf:
@@ -270,11 +303,12 @@ class OfflineTransactionalDDLTest(TestBase):
         with capture_context_buffer(transaction_per_migration=True) as buf:
             command.upgrade(self.cfg, self.c, sql=True)
         assert re.match(
-            (r"^BEGIN;\s+CREATE TABLE.*%s.*?COMMIT;.*" % self.a) +
-            (r"BEGIN;.*?%s.*?COMMIT;.*" % self.b) +
-            (r"BEGIN;.*?%s.*?COMMIT;.*$" % self.c),
-
-            buf.getvalue(), re.S)
+            (r"^BEGIN;\s+CREATE TABLE.*%s.*?COMMIT;.*" % self.a)
+            + (r"BEGIN;.*?%s.*?COMMIT;.*" % self.b)
+            + (r"BEGIN;.*?%s.*?COMMIT;.*$" % self.c),
+            buf.getvalue(),
+            re.S,
+        )
 
 
 class OnlineTransactionalDDLTest(TestBase):
@@ -290,7 +324,10 @@ class OnlineTransactionalDDLTest(TestBase):
         b = util.rev_id()
         c = util.rev_id()
         script.generate_revision(a, "revision a", refresh=True)
-        write_script(script, a, """
+        write_script(
+            script,
+            a,
+            """
 "rev a"
 
 revision = '%s'
@@ -302,9 +339,14 @@ def upgrade():
 def downgrade():
     pass
 
-""" % (a, ))
+"""
+            % (a,),
+        )
         script.generate_revision(b, "revision b", refresh=True)
-        write_script(script, b, """
+        write_script(
+            script,
+            b,
+            """
 "rev b"
 revision = '%s'
 down_revision = '%s'
@@ -320,9 +362,14 @@ def upgrade():
 def downgrade():
     pass
 
-""" % (b, a))
+"""
+            % (b, a),
+        )
         script.generate_revision(c, "revision c", refresh=True)
-        write_script(script, c, """
+        write_script(
+            script,
+            c,
+            """
 "rev c"
 revision = '%s'
 down_revision = '%s'
@@ -337,7 +384,9 @@ def upgrade():
 def downgrade():
     pass
 
-""" % (c, b))
+"""
+            % (c, b),
+        )
         return a, b, c
 
     @contextmanager
@@ -347,7 +396,8 @@ def downgrade():
         def configure(*arg, **opt):
             opt.update(
                 transactional_ddl=transactional_ddl,
-                transaction_per_migration=transaction_per_migration)
+                transaction_per_migration=transaction_per_migration,
+            )
             return conf(*arg, **opt)
 
         with mock.patch.object(EnvironmentContext, "configure", configure):
@@ -357,39 +407,47 @@ def downgrade():
         a, b, c = self._opened_transaction_fixture()
 
         with self._patch_environment(
-                transactional_ddl=False, transaction_per_migration=False):
+            transactional_ddl=False, transaction_per_migration=False
+        ):
             assert_raises_message(
                 util.CommandError,
                 r'Migration "upgrade .*, rev b" has left an uncommitted '
-                r'transaction opened; transactional_ddl is False so Alembic '
-                r'is not committing transactions',
-                command.upgrade, self.cfg, c
+                r"transaction opened; transactional_ddl is False so Alembic "
+                r"is not committing transactions",
+                command.upgrade,
+                self.cfg,
+                c,
             )
 
     def test_raise_when_rev_leaves_open_transaction_tpm(self):
         a, b, c = self._opened_transaction_fixture()
 
         with self._patch_environment(
-                transactional_ddl=False, transaction_per_migration=True):
+            transactional_ddl=False, transaction_per_migration=True
+        ):
             assert_raises_message(
                 util.CommandError,
                 r'Migration "upgrade .*, rev b" has left an uncommitted '
-                r'transaction opened; transactional_ddl is False so Alembic '
-                r'is not committing transactions',
-                command.upgrade, self.cfg, c
+                r"transaction opened; transactional_ddl is False so Alembic "
+                r"is not committing transactions",
+                command.upgrade,
+                self.cfg,
+                c,
             )
 
     def test_noerr_rev_leaves_open_transaction_transactional_ddl(self):
         a, b, c = self._opened_transaction_fixture()
 
         with self._patch_environment(
-                transactional_ddl=True, transaction_per_migration=False):
+            transactional_ddl=True, transaction_per_migration=False
+        ):
             command.upgrade(self.cfg, c)
 
     def test_noerr_transaction_opened_externally(self):
         a, b, c = self._opened_transaction_fixture()
 
-        env_file_fixture("""
+        env_file_fixture(
+            """
 from sqlalchemy import engine_from_config, pool
 
 def run_migrations_online():
@@ -411,22 +469,27 @@ def run_migrations_online():
 
 run_migrations_online()
 
-""")
+"""
+        )
 
         command.stamp(self.cfg, c)
 
 
 class EncodingTest(TestBase):
-
     def setUp(self):
         self.env = staging_env()
         self.cfg = cfg = _no_sql_testing_config()
-        cfg.set_main_option('dialect_name', 'sqlite')
-        cfg.remove_main_option('url')
+        cfg.set_main_option("dialect_name", "sqlite")
+        cfg.remove_main_option("url")
         self.a = util.rev_id()
         script = ScriptDirectory.from_config(cfg)
         script.generate_revision(self.a, "revision a", refresh=True)
-        write_script(script, self.a, (compat.u("""# coding: utf-8
+        write_script(
+            script,
+            self.a,
+            (
+                compat.u(
+                    """# coding: utf-8
 from __future__ import unicode_literals
 revision = '%s'
 down_revision = None
@@ -439,22 +502,25 @@ def upgrade():
 def downgrade():
     op.execute("drôle de petite voix m’a réveillé")
 
-""") % self.a), encoding='utf-8')
+"""
+                )
+                % self.a
+            ),
+            encoding="utf-8",
+        )
 
     def tearDown(self):
         clear_staging_env()
 
     def test_encode(self):
         with capture_context_buffer(
-            bytes_io=True,
-            output_encoding='utf-8'
+            bytes_io=True, output_encoding="utf-8"
         ) as buf:
             command.upgrade(self.cfg, self.a, sql=True)
         assert compat.u("« S’il vous plaît…").encode("utf-8") in buf.getvalue()
 
 
 class VersionNameTemplateTest(TestBase):
-
     def setUp(self):
         self.env = staging_env()
         self.cfg = _sqlite_testing_config()
@@ -467,7 +533,10 @@ class VersionNameTemplateTest(TestBase):
         script = ScriptDirectory.from_config(self.cfg)
         a = util.rev_id()
         script.generate_revision(a, "some message", refresh=True)
-        write_script(script, a, """
+        write_script(
+            script,
+            a,
+            """
     revision = '%s'
     down_revision = None
 
@@ -481,7 +550,9 @@ class VersionNameTemplateTest(TestBase):
     def downgrade():
         op.execute("DROP TABLE foo")
 
-    """ % a)
+    """
+            % a,
+        )
 
         script = ScriptDirectory.from_config(self.cfg)
         rev = script.get_revision(a)
@@ -493,7 +564,10 @@ class VersionNameTemplateTest(TestBase):
         script = ScriptDirectory.from_config(self.cfg)
         a = util.rev_id()
         script.generate_revision(a, None, refresh=True)
-        write_script(script, a, """
+        write_script(
+            script,
+            a,
+            """
     down_revision = None
 
     from alembic import op
@@ -506,7 +580,8 @@ class VersionNameTemplateTest(TestBase):
     def downgrade():
         op.execute("DROP TABLE foo")
 
-    """)
+    """,
+        )
 
         script = ScriptDirectory.from_config(self.cfg)
         rev = script.get_revision(a)
@@ -520,8 +595,9 @@ class VersionNameTemplateTest(TestBase):
         script.generate_revision(a, "foobar", refresh=True)
 
         path = script.get_revision(a).path
-        with open(path, 'w') as fp:
-            fp.write("""
+        with open(path, "w") as fp:
+            fp.write(
+                """
 down_revision = None
 
 from alembic import op
@@ -534,7 +610,8 @@ def upgrade():
 def downgrade():
     op.execute("DROP TABLE foo")
 
-""")
+"""
+            )
         pyc_path = util.pyc_file_from_path(path)
         if pyc_path is not None and os.access(pyc_path, os.F_OK):
             os.unlink(pyc_path)
@@ -544,7 +621,10 @@ def downgrade():
             "Could not determine revision id from filename foobar_%s.py. "
             "Be sure the 'revision' variable is declared "
             "inside the script." % a,
-            Script._from_path, script, path)
+            Script._from_path,
+            script,
+            path,
+        )
 
 
 class IgnoreFilesTest(TestBase):
@@ -563,13 +643,11 @@ class IgnoreFilesTest(TestBase):
         command.revision(self.cfg, message="some rev")
         script = ScriptDirectory.from_config(self.cfg)
         path = os.path.join(script.versions, fname)
-        with open(path, 'w') as f:
-            f.write(
-                "crap, crap -> crap"
-            )
+        with open(path, "w") as f:
+            f.write("crap, crap -> crap")
         command.revision(self.cfg, message="another rev")
 
-        script.get_revision('head')
+        script.get_revision("head")
 
     def _test_ignore_init_py(self, ext):
         """test that __init__.py is ignored."""
@@ -613,17 +691,16 @@ class SimpleSourcelessIgnoreFilesTest(IgnoreFilesTest):
 class NewFangledEnvOnlySourcelessIgnoreFilesTest(IgnoreFilesTest):
     sourceless = "pep3147_envonly"
 
-    __requires__ = "pep3147",
+    __requires__ = ("pep3147",)
 
 
 class NewFangledEverythingSourcelessIgnoreFilesTest(IgnoreFilesTest):
     sourceless = "pep3147_everything"
 
-    __requires__ = "pep3147",
+    __requires__ = ("pep3147",)
 
 
 class SourcelessNeedsFlagTest(TestBase):
-
     def setUp(self):
         self.env = staging_env(sourceless=False)
         self.cfg = _sqlite_testing_config()
@@ -636,7 +713,10 @@ class SourcelessNeedsFlagTest(TestBase):
 
         script = ScriptDirectory.from_config(self.cfg)
         script.generate_revision(a, None, refresh=True)
-        write_script(script, a, """
+        write_script(
+            script,
+            a,
+            """
     revision = '%s'
     down_revision = None
 
@@ -650,7 +730,10 @@ class SourcelessNeedsFlagTest(TestBase):
     def downgrade():
         op.execute("DROP TABLE foo")
 
-    """ % a, sourceless=True)
+    """
+            % a,
+            sourceless=True,
+        )
 
         script = ScriptDirectory.from_config(self.cfg)
         eq_(script.get_heads(), [])
index af01a38a40431183c2a8e5942f52d834e2cbfd60..f7837d9f5d1670a3f382a8079eb2bd1b5fe96fae 100644 (file)
@@ -1,10 +1,20 @@
 from alembic.testing.fixtures import TestBase
 from alembic.testing import eq_, ne_, assert_raises_message, is_, assertions
-from alembic.testing.env import clear_staging_env, staging_env, \
-    _get_staging_directory, _no_sql_testing_config, env_file_fixture, \
-    script_file_fixture, _testing_config, _sqlite_testing_config, \
-    three_rev_fixture, _multi_dir_testing_config, write_script,\
-    _sqlite_file_db, _multidb_testing_config
+from alembic.testing.env import (
+    clear_staging_env,
+    staging_env,
+    _get_staging_directory,
+    _no_sql_testing_config,
+    env_file_fixture,
+    script_file_fixture,
+    _testing_config,
+    _sqlite_testing_config,
+    three_rev_fixture,
+    _multi_dir_testing_config,
+    write_script,
+    _sqlite_file_db,
+    _multidb_testing_config,
+)
 from alembic import command
 from alembic.script import ScriptDirectory
 from alembic.environment import EnvironmentContext
@@ -24,7 +34,6 @@ env, abc, def_ = None, None, None
 
 
 class GeneralOrderedTests(TestBase):
-
     def setUp(self):
         global env
         env = staging_env()
@@ -43,11 +52,8 @@ class GeneralOrderedTests(TestBase):
         self._test_008_long_name_configurable()
 
     def _test_001_environment(self):
-        assert_set = set(['env.py', 'script.py.mako', 'README'])
-        eq_(
-            assert_set.intersection(os.listdir(env.dir)),
-            assert_set
-        )
+        assert_set = set(["env.py", "script.py.mako", "README"])
+        eq_(assert_set.intersection(os.listdir(env.dir)), assert_set)
 
     def _test_002_rev_ids(self):
         global abc, def_
@@ -66,19 +72,23 @@ class GeneralOrderedTests(TestBase):
         eq_(script.revision, abc)
         eq_(script.down_revision, None)
         assert os.access(
-            os.path.join(env.dir, 'versions',
-                         '%s_this_is_a_message.py' % abc), os.F_OK)
+            os.path.join(env.dir, "versions", "%s_this_is_a_message.py" % abc),
+            os.F_OK,
+        )
         assert callable(script.module.upgrade)
         eq_(env.get_heads(), [abc])
         eq_(env.get_base(), abc)
 
     def _test_005_nextrev(self):
         script = env.generate_revision(
-            def_, "this is the next rev", refresh=True)
+            def_, "this is the next rev", refresh=True
+        )
         assert os.access(
             os.path.join(
-                env.dir, 'versions',
-                '%s_this_is_the_next_rev.py' % def_), os.F_OK)
+                env.dir, "versions", "%s_this_is_the_next_rev.py" % def_
+            ),
+            os.F_OK,
+        )
         eq_(script.revision, def_)
         eq_(script.down_revision, abc)
         eq_(env.get_revision(abc).nextrev, set([def_]))
@@ -103,32 +113,42 @@ class GeneralOrderedTests(TestBase):
 
     def _test_007_long_name(self):
         rid = util.rev_id()
-        env.generate_revision(rid,
-                              "this is a really long name with "
-                              "lots of characters and also "
-                              "I'd like it to\nhave\nnewlines")
+        env.generate_revision(
+            rid,
+            "this is a really long name with "
+            "lots of characters and also "
+            "I'd like it to\nhave\nnewlines",
+        )
         assert os.access(
             os.path.join(
-                env.dir, 'versions',
-                '%s_this_is_a_really_long_name_with_lots_of_.py' % rid),
-            os.F_OK)
+                env.dir,
+                "versions",
+                "%s_this_is_a_really_long_name_with_lots_of_.py" % rid,
+            ),
+            os.F_OK,
+        )
 
     def _test_008_long_name_configurable(self):
         env.truncate_slug_length = 60
         rid = util.rev_id()
-        env.generate_revision(rid,
-                              "this is a really long name with "
-                              "lots of characters and also "
-                              "I'd like it to\nhave\nnewlines")
+        env.generate_revision(
+            rid,
+            "this is a really long name with "
+            "lots of characters and also "
+            "I'd like it to\nhave\nnewlines",
+        )
         assert os.access(
-            os.path.join(env.dir, 'versions',
-                         '%s_this_is_a_really_long_name_with_lots_'
-                         'of_characters_and_also_.py' % rid),
-            os.F_OK)
+            os.path.join(
+                env.dir,
+                "versions",
+                "%s_this_is_a_really_long_name_with_lots_"
+                "of_characters_and_also_.py" % rid,
+            ),
+            os.F_OK,
+        )
 
 
 class ScriptNamingTest(TestBase):
-
     @classmethod
     def setup_class(cls):
         _testing_config()
@@ -143,15 +163,17 @@ class ScriptNamingTest(TestBase):
             file_template="%(rev)s_%(slug)s_"
             "%(year)s_%(month)s_"
             "%(day)s_%(hour)s_"
-            "%(minute)s_%(second)s"
+            "%(minute)s_%(second)s",
         )
         create_date = datetime.datetime(2012, 7, 25, 15, 8, 5)
         eq_(
             script._rev_path(
-                script.versions, "12345", "this is a message", create_date),
+                script.versions, "12345", "this is a message", create_date
+            ),
             os.path.abspath(
                 "%s/versions/12345_this_is_a_"
-                "message_2012_7_25_15_8_5.py" % _get_staging_directory())
+                "message_2012_7_25_15_8_5.py" % _get_staging_directory()
+            ),
         )
 
     def _test_tz(self, timezone_arg, given, expected):
@@ -161,61 +183,57 @@ class ScriptNamingTest(TestBase):
             "%(year)s_%(month)s_"
             "%(day)s_%(hour)s_"
             "%(minute)s_%(second)s",
-            timezone=timezone_arg
+            timezone=timezone_arg,
         )
 
         with mock.patch(
-                "alembic.script.base.datetime",
-                mock.Mock(
-                    datetime=mock.Mock(
-                        utcnow=lambda: given,
-                        now=lambda: given
-                    )
-                )
+            "alembic.script.base.datetime",
+            mock.Mock(
+                datetime=mock.Mock(utcnow=lambda: given, now=lambda: given)
+            ),
         ):
             create_date = script._generate_create_date()
-        eq_(
-            create_date,
-            expected
-        )
+        eq_(create_date, expected)
 
     def test_custom_tz(self):
         self._test_tz(
-            'EST5EDT',
+            "EST5EDT",
             datetime.datetime(2012, 7, 25, 15, 8, 5),
             datetime.datetime(
-                2012, 7, 25, 11, 8, 5, tzinfo=tz.gettz('EST5EDT'))
+                2012, 7, 25, 11, 8, 5, tzinfo=tz.gettz("EST5EDT")
+            ),
         )
 
     def test_custom_tz_lowercase(self):
         self._test_tz(
-            'est5edt',
+            "est5edt",
             datetime.datetime(2012, 7, 25, 15, 8, 5),
             datetime.datetime(
-                2012, 7, 25, 11, 8, 5, tzinfo=tz.gettz('EST5EDT'))
+                2012, 7, 25, 11, 8, 5, tzinfo=tz.gettz("EST5EDT")
+            ),
         )
 
     def test_custom_tz_utc(self):
         self._test_tz(
-            'utc',
+            "utc",
             datetime.datetime(2012, 7, 25, 15, 8, 5),
-            datetime.datetime(
-                2012, 7, 25, 15, 8, 5, tzinfo=tz.gettz('UTC'))
+            datetime.datetime(2012, 7, 25, 15, 8, 5, tzinfo=tz.gettz("UTC")),
         )
 
     def test_custom_tzdata_tz(self):
         self._test_tz(
-            'Europe/Berlin',
+            "Europe/Berlin",
             datetime.datetime(2012, 7, 25, 15, 8, 5),
             datetime.datetime(
-                2012, 7, 25, 17, 8, 5, tzinfo=tz.gettz('Europe/Berlin'))
+                2012, 7, 25, 17, 8, 5, tzinfo=tz.gettz("Europe/Berlin")
+            ),
         )
 
     def test_default_tz(self):
         self._test_tz(
             None,
             datetime.datetime(2012, 7, 25, 15, 8, 5),
-            datetime.datetime(2012, 7, 25, 15, 8, 5)
+            datetime.datetime(2012, 7, 25, 15, 8, 5),
         )
 
     def test_tz_cant_locate(self):
@@ -225,7 +243,7 @@ class ScriptNamingTest(TestBase):
             self._test_tz,
             "fake",
             datetime.datetime(2012, 7, 25, 15, 8, 5),
-            datetime.datetime(2012, 7, 25, 15, 8, 5)
+            datetime.datetime(2012, 7, 25, 15, 8, 5),
         )
 
 
@@ -247,7 +265,8 @@ class RevisionCommandTest(TestBase):
 
     def test_create_script_splice(self):
         rev = command.revision(
-            self.cfg, message="some message", head=self.b, splice=True)
+            self.cfg, message="some message", head=self.b, splice=True
+        )
         script = ScriptDirectory.from_config(self.cfg)
         rev = script.get_revision(rev.revision)
         eq_(rev.down_revision, self.b)
@@ -260,7 +279,9 @@ class RevisionCommandTest(TestBase):
             "Revision %s is not a head revision; please specify --splice "
             "to create a new branch from this revision" % self.b,
             command.revision,
-            self.cfg, message="some message", head=self.b
+            self.cfg,
+            message="some message",
+            head=self.b,
         )
 
     def test_illegal_revision_chars(self):
@@ -269,19 +290,23 @@ class RevisionCommandTest(TestBase):
             r"Character\(s\) '-' not allowed in "
             "revision identifier 'no-dashes'",
             command.revision,
-            self.cfg, message="some message", rev_id="no-dashes"
+            self.cfg,
+            message="some message",
+            rev_id="no-dashes",
         )
 
         assert not os.path.exists(
-            os.path.join(
-                self.env.dir, "versions", "no-dashes_some_message.py"))
+            os.path.join(self.env.dir, "versions", "no-dashes_some_message.py")
+        )
 
         assert_raises_message(
             util.CommandError,
             r"Character\(s\) '@' not allowed in "
             "revision identifier 'no@atsigns'",
             command.revision,
-            self.cfg, message="some message", rev_id="no@atsigns"
+            self.cfg,
+            message="some message",
+            rev_id="no@atsigns",
         )
 
         assert_raises_message(
@@ -289,7 +314,9 @@ class RevisionCommandTest(TestBase):
             r"Character\(s\) '-, @' not allowed in revision "
             "identifier 'no@atsigns-ordashes'",
             command.revision,
-            self.cfg, message="some message", rev_id="no@atsigns-ordashes"
+            self.cfg,
+            message="some message",
+            rev_id="no@atsigns-ordashes",
         )
 
         assert_raises_message(
@@ -297,12 +324,15 @@ class RevisionCommandTest(TestBase):
             r"Character\(s\) '\+' not allowed in revision "
             r"identifier 'no\+plussignseither'",
             command.revision,
-            self.cfg, message="some message", rev_id="no+plussignseither"
+            self.cfg,
+            message="some message",
+            rev_id="no+plussignseither",
         )
 
     def test_create_script_branches(self):
         rev = command.revision(
-            self.cfg, message="some message", branch_label="foobar")
+            self.cfg, message="some message", branch_label="foobar"
+        )
         script = ScriptDirectory.from_config(self.cfg)
         rev = script.get_revision(rev.revision)
         eq_(script.get_revision("foobar"), rev)
@@ -330,7 +360,9 @@ class RevisionCommandTest(TestBase):
             "upgraded your script.py.mako to include the 'branch_labels' "
             r"section\?",
             command.revision,
-            self.cfg, message="some message", branch_label="foobar"
+            self.cfg,
+            message="some message",
+            branch_label="foobar",
         )
 
 
@@ -350,11 +382,17 @@ class CustomizeRevisionTest(TestBase):
             (self.model3, "model3"),
         ]:
             script.generate_revision(
-                model, name, refresh=True,
+                model,
+                name,
+                refresh=True,
                 version_path=os.path.join(_get_staging_directory(), name),
-                head="base")
+                head="base",
+            )
 
-            write_script(script, model, """\
+            write_script(
+                script,
+                model,
+                """\
 "%s"
 revision = '%s'
 down_revision = None
@@ -370,7 +408,9 @@ def upgrade():
 def downgrade():
     pass
 
-""" % (name, model, name))
+"""
+                % (name, model, name),
+            )
 
     def tearDown(self):
         clear_staging_env()
@@ -385,13 +425,13 @@ def downgrade():
                 context.configure(
                     connection=connection,
                     target_metadata=target_metadata,
-                    process_revision_directives=fn)
+                    process_revision_directives=fn,
+                )
                 with context.begin_transaction():
                     context.run_migrations()
 
         return mock.patch(
-            "alembic.script.base.ScriptDirectory.run_env",
-            run_env
+            "alembic.script.base.ScriptDirectory.run_env", run_env
         )
 
     def test_new_locations_no_autogen(self):
@@ -404,24 +444,27 @@ def downgrade():
                     ops.UpgradeOps(),
                     ops.DowngradeOps(),
                     version_path=os.path.join(
-                        _get_staging_directory(), "model1"),
-                    head="model1@head"
+                        _get_staging_directory(), "model1"
+                    ),
+                    head="model1@head",
                 ),
                 ops.MigrationScript(
                     util.rev_id(),
                     ops.UpgradeOps(),
                     ops.DowngradeOps(),
                     version_path=os.path.join(
-                        _get_staging_directory(), "model2"),
-                    head="model2@head"
+                        _get_staging_directory(), "model2"
+                    ),
+                    head="model2@head",
                 ),
                 ops.MigrationScript(
                     util.rev_id(),
                     ops.UpgradeOps(),
                     ops.DowngradeOps(),
                     version_path=os.path.join(
-                        _get_staging_directory(), "model3"),
-                    head="model3@head"
+                        _get_staging_directory(), "model3"
+                    ),
+                    head="model3@head",
                 ),
             ]
 
@@ -438,10 +481,13 @@ def downgrade():
             rev_script = script.get_revision(rev.revision)
             eq_(
                 rev_script.path,
-                os.path.abspath(os.path.join(
-                    _get_staging_directory(), model,
-                    "%s_.py" % (rev_script.revision, )
-                ))
+                os.path.abspath(
+                    os.path.join(
+                        _get_staging_directory(),
+                        model,
+                        "%s_.py" % (rev_script.revision,),
+                    )
+                ),
             )
             assert os.path.exists(rev_script.path)
 
@@ -455,19 +501,23 @@ def downgrade():
 
         with self._env_fixture(process_revision_directives, m):
             rev = command.revision(
-                self.cfg, message="some message", head="model1@head", sql=True)
+                self.cfg, message="some message", head="model1@head", sql=True
+            )
 
         with mock.patch.object(rev.module, "op") as op_mock:
             rev.module.upgrade()
         eq_(
             op_mock.mock_calls,
-            [mock.call.create_index(
-                'some_index', 'some_table', ['a', 'b'], unique=False)]
+            [
+                mock.call.create_index(
+                    "some_index", "some_table", ["a", "b"], unique=False
+                )
+            ],
         )
 
     def test_autogen(self):
         m = sa.MetaData()
-        sa.Table('t', m, sa.Column('x', sa.Integer))
+        sa.Table("t", m, sa.Column("x", sa.Integer))
 
         def process_revision_directives(context, rev, generate_revisions):
             existing_upgrades = generate_revisions[0].upgrade_ops
@@ -483,17 +533,19 @@ def downgrade():
                     existing_upgrades,
                     ops.DowngradeOps(),
                     version_path=os.path.join(
-                        _get_staging_directory(), "model1"),
-                    head="model1@head"
+                        _get_staging_directory(), "model1"
+                    ),
+                    head="model1@head",
                 ),
                 ops.MigrationScript(
                     util.rev_id(),
                     ops.UpgradeOps(ops=existing_downgrades.ops),
                     ops.DowngradeOps(),
                     version_path=os.path.join(
-                        _get_staging_directory(), "model2"),
-                    head="model2@head"
-                )
+                        _get_staging_directory(), "model2"
+                    ),
+                    head="model2@head",
+                ),
             ]
 
         with self._env_fixture(process_revision_directives, m):
@@ -501,57 +553,57 @@ def downgrade():
 
             eq_(
                 Inspector.from_engine(self.engine).get_table_names(),
-                ["alembic_version"]
+                ["alembic_version"],
             )
 
             command.revision(
-                self.cfg, message="some message",
-                autogenerate=True)
+                self.cfg, message="some message", autogenerate=True
+            )
 
             command.upgrade(self.cfg, "model1@head")
 
             eq_(
                 Inspector.from_engine(self.engine).get_table_names(),
-                ["alembic_version", "t"]
+                ["alembic_version", "t"],
             )
 
             command.upgrade(self.cfg, "model2@head")
 
             eq_(
                 Inspector.from_engine(self.engine).get_table_names(),
-                ["alembic_version"]
+                ["alembic_version"],
             )
 
     def test_programmatic_command_option(self):
-
         def process_revision_directives(context, rev, generate_revisions):
             generate_revisions[0].message = "test programatic"
             generate_revisions[0].upgrade_ops = ops.UpgradeOps(
                 ops=[
                     ops.CreateTableOp(
-                        'test_table',
+                        "test_table",
                         [
-                            sa.Column('id', sa.Integer(), primary_key=True),
-                            sa.Column('name', sa.String(50), nullable=False)
-                        ]
-                    ),
+                            sa.Column("id", sa.Integer(), primary_key=True),
+                            sa.Column("name", sa.String(50), nullable=False),
+                        ],
+                    )
                 ]
             )
             generate_revisions[0].downgrade_ops = ops.DowngradeOps(
-                ops=[
-                    ops.DropTableOp('test_table')
-                ]
+                ops=[ops.DropTableOp("test_table")]
             )
 
         with self._env_fixture(None, None):
             rev = command.revision(
                 self.cfg,
                 head="model1@head",
-                process_revision_directives=process_revision_directives)
+                process_revision_directives=process_revision_directives,
+            )
 
         with open(rev.path) as handle:
             result = handle.read()
-        assert ("""
+        assert (
+            (
+                """
 def upgrade():
     # ### commands auto generated by Alembic - please adjust! ###
     op.create_table('test_table',
@@ -560,22 +612,19 @@ def upgrade():
     sa.PrimaryKeyConstraint('id')
     )
     # ### end Alembic commands ###
-""") in result
+"""
+            )
+            in result
+        )
 
 
 class ScriptAccessorTest(TestBase):
     def test_upgrade_downgrade_ops_list_accessors(self):
         u1 = ops.UpgradeOps(ops=[])
         d1 = ops.DowngradeOps(ops=[])
-        m1 = ops.MigrationScript(
-            "somerev", u1, d1
-        )
-        is_(
-            m1.upgrade_ops, u1
-        )
-        is_(
-            m1.downgrade_ops, d1
-        )
+        m1 = ops.MigrationScript("somerev", u1, d1)
+        is_(m1.upgrade_ops, u1)
+        is_(m1.downgrade_ops, d1)
         u2 = ops.UpgradeOps(ops=[])
         d2 = ops.DowngradeOps(ops=[])
         m1._upgrade_ops.append(u2)
@@ -585,13 +634,17 @@ class ScriptAccessorTest(TestBase):
             ValueError,
             "This MigrationScript instance has a multiple-entry list for "
             "UpgradeOps; please use the upgrade_ops_list attribute.",
-            getattr, m1, "upgrade_ops"
+            getattr,
+            m1,
+            "upgrade_ops",
         )
         assert_raises_message(
             ValueError,
             "This MigrationScript instance has a multiple-entry list for "
             "DowngradeOps; please use the downgrade_ops_list attribute.",
-            getattr, m1, "downgrade_ops"
+            getattr,
+            m1,
+            "downgrade_ops",
         )
         eq_(m1.upgrade_ops_list, [u1, u2])
         eq_(m1.downgrade_ops_list, [d1, d2])
@@ -615,39 +668,36 @@ class ImportsTest(TestBase):
                 context.configure(
                     connection=connection,
                     target_metadata=target_metadata,
-                    **kw)
+                    **kw
+                )
                 with context.begin_transaction():
                     context.run_migrations()
 
         return mock.patch(
-            "alembic.script.base.ScriptDirectory.run_env",
-            run_env
+            "alembic.script.base.ScriptDirectory.run_env", run_env
         )
 
     def test_imports_in_script(self):
         from sqlalchemy import MetaData, Table, Column
         from sqlalchemy.dialects.mysql import VARCHAR
 
-        type_ = VARCHAR(20, charset='utf8', national=True)
+        type_ = VARCHAR(20, charset="utf8", national=True)
 
         m = MetaData()
 
-        Table(
-            't', m,
-            Column('x', type_)
-        )
+        Table("t", m, Column("x", type_))
 
         def process_revision_directives(context, rev, generate_revisions):
             generate_revisions[0].imports.add(
-                "from sqlalchemy.dialects.mysql import TINYINT")
+                "from sqlalchemy.dialects.mysql import TINYINT"
+            )
 
         with self._env_fixture(
-                m,
-                process_revision_directives=process_revision_directives
+            m, process_revision_directives=process_revision_directives
         ):
             rev = command.revision(
-                self.cfg, message="some message",
-                autogenerate=True)
+                self.cfg, message="some message", autogenerate=True
+            )
 
         with open(rev.path) as file_:
             contents = file_.read()
@@ -659,24 +709,24 @@ class MultiContextTest(TestBase):
     """test the multidb template for autogenerate front-to-back"""
 
     def setUp(self):
-        self.engine1 = _sqlite_file_db(tempname='eng1.db')
-        self.engine2 = _sqlite_file_db(tempname='eng2.db')
-        self.engine3 = _sqlite_file_db(tempname='eng3.db')
+        self.engine1 = _sqlite_file_db(tempname="eng1.db")
+        self.engine2 = _sqlite_file_db(tempname="eng2.db")
+        self.engine3 = _sqlite_file_db(tempname="eng3.db")
 
         self.env = staging_env(template="multidb")
-        self.cfg = _multidb_testing_config({
-            "engine1": self.engine1,
-            "engine2": self.engine2,
-            "engine3": self.engine3
-        })
+        self.cfg = _multidb_testing_config(
+            {
+                "engine1": self.engine1,
+                "engine2": self.engine2,
+                "engine3": self.engine3,
+            }
+        )
 
     def _write_metadata(self, meta):
-        path = os.path.join(_get_staging_directory(), 'scripts', 'env.py')
+        path = os.path.join(_get_staging_directory(), "scripts", "env.py")
         with open(path) as env_:
             existing_env = env_.read()
-        existing_env = existing_env.replace(
-            "target_metadata = {}",
-            meta)
+        existing_env = existing_env.replace("target_metadata = {}", meta)
         with open(path, "w") as env_:
             env_.write(existing_env)
 
@@ -701,40 +751,30 @@ sa.Table('e3t1', m3, sa.Column('z', sa.Integer))
         )
 
         rev = command.revision(
-            self.cfg, message="some message",
-            autogenerate=True
+            self.cfg, message="some message", autogenerate=True
         )
         with mock.patch.object(rev.module, "op") as op_mock:
             rev.module.upgrade_engine1()
             eq_(
                 op_mock.mock_calls[-1],
-                mock.call.create_table('e1t1', mock.ANY)
+                mock.call.create_table("e1t1", mock.ANY),
             )
             rev.module.upgrade_engine2()
             eq_(
                 op_mock.mock_calls[-1],
-                mock.call.create_table('e2t1', mock.ANY)
+                mock.call.create_table("e2t1", mock.ANY),
             )
             rev.module.upgrade_engine3()
             eq_(
                 op_mock.mock_calls[-1],
-                mock.call.create_table('e3t1', mock.ANY)
+                mock.call.create_table("e3t1", mock.ANY),
             )
             rev.module.downgrade_engine1()
-            eq_(
-                op_mock.mock_calls[-1],
-                mock.call.drop_table('e1t1')
-            )
+            eq_(op_mock.mock_calls[-1], mock.call.drop_table("e1t1"))
             rev.module.downgrade_engine2()
-            eq_(
-                op_mock.mock_calls[-1],
-                mock.call.drop_table('e2t1')
-            )
+            eq_(op_mock.mock_calls[-1], mock.call.drop_table("e2t1"))
             rev.module.downgrade_engine3()
-            eq_(
-                op_mock.mock_calls[-1],
-                mock.call.drop_table('e3t1')
-            )
+            eq_(op_mock.mock_calls[-1], mock.call.drop_table("e3t1"))
 
 
 class RewriterTest(TestBase):
@@ -744,20 +784,13 @@ class RewriterTest(TestBase):
         mocker = mock.Mock(side_effect=lambda context, revision, op: op)
         writer.rewrites(ops.MigrateOperation)(mocker)
 
-        addcolop = ops.AddColumnOp(
-            't1', sa.Column('x', sa.Integer())
-        )
+        addcolop = ops.AddColumnOp("t1", sa.Column("x", sa.Integer()))
 
         directives = [
             ops.MigrationScript(
                 util.rev_id(),
-                ops.UpgradeOps(ops=[
-                    ops.ModifyTableOps('t1', ops=[
-                        addcolop
-                    ])
-                ]),
-                ops.DowngradeOps(ops=[
-                ]),
+                ops.UpgradeOps(ops=[ops.ModifyTableOps("t1", ops=[addcolop])]),
+                ops.DowngradeOps(ops=[]),
             )
         ]
 
@@ -771,7 +804,7 @@ class RewriterTest(TestBase):
                 mock.call(ctx, rev, directives[0].upgrade_ops.ops[0]),
                 mock.call(ctx, rev, addcolop),
                 mock.call(ctx, rev, directives[0].downgrade_ops),
-            ]
+            ],
         )
 
     def test_double_migrate_table(self):
@@ -783,28 +816,33 @@ class RewriterTest(TestBase):
         def second_table(context, revision, op):
             return [
                 op,
-                ops.ModifyTableOps('t2', ops=[
-                    ops.AddColumnOp('t2', sa.Column('x', sa.Integer()))
-                ])
+                ops.ModifyTableOps(
+                    "t2",
+                    ops=[ops.AddColumnOp("t2", sa.Column("x", sa.Integer()))],
+                ),
             ]
 
         @writer.rewrites(ops.AddColumnOp)
         def add_column(context, revision, op):
-            idx_op = ops.CreateIndexOp('ixt', op.table_name, [op.column.name])
+            idx_op = ops.CreateIndexOp("ixt", op.table_name, [op.column.name])
             idx_ops.append(idx_op)
-            return [
-                op,
-                idx_op
-            ]
+            return [op, idx_op]
 
         directives = [
             ops.MigrationScript(
                 util.rev_id(),
-                ops.UpgradeOps(ops=[
-                    ops.ModifyTableOps('t1', ops=[
-                        ops.AddColumnOp('t1', sa.Column('x', sa.Integer()))
-                    ])
-                ]),
+                ops.UpgradeOps(
+                    ops=[
+                        ops.ModifyTableOps(
+                            "t1",
+                            ops=[
+                                ops.AddColumnOp(
+                                    "t1", sa.Column("x", sa.Integer())
+                                )
+                            ],
+                        )
+                    ]
+                ),
                 ops.DowngradeOps(ops=[]),
             )
         ]
@@ -812,17 +850,10 @@ class RewriterTest(TestBase):
         ctx, rev = mock.Mock(), mock.Mock()
         writer(ctx, rev, directives)
         eq_(
-            [d.table_name for d in directives[0].upgrade_ops.ops],
-            ['t1', 't2']
-        )
-        is_(
-            directives[0].upgrade_ops.ops[0].ops[1],
-            idx_ops[0]
-        )
-        is_(
-            directives[0].upgrade_ops.ops[1].ops[1],
-            idx_ops[1]
+            [d.table_name for d in directives[0].upgrade_ops.ops], ["t1", "t2"]
         )
+        is_(directives[0].upgrade_ops.ops[0].ops[1], idx_ops[0])
+        is_(directives[0].upgrade_ops.ops[1].ops[1], idx_ops[1])
 
     def test_chained_ops(self):
         writer1 = autogenerate.Rewriter()
@@ -841,26 +872,32 @@ class RewriterTest(TestBase):
                         op.column.name,
                         modify_nullable=False,
                         existing_type=op.column.type,
-                    )
+                    ),
                 ]
 
         @writer2.rewrites(ops.AddColumnOp)
         def add_column_idx(context, revision, op):
-            idx_op = ops.CreateIndexOp('ixt', op.table_name, [op.column.name])
-            return [
-                op,
-                idx_op
-            ]
+            idx_op = ops.CreateIndexOp("ixt", op.table_name, [op.column.name])
+            return [op, idx_op]
 
         directives = [
             ops.MigrationScript(
                 util.rev_id(),
-                ops.UpgradeOps(ops=[
-                    ops.ModifyTableOps('t1', ops=[
-                        ops.AddColumnOp(
-                            't1', sa.Column('x', sa.Integer(), nullable=False))
-                    ])
-                ]),
+                ops.UpgradeOps(
+                    ops=[
+                        ops.ModifyTableOps(
+                            "t1",
+                            ops=[
+                                ops.AddColumnOp(
+                                    "t1",
+                                    sa.Column(
+                                        "x", sa.Integer(), nullable=False
+                                    ),
+                                )
+                            ],
+                        )
+                    ]
+                ),
                 ops.DowngradeOps(ops=[]),
             )
         ]
@@ -877,7 +914,7 @@ class RewriterTest(TestBase):
             "    op.alter_column('t1', 'x',\n"
             "               existing_type=sa.Integer(),\n"
             "               nullable=False)\n"
-            "    # ### end Alembic commands ###"
+            "    # ### end Alembic commands ###",
         )
 
 
@@ -894,7 +931,9 @@ class MultiDirRevisionCommandTest(TestBase):
             util.CommandError,
             "Multiple version locations present, please specify "
             "--version-path",
-            command.revision, self.cfg, message="some message"
+            command.revision,
+            self.cfg,
+            message="some message",
         )
 
     def test_multiple_dir_no_bases_invalid_version_path(self):
@@ -902,40 +941,46 @@ class MultiDirRevisionCommandTest(TestBase):
             util.CommandError,
             "Path foo/bar/ is not represented in current version locations",
             command.revision,
-            self.cfg, message="x",
-            version_path=os.path.join("foo/bar/")
+            self.cfg,
+            message="x",
+            version_path=os.path.join("foo/bar/"),
         )
 
     def test_multiple_dir_no_bases_version_path(self):
         script = command.revision(
-            self.cfg, message="x",
-            version_path=os.path.join(_get_staging_directory(), "model1"))
+            self.cfg,
+            message="x",
+            version_path=os.path.join(_get_staging_directory(), "model1"),
+        )
         assert os.access(script.path, os.F_OK)
 
     def test_multiple_dir_chooses_base(self):
         command.revision(
-            self.cfg, message="x",
+            self.cfg,
+            message="x",
             head="base",
-            version_path=os.path.join(_get_staging_directory(), "model1"))
+            version_path=os.path.join(_get_staging_directory(), "model1"),
+        )
 
         script2 = command.revision(
-            self.cfg, message="y",
+            self.cfg,
+            message="y",
             head="base",
-            version_path=os.path.join(_get_staging_directory(), "model2"))
+            version_path=os.path.join(_get_staging_directory(), "model2"),
+        )
 
         script3 = command.revision(
-            self.cfg, message="y2",
-            head=script2.revision)
+            self.cfg, message="y2", head=script2.revision
+        )
 
         eq_(
             os.path.dirname(script3.path),
-            os.path.abspath(os.path.join(_get_staging_directory(), "model2"))
+            os.path.abspath(os.path.join(_get_staging_directory(), "model2")),
         )
         assert os.access(script3.path, os.F_OK)
 
 
 class TemplateArgsTest(TestBase):
-
     def setUp(self):
         staging_env()
         self.cfg = _no_sql_testing_config(
@@ -949,43 +994,45 @@ class TemplateArgsTest(TestBase):
         config = _no_sql_testing_config()
         script = ScriptDirectory.from_config(config)
         template_args = {"x": "x1", "y": "y1", "z": "z1"}
-        env = EnvironmentContext(
-            config,
-            script,
-            template_args=template_args
-        )
-        env.configure(dialect_name="sqlite",
-                      template_args={"y": "y2", "q": "q1"})
-        eq_(
-            template_args,
-            {"x": "x1", "y": "y2", "z": "z1", "q": "q1"}
+        env = EnvironmentContext(config, script, template_args=template_args)
+        env.configure(
+            dialect_name="sqlite", template_args={"y": "y2", "q": "q1"}
         )
+        eq_(template_args, {"x": "x1", "y": "y2", "z": "z1", "q": "q1"})
 
     def test_tmpl_args_revision(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 context.configure(dialect_name='sqlite', template_args={"somearg":"somevalue"})
-""")
-        script_file_fixture("""
+"""
+        )
+        script_file_fixture(
+            """
 # somearg: ${somearg}
 revision = ${repr(up_revision)}
 down_revision = ${repr(down_revision)}
-""")
+"""
+        )
 
         command.revision(self.cfg, message="some rev")
         script = ScriptDirectory.from_config(self.cfg)
 
-        rev = script.get_revision('head')
+        rev = script.get_revision("head")
         with open(rev.path) as f:
             text = f.read()
         assert "somearg: somevalue" in text
 
     def test_bad_render(self):
-        env_file_fixture("""
+        env_file_fixture(
+            """
 context.configure(dialect_name='sqlite', template_args={"somearg":"somevalue"})
-""")
-        script_file_fixture("""
+"""
+        )
+        script_file_fixture(
+            """
     <% z = x + y %>
-""")
+"""
+        )
 
         try:
             command.revision(self.cfg, message="some rev")
@@ -993,7 +1040,7 @@ context.configure(dialect_name='sqlite', template_args={"somearg":"somevalue"})
             m = re.match(
                 r"^Template rendering failed; see (.+?) "
                 "for a template-oriented",
-                str(ce)
+                str(ce),
             )
             assert m, "Command error did not produce a file"
             with open(m.group(1)) as handle:
@@ -1003,13 +1050,12 @@ context.configure(dialect_name='sqlite', template_args={"somearg":"somevalue"})
 
 
 class DuplicateVersionLocationsTest(TestBase):
-
     def setUp(self):
         self.env = staging_env()
         self.cfg = _multi_dir_testing_config(
             # this is a duplicate of one of the paths
             # already present in this fixture
-            extra_version_location='%(here)s/model1'
+            extra_version_location="%(here)s/model1"
         )
 
         script = ScriptDirectory.from_config(self.cfg)
@@ -1022,10 +1068,16 @@ class DuplicateVersionLocationsTest(TestBase):
             (self.model3, "model3"),
         ]:
             script.generate_revision(
-                model, name, refresh=True,
+                model,
+                name,
+                refresh=True,
                 version_path=os.path.join(_get_staging_directory(), name),
-                head="base")
-            write_script(script, model, """\
+                head="base",
+            )
+            write_script(
+                script,
+                model,
+                """\
 "%s"
 revision = '%s'
 down_revision = None
@@ -1041,7 +1093,9 @@ def upgrade():
 def downgrade():
     pass
 
-""" % (name, model, name))
+"""
+                % (name, model, name),
+            )
 
     def tearDown(self):
         clear_staging_env()
@@ -1049,16 +1103,20 @@ def downgrade():
     def test_env_emits_warning(self):
         with assertions.expect_warnings(
             "File %s loaded twice! ignoring. "
-            "Please ensure version_locations is unique" % (
-                os.path.realpath(os.path.join(
-                _get_staging_directory(),
-                "model1",
-                "%s_model1.py" % self.model1
-                )))
+            "Please ensure version_locations is unique"
+            % (
+                os.path.realpath(
+                    os.path.join(
+                        _get_staging_directory(),
+                        "model1",
+                        "%s_model1.py" % self.model1,
+                    )
+                )
+            )
         ):
             script = ScriptDirectory.from_config(self.cfg)
             script.revision_map.heads
             eq_(
                 [rev.revision for rev in script.walk_revisions()],
-                [self.model1, self.model2, self.model3]
+                [self.model1, self.model2, self.model3],
             )
index 75972d464144b390afd93520667c71cc78577cfe..6718be91d9eb25d77e810e2c38747db94494c334 100644 (file)
@@ -7,34 +7,29 @@ from alembic.testing.fixtures import TestBase
 
 
 class SQLiteTest(TestBase):
-
     def test_add_column(self):
-        context = op_fixture('sqlite')
-        op.add_column('t1', Column('c1', Integer))
-        context.assert_(
-            'ALTER TABLE t1 ADD COLUMN c1 INTEGER'
-        )
+        context = op_fixture("sqlite")
+        op.add_column("t1", Column("c1", Integer))
+        context.assert_("ALTER TABLE t1 ADD COLUMN c1 INTEGER")
 
     def test_add_column_implicit_constraint(self):
-        context = op_fixture('sqlite')
-        op.add_column('t1', Column('c1', Boolean))
-        context.assert_(
-            'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN'
-        )
+        context = op_fixture("sqlite")
+        op.add_column("t1", Column("c1", Boolean))
+        context.assert_("ALTER TABLE t1 ADD COLUMN c1 BOOLEAN")
 
     def test_add_explicit_constraint(self):
-        op_fixture('sqlite')
+        op_fixture("sqlite")
         assert_raises_message(
             NotImplementedError,
             "No support for ALTER of constraints in SQLite dialect",
             op.create_check_constraint,
             "foo",
             "sometable",
-            column('name') > 5
+            column("name") > 5,
         )
 
     def test_drop_explicit_constraint(self):
-        op_fixture('sqlite')
+        op_fixture("sqlite")
         assert_raises_message(
             NotImplementedError,
             "No support for ALTER of constraints in SQLite dialect",
index 29530c0ec4001855381691f5a15af8fa4eb33c0c..0a545cffab194065a168c4cfde39861ddc4d0c8b 100644 (file)
@@ -8,24 +8,22 @@ from alembic import migration
 
 from alembic.util import CommandError
 
-version_table = Table('version_table', MetaData(),
-                      Column('version_num', String(32), nullable=False))
+version_table = Table(
+    "version_table",
+    MetaData(),
+    Column("version_num", String(32), nullable=False),
+)
 
 
 def _up(from_, to_, branch_presence_changed=False):
-    return migration.StampStep(
-        from_, to_, True, branch_presence_changed
-    )
+    return migration.StampStep(from_, to_, True, branch_presence_changed)
 
 
 def _down(from_, to_, branch_presence_changed=False):
-    return migration.StampStep(
-        from_, to_, False, branch_presence_changed
-    )
+    return migration.StampStep(from_, to_, False, branch_presence_changed)
 
 
 class TestMigrationContext(TestBase):
-
     @classmethod
     def setup_class(cls):
         cls.bind = config.db
@@ -48,112 +46,126 @@ class TestMigrationContext(TestBase):
         if len(rows) == 0:
             return None
         eq_(len(rows), 1)
-        return rows[0]['version_num']
+        return rows[0]["version_num"]
 
     def test_config_default_version_table_name(self):
-        context = self.make_one(dialect_name='sqlite')
-        eq_(context._version.name, 'alembic_version')
+        context = self.make_one(dialect_name="sqlite")
+        eq_(context._version.name, "alembic_version")
 
     def test_config_explicit_version_table_name(self):
-        context = self.make_one(dialect_name='sqlite',
-                                opts={'version_table': 'explicit'})
-        eq_(context._version.name, 'explicit')
-        eq_(context._version.primary_key.name, 'explicit_pkc')
+        context = self.make_one(
+            dialect_name="sqlite", opts={"version_table": "explicit"}
+        )
+        eq_(context._version.name, "explicit")
+        eq_(context._version.primary_key.name, "explicit_pkc")
 
     def test_config_explicit_version_table_schema(self):
-        context = self.make_one(dialect_name='sqlite',
-                                opts={'version_table_schema': 'explicit'})
-        eq_(context._version.schema, 'explicit')
+        context = self.make_one(
+            dialect_name="sqlite", opts={"version_table_schema": "explicit"}
+        )
+        eq_(context._version.schema, "explicit")
 
     def test_config_explicit_no_pk(self):
-        context = self.make_one(dialect_name='sqlite',
-                                opts={'version_table_pk': False})
+        context = self.make_one(
+            dialect_name="sqlite", opts={"version_table_pk": False}
+        )
         eq_(len(context._version.primary_key), 0)
 
     def test_config_explicit_w_pk(self):
-        context = self.make_one(dialect_name='sqlite',
-                                opts={'version_table_pk': True})
+        context = self.make_one(
+            dialect_name="sqlite", opts={"version_table_pk": True}
+        )
         eq_(len(context._version.primary_key), 1)
         eq_(context._version.primary_key.name, "alembic_version_pkc")
 
     def test_get_current_revision_doesnt_create_version_table(self):
-        context = self.make_one(connection=self.connection,
-                                opts={'version_table': 'version_table'})
+        context = self.make_one(
+            connection=self.connection, opts={"version_table": "version_table"}
+        )
         eq_(context.get_current_revision(), None)
         insp = Inspector(self.connection)
-        assert ('version_table' not in insp.get_table_names())
+        assert "version_table" not in insp.get_table_names()
 
     def test_get_current_revision(self):
-        context = self.make_one(connection=self.connection,
-                                opts={'version_table': 'version_table'})
+        context = self.make_one(
+            connection=self.connection, opts={"version_table": "version_table"}
+        )
         version_table.create(self.connection)
         eq_(context.get_current_revision(), None)
         self.connection.execute(
-            version_table.insert().values(version_num='revid'))
-        eq_(context.get_current_revision(), 'revid')
+            version_table.insert().values(version_num="revid")
+        )
+        eq_(context.get_current_revision(), "revid")
 
     def test_get_current_revision_error_if_starting_rev_given_online(self):
-        context = self.make_one(connection=self.connection,
-                                opts={'starting_rev': 'boo'})
-        assert_raises(
-            CommandError,
-            context.get_current_revision
+        context = self.make_one(
+            connection=self.connection, opts={"starting_rev": "boo"}
         )
+        assert_raises(CommandError, context.get_current_revision)
 
     def test_get_current_revision_offline(self):
-        context = self.make_one(dialect_name='sqlite',
-                                opts={'starting_rev': 'startrev',
-                                      'as_sql': True})
-        eq_(context.get_current_revision(), 'startrev')
+        context = self.make_one(
+            dialect_name="sqlite",
+            opts={"starting_rev": "startrev", "as_sql": True},
+        )
+        eq_(context.get_current_revision(), "startrev")
 
     def test_get_current_revision_multiple_heads(self):
         version_table.create(self.connection)
-        context = self.make_one(connection=self.connection,
-                                opts={'version_table': 'version_table'})
+        context = self.make_one(
+            connection=self.connection, opts={"version_table": "version_table"}
+        )
         updater = migration.HeadMaintainer(context, ())
-        updater.update_to_step(_up(None, 'a', True))
-        updater.update_to_step(_up(None, 'b', True))
+        updater.update_to_step(_up(None, "a", True))
+        updater.update_to_step(_up(None, "b", True))
         assert_raises_message(
             CommandError,
             "Version table 'version_table' has more than one head present; "
             "please use get_current_heads()",
-            context.get_current_revision
+            context.get_current_revision,
         )
 
     def test_get_heads(self):
         version_table.create(self.connection)
-        context = self.make_one(connection=self.connection,
-                                opts={'version_table': 'version_table'})
+        context = self.make_one(
+            connection=self.connection, opts={"version_table": "version_table"}
+        )
         updater = migration.HeadMaintainer(context, ())
-        updater.update_to_step(_up(None, 'a', True))
-        updater.update_to_step(_up(None, 'b', True))
-        eq_(context.get_current_heads(), ('a', 'b'))
+        updater.update_to_step(_up(None, "a", True))
+        updater.update_to_step(_up(None, "b", True))
+        eq_(context.get_current_heads(), ("a", "b"))
 
     def test_get_heads_offline(self):
         version_table.create(self.connection)
-        context = self.make_one(connection=self.connection,
-                                opts={
-                                    'starting_rev': 'q',
-                                    'version_table': 'version_table',
-                                    'as_sql': True})
-        eq_(context.get_current_heads(), ('q', ))
+        context = self.make_one(
+            connection=self.connection,
+            opts={
+                "starting_rev": "q",
+                "version_table": "version_table",
+                "as_sql": True,
+            },
+        )
+        eq_(context.get_current_heads(), ("q",))
 
     def test_stamp_api_creates_table(self):
         context = self.make_one(connection=self.connection)
         assert (
-            'alembic_version'
-            not in Inspector(self.connection).get_table_names())
+            "alembic_version"
+            not in Inspector(self.connection).get_table_names()
+        )
 
-        script = mock.Mock(_stamp_revs=lambda revision, heads: [
-            _up(None, 'a', True),
-            _up(None, 'b', True)
-        ])
+        script = mock.Mock(
+            _stamp_revs=lambda revision, heads: [
+                _up(None, "a", True),
+                _up(None, "b", True),
+            ]
+        )
 
-        context.stamp(script, 'b')
-        eq_(context.get_current_heads(), ('a', 'b'))
+        context.stamp(script, "b")
+        eq_(context.get_current_heads(), ("a", "b"))
         assert (
-            'alembic_version'
-            in Inspector(self.connection).get_table_names())
+            "alembic_version" in Inspector(self.connection).get_table_names()
+        )
 
 
 class UpdateRevTest(TestBase):
@@ -166,8 +178,8 @@ class UpdateRevTest(TestBase):
     def setUp(self):
         self.connection = self.bind.connect()
         self.context = migration.MigrationContext.configure(
-            connection=self.connection,
-            opts={"version_table": "version_table"})
+            connection=self.connection, opts={"version_table": "version_table"}
+        )
         version_table.create(self.connection)
         self.updater = migration.HeadMaintainer(self.context, ())
 
@@ -180,105 +192,108 @@ class UpdateRevTest(TestBase):
         eq_(self.updater.heads, set(heads))
 
     def test_update_none_to_single(self):
-        self.updater.update_to_step(_up(None, 'a', True))
-        self._assert_heads(('a',))
+        self.updater.update_to_step(_up(None, "a", True))
+        self._assert_heads(("a",))
 
     def test_update_single_to_single(self):
-        self.updater.update_to_step(_up(None, 'a', True))
-        self.updater.update_to_step(_up('a', 'b'))
-        self._assert_heads(('b',))
+        self.updater.update_to_step(_up(None, "a", True))
+        self.updater.update_to_step(_up("a", "b"))
+        self._assert_heads(("b",))
 
     def test_update_single_to_none(self):
-        self.updater.update_to_step(_up(None, 'a', True))
-        self.updater.update_to_step(_down('a', None, True))
+        self.updater.update_to_step(_up(None, "a", True))
+        self.updater.update_to_step(_down("a", None, True))
         self._assert_heads(())
 
     def test_add_branches(self):
-        self.updater.update_to_step(_up(None, 'a', True))
-        self.updater.update_to_step(_up('a', 'b'))
-        self.updater.update_to_step(_up(None, 'c', True))
-        self._assert_heads(('b', 'c'))
-        self.updater.update_to_step(_up('c', 'd'))
-        self.updater.update_to_step(_up('d', 'e1'))
-        self.updater.update_to_step(_up('d', 'e2', True))
-        self._assert_heads(('b', 'e1', 'e2'))
+        self.updater.update_to_step(_up(None, "a", True))
+        self.updater.update_to_step(_up("a", "b"))
+        self.updater.update_to_step(_up(None, "c", True))
+        self._assert_heads(("b", "c"))
+        self.updater.update_to_step(_up("c", "d"))
+        self.updater.update_to_step(_up("d", "e1"))
+        self.updater.update_to_step(_up("d", "e2", True))
+        self._assert_heads(("b", "e1", "e2"))
 
     def test_teardown_branches(self):
-        self.updater.update_to_step(_up(None, 'd1', True))
-        self.updater.update_to_step(_up(None, 'd2', True))
-        self._assert_heads(('d1', 'd2'))
+        self.updater.update_to_step(_up(None, "d1", True))
+        self.updater.update_to_step(_up(None, "d2", True))
+        self._assert_heads(("d1", "d2"))
 
-        self.updater.update_to_step(_down('d1', 'c'))
-        self._assert_heads(('c', 'd2'))
+        self.updater.update_to_step(_down("d1", "c"))
+        self._assert_heads(("c", "d2"))
 
-        self.updater.update_to_step(_down('d2', 'c', True))
+        self.updater.update_to_step(_down("d2", "c", True))
 
-        self._assert_heads(('c',))
-        self.updater.update_to_step(_down('c', 'b'))
-        self._assert_heads(('b',))
+        self._assert_heads(("c",))
+        self.updater.update_to_step(_down("c", "b"))
+        self._assert_heads(("b",))
 
     def test_resolve_merges(self):
-        self.updater.update_to_step(_up(None, 'a', True))
-        self.updater.update_to_step(_up('a', 'b'))
-        self.updater.update_to_step(_up('b', 'c1'))
-        self.updater.update_to_step(_up('b', 'c2', True))
-        self.updater.update_to_step(_up('c1', 'd1'))
-        self.updater.update_to_step(_up('c2', 'd2'))
-        self._assert_heads(('d1', 'd2'))
-        self.updater.update_to_step(_up(('d1', 'd2'), 'e'))
-        self._assert_heads(('e',))
+        self.updater.update_to_step(_up(None, "a", True))
+        self.updater.update_to_step(_up("a", "b"))
+        self.updater.update_to_step(_up("b", "c1"))
+        self.updater.update_to_step(_up("b", "c2", True))
+        self.updater.update_to_step(_up("c1", "d1"))
+        self.updater.update_to_step(_up("c2", "d2"))
+        self._assert_heads(("d1", "d2"))
+        self.updater.update_to_step(_up(("d1", "d2"), "e"))
+        self._assert_heads(("e",))
 
     def test_unresolve_merges(self):
-        self.updater.update_to_step(_up(None, 'e', True))
+        self.updater.update_to_step(_up(None, "e", True))
 
-        self.updater.update_to_step(_down('e', ('d1', 'd2')))
-        self._assert_heads(('d2', 'd1'))
+        self.updater.update_to_step(_down("e", ("d1", "d2")))
+        self._assert_heads(("d2", "d1"))
 
-        self.updater.update_to_step(_down('d2', 'c2'))
-        self._assert_heads(('c2', 'd1'))
+        self.updater.update_to_step(_down("d2", "c2"))
+        self._assert_heads(("c2", "d1"))
 
     def test_update_no_match(self):
-        self.updater.update_to_step(_up(None, 'a', True))
-        self.updater.heads.add('x')
+        self.updater.update_to_step(_up(None, "a", True))
+        self.updater.heads.add("x")
         assert_raises_message(
             CommandError,
             "Online migration expected to match one row when updating "
             "'x' to 'b' in 'version_table'; 0 found",
-            self.updater.update_to_step, _up('x', 'b')
+            self.updater.update_to_step,
+            _up("x", "b"),
         )
 
     def test_update_multi_match(self):
-        self.connection.execute(version_table.insert(), version_num='a')
-        self.connection.execute(version_table.insert(), version_num='a')
+        self.connection.execute(version_table.insert(), version_num="a")
+        self.connection.execute(version_table.insert(), version_num="a")
 
-        self.updater.heads.add('a')
+        self.updater.heads.add("a")
         assert_raises_message(
             CommandError,
             "Online migration expected to match one row when updating "
             "'a' to 'b' in 'version_table'; 2 found",
-            self.updater.update_to_step, _up('a', 'b')
+            self.updater.update_to_step,
+            _up("a", "b"),
         )
 
     def test_delete_no_match(self):
-        self.updater.update_to_step(_up(None, 'a', True))
+        self.updater.update_to_step(_up(None, "a", True))
 
-        self.updater.heads.add('x')
+        self.updater.heads.add("x")
         assert_raises_message(
             CommandError,
             "Online migration expected to match one row when "
             "deleting 'x' in 'version_table'; 0 found",
-            self.updater.update_to_step, _down('x', None, True)
+            self.updater.update_to_step,
+            _down("x", None, True),
         )
 
     def test_delete_multi_match(self):
-        self.connection.execute(version_table.insert(), version_num='a')
-        self.connection.execute(version_table.insert(), version_num='a')
+        self.connection.execute(version_table.insert(), version_num="a")
+        self.connection.execute(version_table.insert(), version_num="a")
 
-        self.updater.heads.add('a')
+        self.updater.heads.add("a")
         assert_raises_message(
             CommandError,
             "Online migration expected to match one row when "
             "deleting 'a' in 'version_table'; 2 found",
-            self.updater.update_to_step, _down('a', None, True)
+            self.updater.update_to_step,
+            _down("a", None, True),
         )
-
index f69a9bd24a3ba1a6b2d8de7c9aac2764aadac0fd..c32c0c89e32b76865c48769bf78b703570a07e11 100644 (file)
@@ -7,20 +7,15 @@ from alembic.migration import MigrationStep, HeadMaintainer
 
 
 class MigrationTest(TestBase):
-
     def up_(self, rev):
-        return MigrationStep.upgrade_from_script(
-            self.env.revision_map, rev)
+        return MigrationStep.upgrade_from_script(self.env.revision_map, rev)
 
     def down_(self, rev):
-        return MigrationStep.downgrade_from_script(
-            self.env.revision_map, rev)
+        return MigrationStep.downgrade_from_script(self.env.revision_map, rev)
 
     def _assert_downgrade(self, destination, source, expected, expected_heads):
         revs = self.env._downgrade_revs(destination, source)
-        eq_(
-            revs, expected
-        )
+        eq_(revs, expected)
         heads = set(util.to_tuple(source, default=()))
         head = HeadMaintainer(mock.Mock(), heads)
         for rev in revs:
@@ -29,9 +24,7 @@ class MigrationTest(TestBase):
 
     def _assert_upgrade(self, destination, source, expected, expected_heads):
         revs = self.env._upgrade_revs(destination, source)
-        eq_(
-            revs, expected
-        )
+        eq_(revs, expected)
         heads = set(util.to_tuple(source, default=()))
         head = HeadMaintainer(mock.Mock(), heads)
         for rev in revs:
@@ -40,15 +33,14 @@ class MigrationTest(TestBase):
 
 
 class RevisionPathTest(MigrationTest):
-
     @classmethod
     def setup_class(cls):
         cls.env = env = staging_env()
-        cls.a = env.generate_revision(util.rev_id(), '->a')
-        cls.b = env.generate_revision(util.rev_id(), 'a->b')
-        cls.c = env.generate_revision(util.rev_id(), 'b->c')
-        cls.d = env.generate_revision(util.rev_id(), 'c->d')
-        cls.e = env.generate_revision(util.rev_id(), 'd->e')
+        cls.a = env.generate_revision(util.rev_id(), "->a")
+        cls.b = env.generate_revision(util.rev_id(), "a->b")
+        cls.c = env.generate_revision(util.rev_id(), "b->c")
+        cls.d = env.generate_revision(util.rev_id(), "c->d")
+        cls.e = env.generate_revision(util.rev_id(), "d->e")
 
     @classmethod
     def teardown_class(cls):
@@ -56,58 +48,50 @@ class RevisionPathTest(MigrationTest):
 
     def test_upgrade_path(self):
         self._assert_upgrade(
-            self.e.revision, self.c.revision,
-            [
-                self.up_(self.d),
-                self.up_(self.e)
-            ],
-            set([self.e.revision])
+            self.e.revision,
+            self.c.revision,
+            [self.up_(self.d), self.up_(self.e)],
+            set([self.e.revision]),
         )
 
         self._assert_upgrade(
-            self.c.revision, None,
-            [
-                self.up_(self.a),
-                self.up_(self.b),
-                self.up_(self.c),
-            ],
-            set([self.c.revision])
+            self.c.revision,
+            None,
+            [self.up_(self.a), self.up_(self.b), self.up_(self.c)],
+            set([self.c.revision]),
         )
 
     def test_relative_upgrade_path(self):
         self._assert_upgrade(
-            "+2", self.a.revision,
-            [
-                self.up_(self.b),
-                self.up_(self.c),
-            ],
-            set([self.c.revision])
+            "+2",
+            self.a.revision,
+            [self.up_(self.b), self.up_(self.c)],
+            set([self.c.revision]),
         )
 
         self._assert_upgrade(
-            "+1", self.a.revision,
-            [
-                self.up_(self.b)
-            ],
-            set([self.b.revision])
+            "+1", self.a.revision, [self.up_(self.b)], set([self.b.revision])
         )
 
         self._assert_upgrade(
-            "+3", self.b.revision,
+            "+3",
+            self.b.revision,
             [self.up_(self.c), self.up_(self.d), self.up_(self.e)],
-            set([self.e.revision])
+            set([self.e.revision]),
         )
 
         self._assert_upgrade(
-            "%s+2" % self.b.revision, self.a.revision,
+            "%s+2" % self.b.revision,
+            self.a.revision,
             [self.up_(self.b), self.up_(self.c), self.up_(self.d)],
-            set([self.d.revision])
+            set([self.d.revision]),
         )
 
         self._assert_upgrade(
-            "%s-2" % self.d.revision, self.a.revision,
+            "%s-2" % self.d.revision,
+            self.a.revision,
             [self.up_(self.b)],
-            set([self.b.revision])
+            set([self.b.revision]),
         )
 
     def test_invalid_relative_upgrade_path(self):
@@ -115,53 +99,60 @@ class RevisionPathTest(MigrationTest):
         assert_raises_message(
             util.CommandError,
             "Relative revision -2 didn't produce 2 migrations",
-            self.env._upgrade_revs, "-2", self.b.revision
+            self.env._upgrade_revs,
+            "-2",
+            self.b.revision,
         )
 
         assert_raises_message(
             util.CommandError,
             r"Relative revision \+5 didn't produce 5 migrations",
-            self.env._upgrade_revs, "+5", self.b.revision
+            self.env._upgrade_revs,
+            "+5",
+            self.b.revision,
         )
 
     def test_downgrade_path(self):
 
         self._assert_downgrade(
-            self.c.revision, self.e.revision,
+            self.c.revision,
+            self.e.revision,
             [self.down_(self.e), self.down_(self.d)],
-            set([self.c.revision])
+            set([self.c.revision]),
         )
 
         self._assert_downgrade(
-            None, self.c.revision,
+            None,
+            self.c.revision,
             [self.down_(self.c), self.down_(self.b), self.down_(self.a)],
-            set()
+            set(),
         )
 
     def test_relative_downgrade_path(self):
 
         self._assert_downgrade(
-            "-1", self.c.revision,
-            [self.down_(self.c)],
-            set([self.b.revision])
+            "-1", self.c.revision, [self.down_(self.c)], set([self.b.revision])
         )
 
         self._assert_downgrade(
-            "-3", self.e.revision,
+            "-3",
+            self.e.revision,
             [self.down_(self.e), self.down_(self.d), self.down_(self.c)],
-            set([self.b.revision])
+            set([self.b.revision]),
         )
 
         self._assert_downgrade(
-            "%s+2" % self.a.revision, self.d.revision,
+            "%s+2" % self.a.revision,
+            self.d.revision,
             [self.down_(self.d)],
-            set([self.c.revision])
+            set([self.c.revision]),
         )
 
         self._assert_downgrade(
-            "%s-2" % self.c.revision, self.d.revision,
+            "%s-2" % self.c.revision,
+            self.d.revision,
             [self.down_(self.d), self.down_(self.c), self.down_(self.b)],
-            set([self.a.revision])
+            set([self.a.revision]),
         )
 
     def test_invalid_relative_downgrade_path(self):
@@ -169,13 +160,17 @@ class RevisionPathTest(MigrationTest):
         assert_raises_message(
             util.CommandError,
             "Relative revision -5 didn't produce 5 migrations",
-            self.env._downgrade_revs, "-5", self.b.revision
+            self.env._downgrade_revs,
+            "-5",
+            self.b.revision,
         )
 
         assert_raises_message(
             util.CommandError,
             r"Relative revision \+2 didn't produce 2 migrations",
-            self.env._downgrade_revs, "+2", self.b.revision
+            self.env._downgrade_revs,
+            "+2",
+            self.b.revision,
         )
 
     def test_invalid_move_rev_to_none(self):
@@ -184,7 +179,9 @@ class RevisionPathTest(MigrationTest):
             util.CommandError,
             r"Destination %s is not a valid downgrade "
             r"target from current head\(s\)" % self.b.revision[0:3],
-            self.env._downgrade_revs, self.b.revision[0:3], None
+            self.env._downgrade_revs,
+            self.b.revision[0:3],
+            None,
         )
 
     def test_invalid_move_higher_to_lower(self):
@@ -193,7 +190,9 @@ class RevisionPathTest(MigrationTest):
             util.CommandError,
             r"Destination %s is not a valid downgrade "
             r"target from current head\(s\)" % self.c.revision[0:4],
-            self.env._downgrade_revs, self.c.revision[0:4], self.b.revision
+            self.env._downgrade_revs,
+            self.c.revision[0:4],
+            self.b.revision,
         )
 
     def test_stamp_to_base(self):
@@ -204,26 +203,27 @@ class RevisionPathTest(MigrationTest):
 
 
 class BranchedPathTest(MigrationTest):
-
     @classmethod
     def setup_class(cls):
         cls.env = env = staging_env()
-        cls.a = env.generate_revision(util.rev_id(), '->a')
-        cls.b = env.generate_revision(util.rev_id(), 'a->b')
+        cls.a = env.generate_revision(util.rev_id(), "->a")
+        cls.b = env.generate_revision(util.rev_id(), "a->b")
 
         cls.c1 = env.generate_revision(
-            util.rev_id(), 'b->c1',
-            branch_labels='c1branch',
-            refresh=True)
-        cls.d1 = env.generate_revision(util.rev_id(), 'c1->d1')
+            util.rev_id(), "b->c1", branch_labels="c1branch", refresh=True
+        )
+        cls.d1 = env.generate_revision(util.rev_id(), "c1->d1")
 
         cls.c2 = env.generate_revision(
-            util.rev_id(), 'b->c2',
-            branch_labels='c2branch',
-            head=cls.b.revision, splice=True)
+            util.rev_id(),
+            "b->c2",
+            branch_labels="c2branch",
+            head=cls.b.revision,
+            splice=True,
+        )
         cls.d2 = env.generate_revision(
-            util.rev_id(), 'c2->d2',
-            head=cls.c2.revision)
+            util.rev_id(), "c2->d2", head=cls.c2.revision
+        )
 
     @classmethod
     def teardown_class(cls):
@@ -231,73 +231,87 @@ class BranchedPathTest(MigrationTest):
 
     def test_stamp_down_across_multiple_branch_to_branchpoint(self):
         heads = [self.d1.revision, self.c2.revision]
-        revs = self.env._stamp_revs(
-            self.b.revision, heads)
+        revs = self.env._stamp_revs(self.b.revision, heads)
         eq_(len(revs), 1)
         eq_(
             revs[0].merge_branch_idents(heads),
             # DELETE d1 revision, UPDATE c2 to b
-            ([self.d1.revision], self.c2.revision, self.b.revision)
+            ([self.d1.revision], self.c2.revision, self.b.revision),
         )
 
     def test_stamp_to_labeled_base_multiple_heads(self):
         revs = self.env._stamp_revs(
-            "c1branch@base", [self.d1.revision, self.c2.revision])
+            "c1branch@base", [self.d1.revision, self.c2.revision]
+        )
         eq_(len(revs), 1)
         assert revs[0].should_delete_branch
         eq_(revs[0].delete_version_num, self.d1.revision)
 
     def test_stamp_to_labeled_head_multiple_heads(self):
         heads = [self.d1.revision, self.c2.revision]
-        revs = self.env._stamp_revs(
-            "c2branch@head", heads)
+        revs = self.env._stamp_revs("c2branch@head", heads)
         eq_(len(revs), 1)
         eq_(
             revs[0].merge_branch_idents(heads),
             # the c1branch remains unchanged
-            ([], self.c2.revision, self.d2.revision)
+            ([], self.c2.revision, self.d2.revision),
         )
 
     def test_upgrade_single_branch(self):
 
         self._assert_upgrade(
-            self.d1.revision, self.b.revision,
+            self.d1.revision,
+            self.b.revision,
             [self.up_(self.c1), self.up_(self.d1)],
-            set([self.d1.revision])
+            set([self.d1.revision]),
         )
 
     def test_upgrade_multiple_branch(self):
         # move from a single head to multiple heads
 
         self._assert_upgrade(
-            (self.d1.revision, self.d2.revision), self.a.revision,
-            [self.up_(self.b), self.up_(self.c2), self.up_(self.d2),
-             self.up_(self.c1), self.up_(self.d1)],
-            set([self.d1.revision, self.d2.revision])
+            (self.d1.revision, self.d2.revision),
+            self.a.revision,
+            [
+                self.up_(self.b),
+                self.up_(self.c2),
+                self.up_(self.d2),
+                self.up_(self.c1),
+                self.up_(self.d1),
+            ],
+            set([self.d1.revision, self.d2.revision]),
         )
 
     def test_downgrade_multiple_branch(self):
         self._assert_downgrade(
-            self.a.revision, (self.d1.revision, self.d2.revision),
-            [self.down_(self.d1), self.down_(self.c1), self.down_(self.d2),
-             self.down_(self.c2), self.down_(self.b)],
-            set([self.a.revision])
+            self.a.revision,
+            (self.d1.revision, self.d2.revision),
+            [
+                self.down_(self.d1),
+                self.down_(self.c1),
+                self.down_(self.d2),
+                self.down_(self.c2),
+                self.down_(self.b),
+            ],
+            set([self.a.revision]),
         )
 
     def test_relative_upgrade(self):
 
         self._assert_upgrade(
-            "c2branch@head-1", self.b.revision,
+            "c2branch@head-1",
+            self.b.revision,
             [self.up_(self.c2)],
-            set([self.c2.revision])
+            set([self.c2.revision]),
         )
 
     def test_relative_downgrade(self):
 
         self._assert_downgrade(
-            "c2branch@base+2", [self.d2.revision, self.d1.revision],
+            "c2branch@base+2",
+            [self.d2.revision, self.d1.revision],
             [self.down_(self.d2), self.down_(self.c2), self.down_(self.d1)],
-            set([self.c1.revision])
+            set([self.c1.revision]),
         )
 
 
@@ -311,43 +325,54 @@ class BranchFromMergepointTest(MigrationTest):
     @classmethod
     def setup_class(cls):
         cls.env = env = staging_env()
-        cls.a1 = env.generate_revision(util.rev_id(), '->a1')
-        cls.b1 = env.generate_revision(util.rev_id(), 'a1->b1')
-        cls.c1 = env.generate_revision(util.rev_id(), 'b1->c1')
+        cls.a1 = env.generate_revision(util.rev_id(), "->a1")
+        cls.b1 = env.generate_revision(util.rev_id(), "a1->b1")
+        cls.c1 = env.generate_revision(util.rev_id(), "b1->c1")
 
         cls.a2 = env.generate_revision(
-            util.rev_id(), '->a2', head=(),
-            refresh=True)
+            util.rev_id(), "->a2", head=(), refresh=True
+        )
         cls.b2 = env.generate_revision(
-            util.rev_id(), 'a2->b2', head=cls.a2.revision)
+            util.rev_id(), "a2->b2", head=cls.a2.revision
+        )
         cls.c2 = env.generate_revision(
-            util.rev_id(), 'b2->c2', head=cls.b2.revision)
+            util.rev_id(), "b2->c2", head=cls.b2.revision
+        )
 
         # mergepoint between c1, c2
         # d1 dependent on c2
         cls.d1 = env.generate_revision(
-            util.rev_id(), 'd1', head=(cls.c1.revision, cls.c2.revision),
-            refresh=True)
+            util.rev_id(),
+            "d1",
+            head=(cls.c1.revision, cls.c2.revision),
+            refresh=True,
+        )
 
         # but then c2 keeps going into d2
         cls.d2 = env.generate_revision(
-            util.rev_id(), 'd2', head=cls.c2.revision,
-            refresh=True, splice=True)
+            util.rev_id(),
+            "d2",
+            head=cls.c2.revision,
+            refresh=True,
+            splice=True,
+        )
 
     def test_mergepoint_to_only_one_side_upgrade(self):
 
         self._assert_upgrade(
-            self.d1.revision, (self.d2.revision, self.b1.revision),
+            self.d1.revision,
+            (self.d2.revision, self.b1.revision),
             [self.up_(self.c1), self.up_(self.d1)],
-            set([self.d2.revision, self.d1.revision])
+            set([self.d2.revision, self.d1.revision]),
         )
 
     def test_mergepoint_to_only_one_side_downgrade(self):
 
         self._assert_downgrade(
-            self.b1.revision, (self.d2.revision, self.d1.revision),
+            self.b1.revision,
+            (self.d2.revision, self.d1.revision),
             [self.down_(self.d1), self.down_(self.c1)],
-            set([self.d2.revision, self.b1.revision])
+            set([self.d2.revision, self.b1.revision]),
         )
 
 
@@ -361,42 +386,56 @@ class BranchFrom3WayMergepointTest(MigrationTest):
     @classmethod
     def setup_class(cls):
         cls.env = env = staging_env()
-        cls.a1 = env.generate_revision(util.rev_id(), '->a1')
-        cls.b1 = env.generate_revision(util.rev_id(), 'a1->b1')
-        cls.c1 = env.generate_revision(util.rev_id(), 'b1->c1')
+        cls.a1 = env.generate_revision(util.rev_id(), "->a1")
+        cls.b1 = env.generate_revision(util.rev_id(), "a1->b1")
+        cls.c1 = env.generate_revision(util.rev_id(), "b1->c1")
 
         cls.a2 = env.generate_revision(
-            util.rev_id(), '->a2', head=(),
-            refresh=True)
+            util.rev_id(), "->a2", head=(), refresh=True
+        )
         cls.b2 = env.generate_revision(
-            util.rev_id(), 'a2->b2', head=cls.a2.revision)
+            util.rev_id(), "a2->b2", head=cls.a2.revision
+        )
         cls.c2 = env.generate_revision(
-            util.rev_id(), 'b2->c2', head=cls.b2.revision)
+            util.rev_id(), "b2->c2", head=cls.b2.revision
+        )
 
         cls.a3 = env.generate_revision(
-            util.rev_id(), '->a3', head=(),
-            refresh=True)
+            util.rev_id(), "->a3", head=(), refresh=True
+        )
         cls.b3 = env.generate_revision(
-            util.rev_id(), 'a3->b3', head=cls.a3.revision)
+            util.rev_id(), "a3->b3", head=cls.a3.revision
+        )
         cls.c3 = env.generate_revision(
-            util.rev_id(), 'b3->c3', head=cls.b3.revision)
+            util.rev_id(), "b3->c3", head=cls.b3.revision
+        )
 
         # mergepoint between c1, c2, c3
         # d1 dependent on c2, c3
         cls.d1 = env.generate_revision(
-            util.rev_id(), 'd1', head=(
-                cls.c1.revision, cls.c2.revision, cls.c3.revision),
-            refresh=True)
+            util.rev_id(),
+            "d1",
+            head=(cls.c1.revision, cls.c2.revision, cls.c3.revision),
+            refresh=True,
+        )
 
         # but then c2 keeps going into d2
         cls.d2 = env.generate_revision(
-            util.rev_id(), 'd2', head=cls.c2.revision,
-            refresh=True, splice=True)
+            util.rev_id(),
+            "d2",
+            head=cls.c2.revision,
+            refresh=True,
+            splice=True,
+        )
 
         # c3 keeps going into d3
         cls.d3 = env.generate_revision(
-            util.rev_id(), 'd3', head=cls.c3.revision,
-            refresh=True, splice=True)
+            util.rev_id(),
+            "d3",
+            head=cls.c3.revision,
+            refresh=True,
+            splice=True,
+        )
 
     def test_mergepoint_to_only_one_side_upgrade(self):
 
@@ -404,7 +443,7 @@ class BranchFrom3WayMergepointTest(MigrationTest):
             self.d1.revision,
             (self.d3.revision, self.d2.revision, self.b1.revision),
             [self.up_(self.c1), self.up_(self.d1)],
-            set([self.d3.revision, self.d2.revision, self.d1.revision])
+            set([self.d3.revision, self.d2.revision, self.d1.revision]),
         )
 
     def test_mergepoint_to_only_one_side_downgrade(self):
@@ -412,7 +451,7 @@ class BranchFrom3WayMergepointTest(MigrationTest):
             self.b1.revision,
             (self.d3.revision, self.d2.revision, self.d1.revision),
             [self.down_(self.d1), self.down_(self.c1)],
-            set([self.d3.revision, self.d2.revision, self.b1.revision])
+            set([self.d3.revision, self.d2.revision, self.b1.revision]),
         )
 
     def test_mergepoint_to_two_sides_upgrade(self):
@@ -422,14 +461,15 @@ class BranchFrom3WayMergepointTest(MigrationTest):
             (self.d3.revision, self.b2.revision, self.b1.revision),
             [self.up_(self.c2), self.up_(self.c1), self.up_(self.d1)],
             # this will merge b2 and b1 into d1
-            set([self.d3.revision, self.d1.revision])
+            set([self.d3.revision, self.d1.revision]),
         )
 
         # but then!  b2 will break out again if we keep going with it
         self._assert_upgrade(
-            self.d2.revision, (self.d3.revision, self.d1.revision),
+            self.d2.revision,
+            (self.d3.revision, self.d1.revision),
             [self.up_(self.d2)],
-            set([self.d3.revision, self.d2.revision, self.d1.revision])
+            set([self.d3.revision, self.d2.revision, self.d1.revision]),
         )
 
 
@@ -438,6 +478,7 @@ class TwinMergeTest(MigrationTest):
     originating branches.
 
     """
+
     @classmethod
     def setup_class(cls):
         """
@@ -463,44 +504,43 @@ class TwinMergeTest(MigrationTest):
         """
         cls.env = env = staging_env()
 
-        cls.a = env.generate_revision(
-            'a', 'a'
+        cls.a = env.generate_revision("a", "a")
+        cls.b1 = env.generate_revision("b1", "b1", head=cls.a.revision)
+        cls.b2 = env.generate_revision(
+            "b2", "b2", splice=True, head=cls.a.revision
+        )
+        cls.b3 = env.generate_revision(
+            "b3", "b3", splice=True, head=cls.a.revision
         )
-        cls.b1 = env.generate_revision('b1', 'b1',
-                                       head=cls.a.revision)
-        cls.b2 = env.generate_revision('b2', 'b2',
-                                       splice=True,
-                                       head=cls.a.revision)
-        cls.b3 = env.generate_revision('b3', 'b3',
-                                       splice=True,
-                                       head=cls.a.revision)
 
         cls.c1 = env.generate_revision(
-            'c1', 'c1',
-            head=(cls.b1.revision, cls.b2.revision, cls.b3.revision))
+            "c1",
+            "c1",
+            head=(cls.b1.revision, cls.b2.revision, cls.b3.revision),
+        )
 
         cls.c2 = env.generate_revision(
-            'c2', 'c2',
+            "c2",
+            "c2",
             splice=True,
-            head=(cls.b1.revision, cls.b2.revision, cls.b3.revision))
+            head=(cls.b1.revision, cls.b2.revision, cls.b3.revision),
+        )
 
-        cls.d1 = env.generate_revision(
-            'd1', 'd1', head=cls.c1.revision)
+        cls.d1 = env.generate_revision("d1", "d1", head=cls.c1.revision)
 
-        cls.d2 = env.generate_revision(
-            'd2', 'd2', head=cls.c2.revision)
+        cls.d2 = env.generate_revision("d2", "d2", head=cls.c2.revision)
 
     def test_upgrade(self):
         head = HeadMaintainer(mock.Mock(), [self.a.revision])
 
         steps = [
-            (self.up_(self.b3), ('b3',)),
-            (self.up_(self.b1), ('b1', 'b3',)),
-            (self.up_(self.b2), ('b1', 'b2', 'b3',)),
-            (self.up_(self.c2), ('c2',)),
-            (self.up_(self.d2), ('d2',)),
-            (self.up_(self.c1), ('c1', 'd2')),
-            (self.up_(self.d1), ('d1', 'd2')),
+            (self.up_(self.b3), ("b3",)),
+            (self.up_(self.b1), ("b1", "b3")),
+            (self.up_(self.b2), ("b1", "b2", "b3")),
+            (self.up_(self.c2), ("c2",)),
+            (self.up_(self.d2), ("d2",)),
+            (self.up_(self.c1), ("c1", "d2")),
+            (self.up_(self.d1), ("d1", "d2")),
         ]
         for step, assert_ in steps:
             head.update_to_step(step)
@@ -511,6 +551,7 @@ class NotQuiteTwinMergeTest(MigrationTest):
     """Test a variant of #297.
 
     """
+
     @classmethod
     def setup_class(cls):
         """
@@ -527,32 +568,26 @@ class NotQuiteTwinMergeTest(MigrationTest):
         """
         cls.env = env = staging_env()
 
-        cls.a = env.generate_revision(
-            'a', 'a'
+        cls.a = env.generate_revision("a", "a")
+        cls.b1 = env.generate_revision("b1", "b1", head=cls.a.revision)
+        cls.b2 = env.generate_revision(
+            "b2", "b2", splice=True, head=cls.a.revision
+        )
+        cls.b3 = env.generate_revision(
+            "b3", "b3", splice=True, head=cls.a.revision
         )
-        cls.b1 = env.generate_revision('b1', 'b1',
-                                       head=cls.a.revision)
-        cls.b2 = env.generate_revision('b2', 'b2',
-                                       splice=True,
-                                       head=cls.a.revision)
-        cls.b3 = env.generate_revision('b3', 'b3',
-                                       splice=True,
-                                       head=cls.a.revision)
 
         cls.c1 = env.generate_revision(
-            'c1', 'c1',
-            head=(cls.b1.revision, cls.b2.revision))
+            "c1", "c1", head=(cls.b1.revision, cls.b2.revision)
+        )
 
         cls.c2 = env.generate_revision(
-            'c2', 'c2',
-            splice=True,
-            head=(cls.b2.revision, cls.b3.revision))
+            "c2", "c2", splice=True, head=(cls.b2.revision, cls.b3.revision)
+        )
 
-        cls.d1 = env.generate_revision(
-            'd1', 'd1', head=cls.c1.revision)
+        cls.d1 = env.generate_revision("d1", "d1", head=cls.c1.revision)
 
-        cls.d2 = env.generate_revision(
-            'd2', 'd2', head=cls.c2.revision)
+        cls.d2 = env.generate_revision("d2", "d2", head=cls.c2.revision)
 
     def test_upgrade(self):
         head = HeadMaintainer(mock.Mock(), [self.a.revision])
@@ -568,14 +603,13 @@ class NotQuiteTwinMergeTest(MigrationTest):
         """
 
         steps = [
-            (self.up_(self.b2), ('b2',)),
-            (self.up_(self.b3), ('b2', 'b3',)),
-            (self.up_(self.c2), ('c2',)),
-            (self.up_(self.d2), ('d2',)),
-
-            (self.up_(self.b1), ('b1', 'd2',)),
-            (self.up_(self.c1), ('c1', 'd2')),
-            (self.up_(self.d1), ('d1', 'd2')),
+            (self.up_(self.b2), ("b2",)),
+            (self.up_(self.b3), ("b2", "b3")),
+            (self.up_(self.c2), ("c2",)),
+            (self.up_(self.d2), ("d2",)),
+            (self.up_(self.b1), ("b1", "d2")),
+            (self.up_(self.c1), ("c1", "d2")),
+            (self.up_(self.d1), ("d1", "d2")),
         ]
         for step, assert_ in steps:
             head.update_to_step(step)
@@ -583,32 +617,35 @@ class NotQuiteTwinMergeTest(MigrationTest):
 
 
 class DependsOnBranchTestOne(MigrationTest):
-
     @classmethod
     def setup_class(cls):
         cls.env = env = staging_env()
         cls.a1 = env.generate_revision(
-            util.rev_id(), '->a1',
-            branch_labels=['lib1'])
-        cls.b1 = env.generate_revision(util.rev_id(), 'a1->b1')
-        cls.c1 = env.generate_revision(util.rev_id(), 'b1->c1')
+            util.rev_id(), "->a1", branch_labels=["lib1"]
+        )
+        cls.b1 = env.generate_revision(util.rev_id(), "a1->b1")
+        cls.c1 = env.generate_revision(util.rev_id(), "b1->c1")
 
-        cls.a2 = env.generate_revision(util.rev_id(), '->a2', head=())
+        cls.a2 = env.generate_revision(util.rev_id(), "->a2", head=())
         cls.b2 = env.generate_revision(
-            util.rev_id(), 'a2->b2', head=cls.a2.revision)
+            util.rev_id(), "a2->b2", head=cls.a2.revision
+        )
         cls.c2 = env.generate_revision(
-            util.rev_id(), 'b2->c2', head=cls.b2.revision,
-            depends_on=cls.c1.revision)
+            util.rev_id(),
+            "b2->c2",
+            head=cls.b2.revision,
+            depends_on=cls.c1.revision,
+        )
 
         cls.d1 = env.generate_revision(
-            util.rev_id(), 'c1->d1',
-            head=cls.c1.revision)
+            util.rev_id(), "c1->d1", head=cls.c1.revision
+        )
         cls.e1 = env.generate_revision(
-            util.rev_id(), 'd1->e1',
-            head=cls.d1.revision)
+            util.rev_id(), "d1->e1", head=cls.d1.revision
+        )
         cls.f1 = env.generate_revision(
-            util.rev_id(), 'e1->f1',
-            head=cls.e1.revision)
+            util.rev_id(), "e1->f1", head=cls.e1.revision
+        )
 
     def test_downgrade_to_dependency(self):
         heads = [self.c2.revision, self.d1.revision]
@@ -625,7 +662,6 @@ class DependsOnBranchTestOne(MigrationTest):
 
 
 class DependsOnBranchTestTwo(MigrationTest):
-
     @classmethod
     def setup_class(cls):
         """
@@ -656,32 +692,36 @@ class DependsOnBranchTestTwo(MigrationTest):
 
         """
         cls.env = env = staging_env()
-        cls.a1 = env.generate_revision("a1", '->a1', head='base')
-        cls.a2 = env.generate_revision("a2", '->a2', head='base')
-        cls.a3 = env.generate_revision("a3", '->a3', head='base')
-        cls.amerge = env.generate_revision("amerge", 'amerge', head=[
-            cls.a1.revision, cls.a2.revision, cls.a3.revision
-        ])
-
-        cls.b1 = env.generate_revision("b1", '->b1', head='base')
-        cls.b2 = env.generate_revision("b2", '->b2', head='base')
-        cls.bmerge = env.generate_revision("bmerge", 'bmerge', head=[
-            cls.b1.revision, cls.b2.revision
-        ])
-
-        cls.c1 = env.generate_revision("c1", '->c1', head='base')
-        cls.c2 = env.generate_revision("c2", '->c2', head='base')
-        cls.c3 = env.generate_revision("c3", '->c3', head='base')
-        cls.cmerge = env.generate_revision("cmerge", 'cmerge', head=[
-            cls.c1.revision, cls.c2.revision, cls.c3.revision
-        ])
+        cls.a1 = env.generate_revision("a1", "->a1", head="base")
+        cls.a2 = env.generate_revision("a2", "->a2", head="base")
+        cls.a3 = env.generate_revision("a3", "->a3", head="base")
+        cls.amerge = env.generate_revision(
+            "amerge",
+            "amerge",
+            head=[cls.a1.revision, cls.a2.revision, cls.a3.revision],
+        )
+
+        cls.b1 = env.generate_revision("b1", "->b1", head="base")
+        cls.b2 = env.generate_revision("b2", "->b2", head="base")
+        cls.bmerge = env.generate_revision(
+            "bmerge", "bmerge", head=[cls.b1.revision, cls.b2.revision]
+        )
+
+        cls.c1 = env.generate_revision("c1", "->c1", head="base")
+        cls.c2 = env.generate_revision("c2", "->c2", head="base")
+        cls.c3 = env.generate_revision("c3", "->c3", head="base")
+        cls.cmerge = env.generate_revision(
+            "cmerge",
+            "cmerge",
+            head=[cls.c1.revision, cls.c2.revision, cls.c3.revision],
+        )
 
         cls.d1 = env.generate_revision(
-            "d1", 'o',
+            "d1",
+            "o",
             head="base",
-            depends_on=[
-                cls.a3.revision, cls.b2.revision, cls.c1.revision
-            ])
+            depends_on=[cls.a3.revision, cls.b2.revision, cls.c1.revision],
+        )
 
     def test_kaboom(self):
         # here's the upgrade path:
@@ -690,55 +730,77 @@ class DependsOnBranchTestTwo(MigrationTest):
 
         heads = [
             self.amerge.revision,
-            self.bmerge.revision, self.cmerge.revision,
-            self.d1.revision
+            self.bmerge.revision,
+            self.cmerge.revision,
+            self.d1.revision,
         ]
 
         self._assert_downgrade(
-            self.b2.revision, heads,
+            self.b2.revision,
+            heads,
             [self.down_(self.bmerge)],
-            set([
-                self.amerge.revision,
-                self.b1.revision, self.cmerge.revision, self.d1.revision])
+            set(
+                [
+                    self.amerge.revision,
+                    self.b1.revision,
+                    self.cmerge.revision,
+                    self.d1.revision,
+                ]
+            ),
         )
 
         # start with those heads..
         heads = [
-            self.amerge.revision, self.d1.revision,
-            self.b1.revision, self.cmerge.revision]
+            self.amerge.revision,
+            self.d1.revision,
+            self.b1.revision,
+            self.cmerge.revision,
+        ]
 
         # downgrade d1...
         self._assert_downgrade(
-            "d1@base", heads,
+            "d1@base",
+            heads,
             [self.down_(self.d1)],
-
             # b2 has to be INSERTed, because it was implied by d1
-            set([
-                self.amerge.revision, self.b1.revision,
-                self.b2.revision, self.cmerge.revision])
+            set(
+                [
+                    self.amerge.revision,
+                    self.b1.revision,
+                    self.b2.revision,
+                    self.cmerge.revision,
+                ]
+            ),
         )
 
         # start with those heads ...
         heads = [
-            self.amerge.revision, self.b1.revision,
-            self.b2.revision, self.cmerge.revision
+            self.amerge.revision,
+            self.b1.revision,
+            self.b2.revision,
+            self.cmerge.revision,
         ]
 
         self._assert_downgrade(
-            "base", heads,
+            "base",
+            heads,
             [
-                self.down_(self.amerge), self.down_(self.a1),
-                self.down_(self.a2), self.down_(self.a3),
-                self.down_(self.b1), self.down_(self.b2),
-                self.down_(self.cmerge), self.down_(self.c1),
-                self.down_(self.c2), self.down_(self.c3)
+                self.down_(self.amerge),
+                self.down_(self.a1),
+                self.down_(self.a2),
+                self.down_(self.a3),
+                self.down_(self.b1),
+                self.down_(self.b2),
+                self.down_(self.cmerge),
+                self.down_(self.c1),
+                self.down_(self.c2),
+                self.down_(self.c3),
             ],
-            set([])
+            set([]),
         )
 
 
 class DependsOnBranchTestThree(MigrationTest):
-
     @classmethod
     def setup_class(cls):
         """
@@ -755,14 +817,18 @@ class DependsOnBranchTestThree(MigrationTest):
 
         """
         cls.env = env = staging_env()
-        cls.a1 = env.generate_revision("a1", '->a1', head='base')
-        cls.a2 = env.generate_revision("a2", '->a2')
+        cls.a1 = env.generate_revision("a1", "->a1", head="base")
+        cls.a2 = env.generate_revision("a2", "->a2")
 
-        cls.b1 = env.generate_revision("b1", '->b1', head='base')
-        cls.b2 = env.generate_revision("b2", '->b2', depends_on='a2', head='b1')
-        cls.b3 = env.generate_revision("b3", '->b3', head='b2')
+        cls.b1 = env.generate_revision("b1", "->b1", head="base")
+        cls.b2 = env.generate_revision(
+            "b2", "->b2", depends_on="a2", head="b1"
+        )
+        cls.b3 = env.generate_revision("b3", "->b3", head="b2")
 
-        cls.a3 = env.generate_revision("a3", '->a3', head='a2', depends_on='b1')
+        cls.a3 = env.generate_revision(
+            "a3", "->a3", head="a2", depends_on="b1"
+        )
 
     def test_downgrade_over_crisscross(self):
         # this state was not possible prior to
@@ -772,9 +838,10 @@ class DependsOnBranchTestThree(MigrationTest):
         # b2 because a2 is dependent on it, hence we add the ability
         # to remove half of a merge point.
         self._assert_downgrade(
-            'b1', ['a3', 'b2'],
+            "b1",
+            ["a3", "b2"],
             [self.down_(self.b2)],
-            set(['a3'])  # we have b1 also, which is implied by a3
+            set(["a3"]),  # we have b1 also, which is implied by a3
         )
 
 
@@ -783,33 +850,35 @@ class DependsOnBranchLabelTest(MigrationTest):
     def setup_class(cls):
         cls.env = env = staging_env()
         cls.a1 = env.generate_revision(
-            util.rev_id(), '->a1',
-            branch_labels=['lib1'])
-        cls.b1 = env.generate_revision(util.rev_id(), 'a1->b1')
+            util.rev_id(), "->a1", branch_labels=["lib1"]
+        )
+        cls.b1 = env.generate_revision(util.rev_id(), "a1->b1")
         cls.c1 = env.generate_revision(
-            util.rev_id(), 'b1->c1',
-            branch_labels=['c1lib'])
+            util.rev_id(), "b1->c1", branch_labels=["c1lib"]
+        )
 
-        cls.a2 = env.generate_revision(util.rev_id(), '->a2', head=())
+        cls.a2 = env.generate_revision(util.rev_id(), "->a2", head=())
         cls.b2 = env.generate_revision(
-            util.rev_id(), 'a2->b2', head=cls.a2.revision)
+            util.rev_id(), "a2->b2", head=cls.a2.revision
+        )
         cls.c2 = env.generate_revision(
-            util.rev_id(), 'b2->c2', head=cls.b2.revision,
-            depends_on=['c1lib'])
+            util.rev_id(), "b2->c2", head=cls.b2.revision, depends_on=["c1lib"]
+        )
 
         cls.d1 = env.generate_revision(
-            util.rev_id(), 'c1->d1',
-            head=cls.c1.revision)
+            util.rev_id(), "c1->d1", head=cls.c1.revision
+        )
         cls.e1 = env.generate_revision(
-            util.rev_id(), 'd1->e1',
-            head=cls.d1.revision)
+            util.rev_id(), "d1->e1", head=cls.d1.revision
+        )
         cls.f1 = env.generate_revision(
-            util.rev_id(), 'e1->f1',
-            head=cls.e1.revision)
+            util.rev_id(), "e1->f1", head=cls.e1.revision
+        )
 
     def test_upgrade_path(self):
         self._assert_upgrade(
-            self.c2.revision, self.a2.revision,
+            self.c2.revision,
+            self.a2.revision,
             [
                 self.up_(self.a1),
                 self.up_(self.b1),
@@ -817,23 +886,23 @@ class DependsOnBranchLabelTest(MigrationTest):
                 self.up_(self.b2),
                 self.up_(self.c2),
             ],
-            set([self.c2.revision])
+            set([self.c2.revision]),
         )
 
 
 class ForestTest(MigrationTest):
-
     @classmethod
     def setup_class(cls):
         cls.env = env = staging_env()
-        cls.a1 = env.generate_revision(util.rev_id(), '->a1')
-        cls.b1 = env.generate_revision(util.rev_id(), 'a1->b1')
+        cls.a1 = env.generate_revision(util.rev_id(), "->a1")
+        cls.b1 = env.generate_revision(util.rev_id(), "a1->b1")
 
         cls.a2 = env.generate_revision(
-            util.rev_id(), '->a2', head=(),
-            refresh=True)
+            util.rev_id(), "->a2", head=(), refresh=True
+        )
         cls.b2 = env.generate_revision(
-            util.rev_id(), 'a2->b2', head=cls.a2.revision)
+            util.rev_id(), "a2->b2", head=cls.a2.revision
+        )
 
     @classmethod
     def teardown_class(cls):
@@ -842,8 +911,12 @@ class ForestTest(MigrationTest):
     def test_base_to_heads(self):
         eq_(
             self.env._upgrade_revs("heads", "base"),
-            [self.up_(self.a2), self.up_(self.b2),
-             self.up_(self.a1), self.up_(self.b1)]
+            [
+                self.up_(self.a2),
+                self.up_(self.b2),
+                self.up_(self.a1),
+                self.up_(self.b1),
+            ],
         )
 
     def test_stamp_to_heads(self):
@@ -851,40 +924,44 @@ class ForestTest(MigrationTest):
         eq_(len(revs), 2)
         eq_(
             set(r.to_revisions for r in revs),
-            set([(self.b1.revision,), (self.b2.revision,)])
+            set([(self.b1.revision,), (self.b2.revision,)]),
         )
 
     def test_stamp_to_heads_no_moves_needed(self):
         revs = self.env._stamp_revs(
-            "heads", (self.b1.revision, self.b2.revision))
+            "heads", (self.b1.revision, self.b2.revision)
+        )
         eq_(len(revs), 0)
 
 
 class MergedPathTest(MigrationTest):
-
     @classmethod
     def setup_class(cls):
         cls.env = env = staging_env()
-        cls.a = env.generate_revision(util.rev_id(), '->a')
-        cls.b = env.generate_revision(util.rev_id(), 'a->b')
+        cls.a = env.generate_revision(util.rev_id(), "->a")
+        cls.b = env.generate_revision(util.rev_id(), "a->b")
 
-        cls.c1 = env.generate_revision(util.rev_id(), 'b->c1')
-        cls.d1 = env.generate_revision(util.rev_id(), 'c1->d1')
+        cls.c1 = env.generate_revision(util.rev_id(), "b->c1")
+        cls.d1 = env.generate_revision(util.rev_id(), "c1->d1")
 
         cls.c2 = env.generate_revision(
-            util.rev_id(), 'b->c2',
-            branch_labels='c2branch',
-            head=cls.b.revision, splice=True)
+            util.rev_id(),
+            "b->c2",
+            branch_labels="c2branch",
+            head=cls.b.revision,
+            splice=True,
+        )
         cls.d2 = env.generate_revision(
-            util.rev_id(), 'c2->d2',
-            head=cls.c2.revision)
+            util.rev_id(), "c2->d2", head=cls.c2.revision
+        )
 
         cls.e = env.generate_revision(
-            util.rev_id(), 'merge d1 and d2',
-            head=(cls.d1.revision, cls.d2.revision)
+            util.rev_id(),
+            "merge d1 and d2",
+            head=(cls.d1.revision, cls.d2.revision),
         )
 
-        cls.f = env.generate_revision(util.rev_id(), 'e->f')
+        cls.f = env.generate_revision(util.rev_id(), "e->f")
 
     @classmethod
     def teardown_class(cls):
@@ -897,7 +974,7 @@ class MergedPathTest(MigrationTest):
         eq_(
             revs[0].merge_branch_idents(heads),
             # no deletes, UPDATE e to c2
-            ([], self.e.revision, self.c2.revision)
+            ([], self.e.revision, self.c2.revision),
         )
 
     def test_stamp_down_across_merge_prior_branching(self):
@@ -907,7 +984,7 @@ class MergedPathTest(MigrationTest):
         eq_(
             revs[0].merge_branch_idents(heads),
             # no deletes, UPDATE e to c2
-            ([], self.e.revision, self.a.revision)
+            ([], self.e.revision, self.a.revision),
         )
 
     def test_stamp_up_across_merge_from_single_branch(self):
@@ -916,7 +993,7 @@ class MergedPathTest(MigrationTest):
         eq_(
             revs[0].merge_branch_idents([self.c2.revision]),
             # no deletes, UPDATE e to c2
-            ([], self.c2.revision, self.e.revision)
+            ([], self.c2.revision, self.e.revision),
         )
 
     def test_stamp_labled_head_across_merge_from_multiple_branch(self):
@@ -924,23 +1001,23 @@ class MergedPathTest(MigrationTest):
         # d1 both in terms of "c2branch" as well as that the "head"
         # revision "f" is the head of both d1 and d2
         revs = self.env._stamp_revs(
-            "c2branch@head", [self.d1.revision, self.c2.revision])
+            "c2branch@head", [self.d1.revision, self.c2.revision]
+        )
         eq_(len(revs), 1)
         eq_(
             revs[0].merge_branch_idents([self.d1.revision, self.c2.revision]),
             # DELETE d1 revision, UPDATE c2 to e
-            ([self.d1.revision], self.c2.revision, self.f.revision)
+            ([self.d1.revision], self.c2.revision, self.f.revision),
         )
 
     def test_stamp_up_across_merge_from_multiple_branch(self):
         heads = [self.d1.revision, self.c2.revision]
-        revs = self.env._stamp_revs(
-            self.e.revision, heads)
+        revs = self.env._stamp_revs(self.e.revision, heads)
         eq_(len(revs), 1)
         eq_(
             revs[0].merge_branch_idents(heads),
             # DELETE d1 revision, UPDATE c2 to e
-            ([self.d1.revision], self.c2.revision, self.e.revision)
+            ([self.d1.revision], self.c2.revision, self.e.revision),
         )
 
     def test_stamp_up_across_merge_prior_branching(self):
@@ -950,7 +1027,7 @@ class MergedPathTest(MigrationTest):
         eq_(
             revs[0].merge_branch_idents(heads),
             # no deletes, UPDATE e to c2
-            ([], self.b.revision, self.e.revision)
+            ([], self.b.revision, self.e.revision),
         )
 
     def test_upgrade_across_merge_point(self):
@@ -963,9 +1040,9 @@ class MergedPathTest(MigrationTest):
                 self.up_(self.c1),  # b->c1, create new branch
                 self.up_(self.d1),
                 self.up_(self.e),  # d1/d2 -> e, merge branches
-                                   # (DELETE d2, UPDATE d1->e)
-                self.up_(self.f)
-            ]
+                # (DELETE d2, UPDATE d1->e)
+                self.up_(self.f),
+            ],
         )
 
     def test_downgrade_across_merge_point(self):
@@ -975,10 +1052,10 @@ class MergedPathTest(MigrationTest):
             [
                 self.down_(self.f),
                 self.down_(self.e),  # e -> d1 and d2, unmerge branches
-                                     # (UPDATE e->d1, INSERT d2)
+                # (UPDATE e->d1, INSERT d2)
                 self.down_(self.d1),
                 self.down_(self.c1),
                 self.down_(self.d2),
                 self.down_(self.c2),  # c2->b, delete branch
-            ]
+            ],
         )
diff --git a/tox.ini b/tox.ini
index 8f3640df85ecefc27bca743adf8a03f627ccce57..660761a91c4785257e1dea4d8b2483d8ea26e9b0 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -55,16 +55,14 @@ commands=
   {oracle}: python reap_oracle_dbs.py oracle_idents.txt
 
 
+# thanks to https://julien.danjou.info/the-best-flake8-extensions/
 [testenv:pep8]
-deps=flake8
-commands = python -m flake8 {posargs}
-
-
-[flake8]
-
-show-source = True
-ignore = E711,E712,E721,D,N
-# F841,F811,F401
-exclude=.venv,.git,.tox,dist,doc,*egg,build
-
-
+deps=
+      flake8
+      flake8-import-order
+      flake8-builtins
+      flake8-docstrings
+      flake8-rst-docstrings
+      # used by flake8-rst-docstrings
+      pygments
+commands = flake8 ./alembic/ ./tests/ setup.py