]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
distinguish between string contraint name and defined
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 6 Mar 2023 18:34:40 +0000 (13:34 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 6 Mar 2023 19:09:04 +0000 (14:09 -0500)
Take _NONE_NAME into account as a valid constraint name
and don't skip these constraints or consider them to be unnamed.
Thanks to typing this also revealed that previous batch versions
were also keying "_NONE_NAME" constraints as though they were named.

Fixed regression for 1.10.0 where :class:`.Constraint` objects were
suddenly required to have non-None name fields when using batch mode, which
was not previously a requirement.

Change-Id: If4a7191a00848b19cb124bc6da362f3bc6ce1472
Fixes: #1195
alembic/autogenerate/compare.py
alembic/operations/batch.py
alembic/operations/ops.py
alembic/operations/schemaobj.py
alembic/util/sqla_compat.py
docs/build/unreleased/1195.rst [new file with mode: 0644]
tests/test_batch.py

index f50c41cc0a73254e4d93aa21c9262d19fb7fc41c..4f5126f53d635db3e21e8d9e4b77446bfb6b2e95 100644 (file)
@@ -652,7 +652,7 @@ def _compare_indexes_and_uniques(
     conn_names = {
         c.name: c
         for c in conn_unique_constraints.union(conn_indexes_sig)
-        if sqla_compat.constraint_name_defined(c.name)
+        if sqla_compat.constraint_name_string(c.name)
     }
 
     doubled_constraints = {
index fe32eec2672889423b78e3cbfe402aa9e2f55109..da2caf6d2cc3c28eedebca6a008246eb38c9bb40 100644 (file)
@@ -34,6 +34,7 @@ from ..util.sqla_compat import _remove_column_from_collection
 from ..util.sqla_compat import _resolve_for_variant
 from ..util.sqla_compat import _select
 from ..util.sqla_compat import constraint_name_defined
+from ..util.sqla_compat import constraint_name_string
 
 if TYPE_CHECKING:
     from typing import Literal
@@ -269,7 +270,7 @@ class ApplyBatchImpl:
                 # because
                 # we have no way to determine _is_type_bound() for these.
                 pass
-            elif constraint_name_defined(const.name):
+            elif constraint_name_string(const.name):
                 self.named_constraints[const.name] = const
             else:
                 self.unnamed_constraints.append(const)
@@ -669,7 +670,10 @@ class ApplyBatchImpl:
             if self.table.primary_key in self.unnamed_constraints:
                 self.unnamed_constraints.remove(self.table.primary_key)
 
-        self.named_constraints[const.name] = const
+        if constraint_name_string(const.name):
+            self.named_constraints[const.name] = const
+        else:
+            self.unnamed_constraints.append(const)
 
     def drop_constraint(self, const: Constraint) -> None:
         if not const.name:
@@ -681,9 +685,11 @@ class ApplyBatchImpl:
                 for col_const in list(self.columns[col.name].constraints):
                     if col_const.name == const.name:
                         self.columns[col.name].constraints.remove(col_const)
-            else:
-                assert constraint_name_defined(const.name)
+            elif constraint_name_string(const.name):
                 const = self.named_constraints.pop(const.name)
+            elif const in self.unnamed_constraints:
+                self.unnamed_constraints.remove(const)
+
         except KeyError:
             if _is_type_bound(const):
                 # type-bound constraints are only included in the new
index 8e1144bb07a6224c829d4fc928b553f4b8de76f6..48384f96654f6f809f26f43efc99c529a05ec726 100644 (file)
@@ -130,7 +130,7 @@ class DropConstraintOp(MigrateOperation):
 
     def __init__(
         self,
-        constraint_name: Optional[str],
+        constraint_name: Optional[sqla_compat._ConstraintNameDefined],
         table_name: str,
         type_: Optional[str] = None,
         schema: Optional[str] = None,
@@ -255,7 +255,7 @@ class CreatePrimaryKeyOp(AddConstraintOp):
 
     def __init__(
         self,
-        constraint_name: Optional[str],
+        constraint_name: Optional[sqla_compat._ConstraintNameDefined],
         table_name: str,
         columns: Sequence[str],
         schema: Optional[str] = None,
@@ -379,7 +379,7 @@ class CreateUniqueConstraintOp(AddConstraintOp):
 
     def __init__(
         self,
-        constraint_name: Optional[str],
+        constraint_name: Optional[sqla_compat._ConstraintNameDefined],
         table_name: str,
         columns: Sequence[str],
         schema: Optional[str] = None,
@@ -513,7 +513,7 @@ class CreateForeignKeyOp(AddConstraintOp):
 
     def __init__(
         self,
-        constraint_name: Optional[str],
+        constraint_name: Optional[sqla_compat._ConstraintNameDefined],
         source_table: str,
         referent_table: str,
         local_cols: List[str],
@@ -730,7 +730,7 @@ class CreateCheckConstraintOp(AddConstraintOp):
 
     def __init__(
         self,
-        constraint_name: Optional[str],
+        constraint_name: Optional[sqla_compat._ConstraintNameDefined],
         table_name: str,
         condition: Union[str, TextClause, ColumnElement[Any]],
         schema: Optional[str] = None,
index ba09b3bb1634cc738c58a1c8394b78a54e81fd29..0568471a76d520febeefe8b12e2344b4381de139 100644 (file)
@@ -42,7 +42,7 @@ class SchemaObjects:
 
     def primary_key_constraint(
         self,
-        name: Optional[str],
+        name: Optional[sqla_compat._ConstraintNameDefined],
         table_name: str,
         cols: Sequence[str],
         schema: Optional[str] = None,
@@ -51,14 +51,16 @@ class SchemaObjects:
         m = self.metadata()
         columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
         t = sa_schema.Table(table_name, m, *columns, schema=schema)
+        # SQLAlchemy primary key constraint name arg is wrongly typed on
+        # the SQLAlchemy side through 2.0.5 at least
         p = sa_schema.PrimaryKeyConstraint(
-            *[t.c[n] for n in cols], name=name, **dialect_kw
+            *[t.c[n] for n in cols], name=name, **dialect_kw  # type: ignore
         )
         return p
 
     def foreign_key_constraint(
         self,
-        name: Optional[str],
+        name: Optional[sqla_compat._ConstraintNameDefined],
         source: str,
         referent: str,
         local_cols: List[str],
@@ -115,7 +117,7 @@ class SchemaObjects:
 
     def unique_constraint(
         self,
-        name: Optional[str],
+        name: Optional[sqla_compat._ConstraintNameDefined],
         source: str,
         local_cols: Sequence[str],
         schema: Optional[str] = None,
@@ -136,7 +138,7 @@ class SchemaObjects:
 
     def check_constraint(
         self,
-        name: Optional[str],
+        name: Optional[sqla_compat._ConstraintNameDefined],
         source: str,
         condition: Union[str, TextClause, ColumnElement[Any]],
         schema: Optional[str] = None,
@@ -154,7 +156,7 @@ class SchemaObjects:
 
     def generic_constraint(
         self,
-        name: Optional[str],
+        name: Optional[sqla_compat._ConstraintNameDefined],
         table_name: str,
         type_: Optional[str],
         schema: Optional[str] = None,
index 738bbcbe1f489c3b2d440c989befd663aee955e5..2cc070b66531217d937179cf7080b0df2d2b8b3c 100644 (file)
@@ -27,6 +27,7 @@ from sqlalchemy.sql.elements import ColumnClause
 from sqlalchemy.sql.elements import quoted_name
 from sqlalchemy.sql.elements import TextClause
 from sqlalchemy.sql.elements import UnaryExpression
+from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME
 from sqlalchemy.sql.visitors import traverse
 from typing_extensions import TypeGuard
 
@@ -111,15 +112,28 @@ if sqla_2:
 else:
     from sqlalchemy.util import symbol as _NoneName  # type: ignore[assignment]
 
+
 _ConstraintName = Union[None, str, _NoneName]
 
+_ConstraintNameDefined = Union[str, _NoneName]
+
+
+def constraint_name_defined(
+    name: _ConstraintName,
+) -> TypeGuard[_ConstraintNameDefined]:
+    return name is _NONE_NAME or isinstance(name, (str, _NoneName))
 
-def constraint_name_defined(name: _ConstraintName) -> TypeGuard[str]:
+
+def constraint_name_string(
+    name: _ConstraintName,
+) -> TypeGuard[str]:
     return isinstance(name, str)
 
 
-def constraint_name_or_none(name: _ConstraintName) -> Optional[str]:
-    return name if constraint_name_defined(name) else None
+def constraint_name_or_none(
+    name: _ConstraintName,
+) -> Optional[str]:
+    return name if constraint_name_string(name) else None
 
 
 AUTOINCREMENT_DEFAULT = "auto"
diff --git a/docs/build/unreleased/1195.rst b/docs/build/unreleased/1195.rst
new file mode 100644 (file)
index 0000000..11542f1
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, batch, regression
+    :tickets: 1195
+
+    Fixed regression for 1.10.0 where :class:`.Constraint` objects were
+    suddenly required to have non-None name fields when using batch mode, which
+    was not previously a requirement.
index e0289aa40cac7bb4774f5194d365ef80349f1a92..3b678953e186b0a2f033405aabd7cfc027d9edf0 100644 (file)
@@ -49,6 +49,7 @@ from alembic.testing.fixtures import capture_context_buffer
 from alembic.testing.fixtures import op_fixture
 from alembic.util import CommandError
 from alembic.util import exc as alembic_exc
+from alembic.util.sqla_compat import _NONE_NAME
 from alembic.util.sqla_compat import _safe_commit_connection_transaction
 from alembic.util.sqla_compat import _select
 from alembic.util.sqla_compat import has_computed
@@ -819,6 +820,18 @@ class BatchApplyTest(TestBase):
             ddl_not_contains="CONSTRAINT uq1 UNIQUE",
         )
 
+    def test_add_ck_unnamed(self):
+        """test for #1195"""
+        impl = self._simple_fixture()
+        ck = self.op.schema_obj.check_constraint(_NONE_NAME, "tname", "y > 5")
+
+        impl.add_constraint(ck)
+        self._assert_impl(
+            impl,
+            colnames=["id", "x", "y"],
+            ddl_contains="CHECK (y > 5)",
+        )
+
     def test_add_ck(self):
         impl = self._simple_fixture()
         ck = self.op.schema_obj.check_constraint("ck1", "tname", "y > 5")
@@ -1444,6 +1457,19 @@ class BatchRoundTripTest(TestBase):
             t = Table("hasbool", self.metadata, Column("x", Integer))
             t.create(self.conn)
 
+    def test_add_constraint_type(self):
+        """test for #1195."""
+
+        with self.op.batch_alter_table("foo") as batch_op:
+            batch_op.add_column(Column("q", Boolean(create_constraint=True)))
+        insp = inspect(self.conn)
+
+        assert {
+            c["type"]._type_affinity
+            for c in insp.get_columns("foo")
+            if c["name"] == "q"
+        }.intersection([Boolean, Integer])
+
     def test_change_type_boolean_to_int(self):
         self._boolean_fixture()
         with self.op.batch_alter_table("hasbool") as batch_op: