]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Update black to 24.1.1
authorFederico Caselli <cfederico87@gmail.com>
Wed, 14 Feb 2024 19:35:24 +0000 (20:35 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 14 Feb 2024 19:37:44 +0000 (20:37 +0100)
Change-Id: Iebd9b9e866a6a58541c187e70d4f170fdf84daff

28 files changed:
.pre-commit-config.yaml
alembic/autogenerate/api.py
alembic/autogenerate/render.py
alembic/config.py
alembic/ddl/_autogen.py
alembic/ddl/base.py
alembic/ddl/impl.py
alembic/ddl/mysql.py
alembic/ddl/oracle.py
alembic/operations/base.py
alembic/operations/ops.py
alembic/operations/schemaobj.py
alembic/runtime/environment.py
alembic/runtime/migration.py
alembic/script/base.py
alembic/script/revision.py
alembic/testing/fixtures.py
alembic/testing/suite/test_environment.py
alembic/util/langhelpers.py
alembic/util/sqla_compat.py
reap_dbs.py
setup.cfg
tests/test_autogen_indexes.py
tests/test_autogen_render.py
tests/test_batch.py
tests/test_mssql.py
tests/test_version_traversal.py
tox.ini

index f1a8b41838d716ef9fc21dbeb0b95f5c3a7675e5..ac4be8989c88690902175a350cdfbe12be75accc 100644 (file)
@@ -2,7 +2,7 @@
 # See https://pre-commit.com/hooks.html for more hooks
 repos:
 -   repo: https://github.com/python/black
-    rev: 23.3.0
+    rev: 24.1.1
     hooks:
     -   id: black
 
index aa8f32f65359c9c04f41ea24e21131beee2d8d2a..4c039162884b78c09d4771f4e9373ed636427fa8 100644 (file)
@@ -596,9 +596,9 @@ class RevisionContext:
         migration_script = self.generated_revisions[-1]
         if not getattr(migration_script, "_needs_render", False):
             migration_script.upgrade_ops_list[-1].upgrade_token = upgrade_token
-            migration_script.downgrade_ops_list[
-                -1
-            ].downgrade_token = downgrade_token
+            migration_script.downgrade_ops_list[-1].downgrade_token = (
+                downgrade_token
+            )
             migration_script._needs_render = True
         else:
             migration_script._upgrade_ops.append(
index 317a6dbed9cf6eb6514d67a82ee3ee853c22254b..61d56acfed416a77e0dbf333ef36cb88af74c894 100644 (file)
@@ -187,9 +187,11 @@ def _render_create_table_comment(
         prefix=_alembic_autogenerate_prefix(autogen_context),
         tname=op.table_name,
         comment="%r" % op.comment if op.comment is not None else None,
-        existing="%r" % op.existing_comment
-        if op.existing_comment is not None
-        else None,
+        existing=(
+            "%r" % op.existing_comment
+            if op.existing_comment is not None
+            else None
+        ),
         schema="'%s'" % op.schema if op.schema is not None else None,
         indent="    ",
     )
@@ -216,9 +218,11 @@ def _render_drop_table_comment(
     return templ.format(
         prefix=_alembic_autogenerate_prefix(autogen_context),
         tname=op.table_name,
-        existing="%r" % op.existing_comment
-        if op.existing_comment is not None
-        else None,
+        existing=(
+            "%r" % op.existing_comment
+            if op.existing_comment is not None
+            else None
+        ),
         schema="'%s'" % op.schema if op.schema is not None else None,
         indent="    ",
     )
@@ -328,9 +332,11 @@ def _add_index(autogen_context: AutogenContext, op: ops.CreateIndexOp) -> str:
             _get_index_rendered_expressions(index, autogen_context)
         ),
         "unique": index.unique or False,
-        "schema": (", schema=%r" % _ident(index.table.schema))
-        if index.table.schema
-        else "",
+        "schema": (
+            (", schema=%r" % _ident(index.table.schema))
+            if index.table.schema
+            else ""
+        ),
         "kwargs": ", " + ", ".join(opts) if opts else "",
     }
     return text
@@ -592,9 +598,11 @@ def _get_index_rendered_expressions(
     idx: Index, autogen_context: AutogenContext
 ) -> List[str]:
     return [
-        repr(_ident(getattr(exp, "name", None)))
-        if isinstance(exp, sa_schema.Column)
-        else _render_potential_expr(exp, autogen_context, is_index=True)
+        (
+            repr(_ident(getattr(exp, "name", None)))
+            if isinstance(exp, sa_schema.Column)
+            else _render_potential_expr(exp, autogen_context, is_index=True)
+        )
         for exp in idx.expressions
     ]
 
@@ -1075,9 +1083,11 @@ def _render_check_constraint(
         )
     return "%(prefix)sCheckConstraint(%(sqltext)s%(opts)s)" % {
         "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
-        "opts": ", " + (", ".join("%s=%s" % (k, v) for k, v in opts))
-        if opts
-        else "",
+        "opts": (
+            ", " + (", ".join("%s=%s" % (k, v) for k, v in opts))
+            if opts
+            else ""
+        ),
         "sqltext": _render_potential_expr(
             constraint.sqltext, autogen_context, wrap_in_text=False
         ),
index 4b2263fddacf0cd85e9969221ef9b3d8105cd8f4..2c52e7cd138820e1fdda5ac277ae723958d89d19 100644 (file)
@@ -221,8 +221,7 @@ class Config:
     @overload
     def get_section(
         self, name: str, default: None = ...
-    ) -> Optional[Dict[str, str]]:
-        ...
+    ) -> Optional[Dict[str, str]]: ...
 
     # "default" here could also be a TypeVar
     # _MT = TypeVar("_MT", bound=Mapping[str, str]),
@@ -230,14 +229,12 @@ class Config:
     @overload
     def get_section(
         self, name: str, default: Dict[str, str]
-    ) -> Dict[str, str]:
-        ...
+    ) -> Dict[str, str]: ...
 
     @overload
     def get_section(
         self, name: str, default: Mapping[str, str]
-    ) -> Union[Dict[str, str], Mapping[str, str]]:
-        ...
+    ) -> Union[Dict[str, str], Mapping[str, str]]: ...
 
     def get_section(
         self, name: str, default: Optional[Mapping[str, str]] = None
@@ -313,14 +310,12 @@ class Config:
             return default
 
     @overload
-    def get_main_option(self, name: str, default: str) -> str:
-        ...
+    def get_main_option(self, name: str, default: str) -> str: ...
 
     @overload
     def get_main_option(
         self, name: str, default: Optional[str] = None
-    ) -> Optional[str]:
-        ...
+    ) -> Optional[str]: ...
 
     def get_main_option(
         self, name: str, default: Optional[str] = None
index e22153c49c761451c074c11de6c7ea53d20c1149..74715b18a8bfd8b727ee14e8ed3d290de7169d7b 100644 (file)
@@ -287,18 +287,22 @@ class _fk_constraint_sig(_constraint_sig[ForeignKeyConstraint]):
             self.target_table,
             tuple(self.target_columns),
         ) + (
-            (None if onupdate.lower() == "no action" else onupdate.lower())
-            if onupdate
-            else None,
-            (None if ondelete.lower() == "no action" else ondelete.lower())
-            if ondelete
-            else None,
+            (
+                (None if onupdate.lower() == "no action" else onupdate.lower())
+                if onupdate
+                else None
+            ),
+            (
+                (None if ondelete.lower() == "no action" else ondelete.lower())
+                if ondelete
+                else None
+            ),
             # convert initially + deferrable into one three-state value
-            "initially_deferrable"
-            if initially and initially.lower() == "deferred"
-            else "deferrable"
-            if deferrable
-            else "not deferrable",
+            (
+                "initially_deferrable"
+                if initially and initially.lower() == "deferred"
+                else "deferrable" if deferrable else "not deferrable"
+            ),
         )
 
     @util.memoized_property
index 7a85a5c198affa8f50fcfe4da126836627ae472c..690c153763944521e310a33e0e4e3c14117ef9f9 100644 (file)
@@ -40,7 +40,6 @@ _ServerDefault = Union["TextClause", "FetchedValue", "Function[Any]", str]
 
 
 class AlterTable(DDLElement):
-
     """Represent an ALTER TABLE statement.
 
     Only the string name and optional schema name of the table
@@ -238,9 +237,11 @@ def visit_column_default(
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
-        "SET DEFAULT %s" % format_server_default(compiler, element.default)
-        if element.default is not None
-        else "DROP DEFAULT",
+        (
+            "SET DEFAULT %s" % format_server_default(compiler, element.default)
+            if element.default is not None
+            else "DROP DEFAULT"
+        ),
     )
 
 
index bf202a48c62c9fa33c145a407a8b4b6afe06466b..d298392327e7965e93b631d986cca9c73084182a 100644 (file)
@@ -77,7 +77,6 @@ _impls: Dict[str, Type[DefaultImpl]] = {}
 
 
 class DefaultImpl(metaclass=ImplMeta):
-
     """Provide the entrypoint for major migration operations,
     including database-specific behavioral variances.
 
@@ -425,13 +424,15 @@ class DefaultImpl(metaclass=ImplMeta):
                 self._exec(
                     sqla_compat._insert_inline(table).values(
                         **{
-                            k: sqla_compat._literal_bindparam(
-                                k, v, type_=table.c[k].type
-                            )
-                            if not isinstance(
-                                v, sqla_compat._literal_bindparam
+                            k: (
+                                sqla_compat._literal_bindparam(
+                                    k, v, type_=table.c[k].type
+                                )
+                                if not isinstance(
+                                    v, sqla_compat._literal_bindparam
+                                )
+                                else v
                             )
-                            else v
                             for k, v in row.items()
                         }
                     )
index f312173e946d117b276e06ed5aa290f18f7db61b..3482f672daee97b0a4d8a6a84503bed7d4d5ccd2 100644 (file)
@@ -94,21 +94,29 @@ class MySQLImpl(DefaultImpl):
                     column_name,
                     schema=schema,
                     newname=name if name is not None else column_name,
-                    nullable=nullable
-                    if nullable is not None
-                    else existing_nullable
-                    if existing_nullable is not None
-                    else True,
+                    nullable=(
+                        nullable
+                        if nullable is not None
+                        else (
+                            existing_nullable
+                            if existing_nullable is not None
+                            else True
+                        )
+                    ),
                     type_=type_ if type_ is not None else existing_type,
-                    default=server_default
-                    if server_default is not False
-                    else existing_server_default,
-                    autoincrement=autoincrement
-                    if autoincrement is not None
-                    else existing_autoincrement,
-                    comment=comment
-                    if comment is not False
-                    else existing_comment,
+                    default=(
+                        server_default
+                        if server_default is not False
+                        else existing_server_default
+                    ),
+                    autoincrement=(
+                        autoincrement
+                        if autoincrement is not None
+                        else existing_autoincrement
+                    ),
+                    comment=(
+                        comment if comment is not False else existing_comment
+                    ),
                 )
             )
         elif (
@@ -123,21 +131,29 @@ class MySQLImpl(DefaultImpl):
                     column_name,
                     schema=schema,
                     newname=name if name is not None else column_name,
-                    nullable=nullable
-                    if nullable is not None
-                    else existing_nullable
-                    if existing_nullable is not None
-                    else True,
+                    nullable=(
+                        nullable
+                        if nullable is not None
+                        else (
+                            existing_nullable
+                            if existing_nullable is not None
+                            else True
+                        )
+                    ),
                     type_=type_ if type_ is not None else existing_type,
-                    default=server_default
-                    if server_default is not False
-                    else existing_server_default,
-                    autoincrement=autoincrement
-                    if autoincrement is not None
-                    else existing_autoincrement,
-                    comment=comment
-                    if comment is not False
-                    else existing_comment,
+                    default=(
+                        server_default
+                        if server_default is not False
+                        else existing_server_default
+                    ),
+                    autoincrement=(
+                        autoincrement
+                        if autoincrement is not None
+                        else existing_autoincrement
+                    ),
+                    comment=(
+                        comment if comment is not False else existing_comment
+                    ),
                 )
             )
         elif server_default is not False:
@@ -368,9 +384,11 @@ def _mysql_alter_default(
     return "%s ALTER COLUMN %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         format_column_name(compiler, element.column_name),
-        "SET DEFAULT %s" % format_server_default(compiler, element.default)
-        if element.default is not None
-        else "DROP DEFAULT",
+        (
+            "SET DEFAULT %s" % format_server_default(compiler, element.default)
+            if element.default is not None
+            else "DROP DEFAULT"
+        ),
     )
 
 
index 54011740723749b50f53beaac6c75ca020e365a3..eac99124f42290163b402765b0a94e7d4f75f820 100644 (file)
@@ -141,9 +141,11 @@ def visit_column_default(
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
-        "DEFAULT %s" % format_server_default(compiler, element.default)
-        if element.default is not None
-        else "DEFAULT NULL",
+        (
+            "DEFAULT %s" % format_server_default(compiler, element.default)
+            if element.default is not None
+            else "DEFAULT NULL"
+        ),
     )
 
 
index bafe441a69ceb2bcd13f5f3f3fad1382b589e99f..bd1b170d3d684de31f4c7e972dfa93c2882d7aa1 100644 (file)
@@ -406,8 +406,7 @@ class AbstractOperations(util.ModuleClsProxy):
         return self.migration_context
 
     @overload
-    def invoke(self, operation: CreateTableOp) -> Table:
-        ...
+    def invoke(self, operation: CreateTableOp) -> Table: ...
 
     @overload
     def invoke(
@@ -427,12 +426,10 @@ class AbstractOperations(util.ModuleClsProxy):
             DropTableOp,
             ExecuteSQLOp,
         ],
-    ) -> None:
-        ...
+    ) -> None: ...
 
     @overload
-    def invoke(self, operation: MigrateOperation) -> Any:
-        ...
+    def invoke(self, operation: MigrateOperation) -> Any: ...
 
     def invoke(self, operation: MigrateOperation) -> Any:
         """Given a :class:`.MigrateOperation`, invoke it in terms of
index 7b65191cf20fa5bc1be08c646247dee611f4f4fe..0282d5716370c47231883b17140334a450e87560 100644 (file)
@@ -1371,9 +1371,9 @@ class DropTableOp(MigrateOperation):
             info=self.info.copy() if self.info else {},
             prefixes=list(self.prefixes) if self.prefixes else [],
             schema=self.schema,
-            _constraints_included=self._reverse._constraints_included
-            if self._reverse
-            else False,
+            _constraints_included=(
+                self._reverse._constraints_included if self._reverse else False
+            ),
             **self.table_kw,
         )
         return t
index 32b26e9b9d6471c7c663e732a2cfeb35e9eb4bd6..59c1002f109c6fcde6b76e1f2910921f349ec13d 100644 (file)
@@ -223,10 +223,12 @@ class SchemaObjects:
         t = sa_schema.Table(name, m, *cols, **kw)
 
         constraints = [
-            sqla_compat._copy(elem, target_table=t)
-            if getattr(elem, "parent", None) is not t
-            and getattr(elem, "parent", None) is not None
-            else elem
+            (
+                sqla_compat._copy(elem, target_table=t)
+                if getattr(elem, "parent", None) is not t
+                and getattr(elem, "parent", None) is not None
+                else elem
+            )
             for elem in columns
             if isinstance(elem, (Constraint, Index))
         ]
index d64b2adc279761b40724b2c7c7c7f53da1e77019..a30972ec91251e4c38004ad138583789e5ffd89b 100644 (file)
@@ -108,7 +108,6 @@ CompareType = Callable[
 
 
 class EnvironmentContext(util.ModuleClsProxy):
-
     """A configurational facade made available in an ``env.py`` script.
 
     The :class:`.EnvironmentContext` acts as a *facade* to the more
@@ -342,18 +341,17 @@ class EnvironmentContext(util.ModuleClsProxy):
         return self.context_opts.get("tag", None)  # type: ignore[no-any-return]  # noqa: E501
 
     @overload
-    def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]:
-        ...
+    def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]: ...
 
     @overload
-    def get_x_argument(self, as_dictionary: Literal[True]) -> Dict[str, str]:
-        ...
+    def get_x_argument(
+        self, as_dictionary: Literal[True]
+    ) -> Dict[str, str]: ...
 
     @overload
     def get_x_argument(
         self, as_dictionary: bool = ...
-    ) -> Union[List[str], Dict[str, str]]:
-        ...
+    ) -> Union[List[str], Dict[str, str]]: ...
 
     def get_x_argument(
         self, as_dictionary: bool = False
index 95c69bc692555e8614570691b0309b866bac2946..6cfe5e23e4d429cc181bf8d175ac6f09e35d0157 100644 (file)
@@ -86,7 +86,6 @@ class _ProxyTransaction:
 
 
 class MigrationContext:
-
     """Represent the database state made available to a migration
     script.
 
@@ -218,9 +217,11 @@ class MigrationContext:
             log.info("Generating static SQL")
         log.info(
             "Will assume %s DDL.",
-            "transactional"
-            if self.impl.transactional_ddl
-            else "non-transactional",
+            (
+                "transactional"
+                if self.impl.transactional_ddl
+                else "non-transactional"
+            ),
         )
 
     @classmethod
@@ -345,9 +346,9 @@ class MigrationContext:
             # except that it will not know it's in "autocommit" and will
             # emit deprecation warnings when an autocommit action takes
             # place.
-            self.connection = (
-                self.impl.connection
-            ) = base_connection.execution_options(isolation_level="AUTOCOMMIT")
+            self.connection = self.impl.connection = (
+                base_connection.execution_options(isolation_level="AUTOCOMMIT")
+            )
 
             # sqlalchemy future mode will "autobegin" in any case, so take
             # control of that "transaction" here
@@ -1006,8 +1007,7 @@ class MigrationStep:
     if TYPE_CHECKING:
 
         @property
-        def doc(self) -> Optional[str]:
-            ...
+        def doc(self) -> Optional[str]: ...
 
     @property
     def name(self) -> str:
index 5945ca591c221279b05b07833591faa4ad4cd628..66564781258a984a9834120b0bf4cad70e39e56a 100644 (file)
@@ -56,7 +56,6 @@ _split_on_space_comma_colon = re.compile(r", *|(?: +)|\:")
 
 
 class ScriptDirectory:
-
     """Provides operations upon an Alembic script directory.
 
     This object is useful to get information as to current revisions,
@@ -732,9 +731,11 @@ class ScriptDirectory:
         if depends_on:
             with self._catch_revision_errors():
                 resolved_depends_on = [
-                    dep
-                    if dep in rev.branch_labels  # maintain branch labels
-                    else rev.revision  # resolve partial revision identifiers
+                    (
+                        dep
+                        if dep in rev.branch_labels  # maintain branch labels
+                        else rev.revision
+                    )  # resolve partial revision identifiers
                     for rev, dep in [
                         (not_none(self.revision_map.get_revision(dep)), dep)
                         for dep in util.to_list(depends_on)
@@ -808,7 +809,6 @@ class ScriptDirectory:
 
 
 class Script(revision.Revision):
-
     """Represent a single revision file in a ``versions/`` directory.
 
     The :class:`.Script` instance is returned by methods
@@ -930,9 +930,11 @@ class Script(revision.Revision):
         if head_indicators or tree_indicators:
             text += "%s%s%s" % (
                 " (head)" if self._is_real_head else "",
-                " (effective head)"
-                if self.is_head and not self._is_real_head
-                else "",
+                (
+                    " (effective head)"
+                    if self.is_head and not self._is_real_head
+                    else ""
+                ),
                 " (current)" if self._db_current_indicator else "",
             )
         if tree_indicators:
index 77a802cdcadf9c59049fdb5db1c2be95d305a1ae..c3108e985a0a013922793489dbd9b03df95b3e07 100644 (file)
@@ -56,8 +56,7 @@ class _CollectRevisionsProtocol(Protocol):
         inclusive: bool,
         implicit_base: bool,
         assert_relative_length: bool,
-    ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase], ...]]:
-        ...
+    ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase], ...]]: ...
 
 
 class RevisionError(Exception):
@@ -720,9 +719,11 @@ class RevisionMap:
             resolved_target = target
 
         resolved_test_against_revs = [
-            self._revision_for_ident(test_against_rev)
-            if not isinstance(test_against_rev, Revision)
-            else test_against_rev
+            (
+                self._revision_for_ident(test_against_rev)
+                if not isinstance(test_against_rev, Revision)
+                else test_against_rev
+            )
             for test_against_rev in util.to_tuple(
                 test_against_revs, default=()
             )
@@ -1016,9 +1017,9 @@ class RevisionMap:
                         # each time but it was getting complicated
                         current_heads[current_candidate_idx] = heads_to_add[0]
                         current_heads.extend(heads_to_add[1:])
-                        ancestors_by_idx[
-                            current_candidate_idx
-                        ] = get_ancestors(heads_to_add[0])
+                        ancestors_by_idx[current_candidate_idx] = (
+                            get_ancestors(heads_to_add[0])
+                        )
                         ancestors_by_idx.extend(
                             get_ancestors(head) for head in heads_to_add[1:]
                         )
@@ -1183,9 +1184,13 @@ class RevisionMap:
                         branch_label = symbol
                 # Walk down the tree to find downgrade target.
                 rev = self._walk(
-                    start=self.get_revision(symbol)
-                    if branch_label is None
-                    else self.get_revision("%s@%s" % (branch_label, symbol)),
+                    start=(
+                        self.get_revision(symbol)
+                        if branch_label is None
+                        else self.get_revision(
+                            "%s@%s" % (branch_label, symbol)
+                        )
+                    ),
                     steps=rel_int,
                     no_overwalk=assert_relative_length,
                 )
@@ -1303,9 +1308,13 @@ class RevisionMap:
                 )
             return (
                 self._walk(
-                    start=self.get_revision(symbol)
-                    if branch_label is None
-                    else self.get_revision("%s@%s" % (branch_label, symbol)),
+                    start=(
+                        self.get_revision(symbol)
+                        if branch_label is None
+                        else self.get_revision(
+                            "%s@%s" % (branch_label, symbol)
+                        )
+                    ),
                     steps=relative,
                     no_overwalk=assert_relative_length,
                 ),
@@ -1694,15 +1703,13 @@ class Revision:
 
 
 @overload
-def tuple_rev_as_scalar(rev: None) -> None:
-    ...
+def tuple_rev_as_scalar(rev: None) -> None: ...
 
 
 @overload
 def tuple_rev_as_scalar(
     rev: Union[Tuple[_T, ...], List[_T]]
-) -> Union[_T, Tuple[_T, ...], List[_T]]:
-    ...
+) -> Union[_T, Tuple[_T, ...], List[_T]]: ...
 
 
 def tuple_rev_as_scalar(
index b6cea632e9047db4c51eaefe08eb890081a049ef..3b5ce596e6f7d26fa821df347266b8bb6f65ec94 100644 (file)
@@ -274,9 +274,11 @@ class AlterColRoundTripFixture:
                 "x",
                 column.name,
                 existing_type=column.type,
-                existing_server_default=column.server_default
-                if column.server_default is not None
-                else False,
+                existing_server_default=(
+                    column.server_default
+                    if column.server_default is not None
+                    else False
+                ),
                 existing_nullable=True if column.nullable else False,
                 # existing_comment=column.comment,
                 nullable=to_.get("nullable", None),
@@ -304,9 +306,13 @@ class AlterColRoundTripFixture:
             new_col["type"],
             new_col.get("default", None),
             compare.get("type", old_col["type"]),
-            compare["server_default"].text
-            if "server_default" in compare
-            else column.server_default.arg.text
-            if column.server_default is not None
-            else None,
+            (
+                compare["server_default"].text
+                if "server_default" in compare
+                else (
+                    column.server_default.arg.text
+                    if column.server_default is not None
+                    else None
+                )
+            ),
         )
index 8c86859ae2cff2b0497593dff6fc547a2487bbe3..df2d9afbd490fb1cd3170de428eaf3190e4f2a48 100644 (file)
@@ -24,9 +24,9 @@ class MigrationTransactionTest(TestBase):
             self.context = MigrationContext.configure(
                 dialect=conn.dialect, opts=opts
             )
-            self.context.output_buffer = (
-                self.context.impl.output_buffer
-            ) = io.StringIO()
+            self.context.output_buffer = self.context.impl.output_buffer = (
+                io.StringIO()
+            )
         else:
             self.context = MigrationContext.configure(
                 connection=conn, opts=opts
index 4a5bf09a98bba393e5d61a8abf67ba011ab52b71..80d88cbcec56e280c55c395d6f00b3a72c5946f1 100644 (file)
@@ -234,20 +234,17 @@ def rev_id() -> str:
 
 
 @overload
-def to_tuple(x: Any, default: Tuple[Any, ...]) -> Tuple[Any, ...]:
-    ...
+def to_tuple(x: Any, default: Tuple[Any, ...]) -> Tuple[Any, ...]: ...
 
 
 @overload
-def to_tuple(x: None, default: Optional[_T] = ...) -> _T:
-    ...
+def to_tuple(x: None, default: Optional[_T] = ...) -> _T: ...
 
 
 @overload
 def to_tuple(
     x: Any, default: Optional[Tuple[Any, ...]] = None
-) -> Tuple[Any, ...]:
-    ...
+) -> Tuple[Any, ...]: ...
 
 
 def to_tuple(
index 8489c19fac7c163dc9053d2f52606855117d60a6..30b9b4c4e6b49e7302ce4b2cedf9ec67736db59e 100644 (file)
@@ -59,8 +59,7 @@ _CE = TypeVar("_CE", bound=Union["ColumnElement[Any]", "SchemaItem"])
 
 
 class _CompilerProtocol(Protocol):
-    def __call__(self, element: Any, compiler: Any, **kw: Any) -> str:
-        ...
+    def __call__(self, element: Any, compiler: Any, **kw: Any) -> str: ...
 
 
 def _safe_int(value: str) -> Union[int, str]:
@@ -95,8 +94,7 @@ if TYPE_CHECKING:
 
     def compiles(
         element: Type[ClauseElement], *dialects: str
-    ) -> Callable[[_CompilerProtocol], _CompilerProtocol]:
-        ...
+    ) -> Callable[[_CompilerProtocol], _CompilerProtocol]: ...
 
 else:
     from sqlalchemy.ext.compiler import compiles
index ae7ff8582b03fc6a23a8c4b97bb95ac1ba25aab9..6b2215dfe24bccd9f6e8a76d70ecc5b3a05c1e8b 100644 (file)
@@ -10,6 +10,7 @@ running a kill of all detected sessions does not seem to release the
 database in process.
 
 """
+
 import logging
 import sys
 
index fa957ecac63eddcf4e744613daddf0893a2d8386..3c516430571b9d1a086ee8cf9323910ffee589b4 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -86,7 +86,7 @@ enable-extensions = G
 ignore =
     A003,
     D,
-    E203,E305,E711,E712,E721,E722,E741,
+    E203,E305,E704,E711,E712,E721,E722,E741,
     N801,N802,N806,
     RST304,RST303,RST299,RST399,
     W503,W504
index b06e7c90c2eb470ca1debe3416b878647c53f46b..d1e95e96521bbf7f270cd29decea173f3f9ce8fd 100644 (file)
@@ -642,9 +642,11 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         diffs = {
             (
                 cmd,
-                isinstance(obj, (UniqueConstraint, Index))
-                if obj.name is not None
-                else False,
+                (
+                    isinstance(obj, (UniqueConstraint, Index))
+                    if obj.name is not None
+                    else False
+                ),
             )
             for cmd, obj in diffs
         }
@@ -1800,7 +1802,6 @@ class NoUqReflectionIndexTest(NoUqReflection, AutogenerateUniqueIndexTest):
 
 
 class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest):
-
     """this test suite simulates the condition where:
 
     a. the dialect doesn't report unique constraints
index eeeb92ed1c3d9ad5d353aff0bc81f4ec882f073c..254b6ddd1253291c2e8986e46666410c95ea81a1 100644 (file)
@@ -53,7 +53,6 @@ from alembic.util import sqla_compat
 
 
 class AutogenRenderTest(TestBase):
-
     """test individual directives"""
 
     def setUp(self):
index 9992af2c6be371129f1ab9a440dc165ea0122a84..2806dde1b4db4eb7e724cae771b89869979538a6 100644 (file)
@@ -329,19 +329,21 @@ class BatchApplyTest(TestBase):
         )
 
         args["tname_colnames"] = ", ".join(
-            "CAST(%(schema)stname.%(name)s AS %(type)s) AS %(cast_label)s"
-            % {
-                "schema": args["schema"],
-                "name": name,
-                "type": impl.new_table.c[name].type,
-                "cast_label": name if sqla_14 else "anon_1",
-            }
-            if (
-                impl.new_table.c[name].type._type_affinity
-                is not impl.table.c[name].type._type_affinity
+            (
+                "CAST(%(schema)stname.%(name)s AS %(type)s) AS %(cast_label)s"
+                % {
+                    "schema": args["schema"],
+                    "name": name,
+                    "type": impl.new_table.c[name].type,
+                    "cast_label": name if sqla_14 else "anon_1",
+                }
+                if (
+                    impl.new_table.c[name].type._type_affinity
+                    is not impl.table.c[name].type._type_affinity
+                )
+                else "%(schema)stname.%(name)s"
+                % {"schema": args["schema"], "name": name}
             )
-            else "%(schema)stname.%(name)s"
-            % {"schema": args["schema"], "name": name}
             for name in colnames
             if name in impl.table.c
         )
index 693ab57dd1a84127063c70643f9e15baae6a9f4d..fccde2641479ac6cebe8889c0cf667c239e311ca 100644 (file)
@@ -1,4 +1,5 @@
 """Test op functions against MSSQL."""
+
 from __future__ import annotations
 
 from typing import Any
@@ -118,9 +119,9 @@ class OpTest(TestBase):
             expected_nullability = not existing_nullability
             args["nullable"] = expected_nullability
         else:
-            args[
-                "existing_nullable"
-            ] = expected_nullability = existing_nullability
+            args["existing_nullable"] = expected_nullability = (
+                existing_nullability
+            )
 
         op.alter_column("t", "c", **args)
 
index 09816dff5437524425149100fa16717d7cacc5fb..2fd07a95b604b1607ab7347dfad67cb16b220160 100644 (file)
@@ -554,7 +554,6 @@ class BranchedPathTest(MigrationTest):
 
 
 class BranchFromMergepointTest(MigrationTest):
-
     """this is a form that will come up frequently in the
     "many independent roots with cross-dependencies" case.
 
@@ -617,7 +616,6 @@ class BranchFromMergepointTest(MigrationTest):
 
 
 class BranchFrom3WayMergepointTest(MigrationTest):
-
     """this is a form that will come up frequently in the
     "many independent roots with cross-dependencies" case.
 
diff --git a/tox.ini b/tox.ini
index 1c1d942875f37a760109790440752cf91f05d122..a1265d813c563d03d6a3853b551072286728529b 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -25,7 +25,7 @@ deps=pytest>4.6
      backports.zoneinfo;python_version<"3.9"
      tzdata
      zimports
-     black==23.3.0
+     black==24.1.1
      greenlet>=1
 
 
@@ -97,7 +97,7 @@ deps=
       pydocstyle<4.0.0
       # used by flake8-rst-docstrings
       pygments
-      black==23.3.0
+      black==24.1.1
 commands =
      flake8 ./alembic/ ./tests/ setup.py docs/build/conf.py {posargs}
      black --check setup.py tests alembic