]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Column._copy() duplicates "user defined" nullable state exactly
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 21 Aug 2022 15:58:20 +0000 (11:58 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 21 Aug 2022 15:58:20 +0000 (11:58 -0400)
To accommodate how mapped_column() works, after many
attempts to get this working it became clear that _copy()
should just transfer "nullable" state exactly as it was,
including the state where .nullable was set but user_defined_nullable
remains at not user set.

additionally, added a similar step to _merge() that was needed
to preserve the nullability behavior when Identity is present.

server / client default objects are not copied within column._copy()
and this should be fixed.

Fixes: #8410
Change-Id: Ib09df52b71f3e58e67e9f19b893d40a6cc4eec5c

lib/sqlalchemy/sql/schema.py
test/orm/declarative/test_typed_mapping.py
test/sql/test_metadata.py

index 4ed5b9e6b1af52636f238f77de0379dfbd711e8d..3320214a27c9e0c16206832c3dba6050fecfb7c9 100644 (file)
@@ -2241,9 +2241,6 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
         if isinstance(type_, SchemaEventTarget):
             type_ = type_.copy(**kw)
 
-        if self._user_defined_nullable is not NULL_UNSPECIFIED:
-            column_kwargs["nullable"] = self._user_defined_nullable
-
         # TODO: DefaultGenerator is not copied here!  it's just used again
         # with _set_parent() pointing to the old column.  see the new
         # use of _copy() in the new _merge() method
@@ -2267,6 +2264,12 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
             *args,
             **column_kwargs,
         )
+
+        # copy the state of "nullable" exactly, to accommodate for
+        # ORM flipping the .nullable flag directly
+        c.nullable = self.nullable
+        c._user_defined_nullable = self._user_defined_nullable
+
         return self._schema_item_copy(c)
 
     def _merge(self, other: Column[Any]) -> None:
@@ -2300,6 +2303,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
             and other._user_defined_nullable is NULL_UNSPECIFIED
         ):
             other.nullable = self.nullable
+            other._user_defined_nullable = self._user_defined_nullable
 
         if self.default is not None and other.default is None:
             new_default = self.default._copy()
index f31f1f4e29a1f234396ddf78393410bb5443a679..98736cf025cd838ae5c7b6ddb159af92617c0b29 100644 (file)
@@ -19,6 +19,7 @@ from sqlalchemy import DateTime
 from sqlalchemy import exc as sa_exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
+from sqlalchemy import Identity
 from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import Numeric
@@ -437,6 +438,14 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_false(User.__table__.c.lnl_rnnl.nullable)
         is_true(User.__table__.c.lnl_rnl.nullable)
 
+        # test #8410
+        is_false(User.__table__.c.lnnl_rndf._copy().nullable)
+        is_false(User.__table__.c.lnnl_rnnl._copy().nullable)
+        is_true(User.__table__.c.lnnl_rnl._copy().nullable)
+        is_true(User.__table__.c.lnl_rndf._copy().nullable)
+        is_false(User.__table__.c.lnl_rnnl._copy().nullable)
+        is_true(User.__table__.c.lnl_rnl._copy().nullable)
+
     def test_fwd_refs(self, decl_base: Type[DeclarativeBase]):
         class MyClass(decl_base):
             __tablename__ = "my_table"
@@ -652,6 +661,7 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         ("onupdate", func.foo()),
         ("server_onupdate", func.foo()),
         ("server_default", func.foo()),
+        ("server_default", Identity()),
         ("nullable", True),
         ("nullable", False),
         ("type", BigInteger()),
@@ -690,24 +700,89 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             else:
                 data: Mapped[element_ref]
 
+        data_col = Element.__table__.c.data
         if paramname in (
             "default",
             "onupdate",
             "server_default",
             "server_onupdate",
         ):
-            default = getattr(Element.__table__.c.data, paramname)
-            is_(default.arg, value)
-            is_(default.column, Element.__table__.c.data)
+            default = getattr(data_col, paramname)
+            if default.is_server_default and default.has_argument:
+                is_(default.arg, value)
+            is_(default.column, data_col)
         elif paramname == "type":
-            assert type(Element.__table__.c.data.type) is type(value)
+            assert type(data_col.type) is type(value)
+        else:
+            is_(getattr(data_col, paramname), value)
+
+            # test _copy() for #8410
+            is_(getattr(data_col._copy(), paramname), value)
+
+        sd = data_col.server_default
+        if sd is not None and isinstance(sd, Identity):
+            if paramname == "nullable" and value:
+                is_(data_col.nullable, True)
+            else:
+                is_(data_col.nullable, False)
+        elif paramname != "nullable":
+            is_(data_col.nullable, optional)
+        else:
+            is_(data_col.nullable, value)
+
+    @testing.combinations(True, False, argnames="specify_identity")
+    @testing.combinations(True, False, None, argnames="specify_nullable")
+    @testing.combinations(True, False, argnames="optional")
+    @testing.combinations(True, False, argnames="include_existing_col")
+    def test_combine_args_from_pep593_identity_nullable(
+        self,
+        decl_base: Type[DeclarativeBase],
+        specify_identity,
+        specify_nullable,
+        optional,
+        include_existing_col,
+    ):
+        intpk = Annotated[int, mapped_column(primary_key=True)]
+
+        if specify_identity:
+            args = [Identity()]
         else:
-            is_(getattr(Element.__table__.c.data, paramname), value)
+            args = []
 
-        if paramname != "nullable":
-            is_(Element.__table__.c.data.nullable, optional)
+        if specify_nullable is not None:
+            params = {"nullable": specify_nullable}
         else:
-            is_(Element.__table__.c.data.nullable, value)
+            params = {}
+
+        element_ref = Annotated[int, mapped_column(*args, **params)]
+        if optional:
+            element_ref = Optional[element_ref]
+
+        class Element(decl_base):
+            __tablename__ = "element"
+
+            id: Mapped[intpk]
+
+            if include_existing_col:
+                data: Mapped[element_ref] = mapped_column()
+            else:
+                data: Mapped[element_ref]
+
+        # test identity + _copy() for #8410
+        for col in (
+            Element.__table__.c.data,
+            Element.__table__.c.data._copy(),
+        ):
+            if specify_nullable is True:
+                is_(col.nullable, True)
+            elif specify_identity:
+                is_(col.nullable, False)
+            elif specify_nullable is False:
+                is_(col.nullable, False)
+            elif not optional:
+                is_(col.nullable, False)
+            else:
+                is_(col.nullable, True)
 
     @testing.combinations(
         ("default", lambda ctx: 10, lambda ctx: 15),
index b7913e6068da4bd16b6b8e5e6cfd9580a977f0a0..7131476be277890ac83f2440fc23ffa3087c5f35 100644 (file)
@@ -4223,31 +4223,83 @@ class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase):
 
         source._merge(target)
 
-        if isinstance(value, (Computed, Identity)):
-            default = target.server_default
-            assert isinstance(default, type(value))
-        elif isinstance(value, Sequence):
-            default = target.default
-            assert isinstance(default, type(value))
-
-        elif paramname in (
-            "default",
-            "onupdate",
-            "server_default",
-            "server_onupdate",
+        target_copy = target._copy()
+        for col in (
+            target,
+            target_copy,
         ):
-            default = getattr(target, paramname)
-            is_(default.arg, value)
-            is_(default.column, target)
-        elif paramname == "type":
-            assert type(target.type) is type(value)
+            if isinstance(value, (Computed, Identity)):
+                default = col.server_default
+                assert isinstance(default, type(value))
+                is_(default.column, col)
+            elif isinstance(value, Sequence):
+                default = col.default
+
+                # TODO: sequence mutated in place
+                is_(default.column, target_copy)
+
+                assert isinstance(default, type(value))
+
+            elif paramname in (
+                "default",
+                "onupdate",
+                "server_default",
+                "server_onupdate",
+            ):
+                default = getattr(col, paramname)
+                is_(default.arg, value)
 
-            if isinstance(target.type, Enum):
-                target.name = "data"
-                t = Table("t", MetaData(), target)
-                assert CheckConstraint in [type(c) for c in t.constraints]
+                # TODO: _copy() seems to note that it isn't copying
+                # server defaults or defaults outside of Computed, Identity,
+                # so here it's getting mutated in place.   this is a bug
+                is_(default.column, target_copy)
+
+            elif paramname == "type":
+                assert type(col.type) is type(value)
+
+                if isinstance(col.type, Enum):
+                    col.name = "data"
+                    t = Table("t", MetaData(), col)
+                    assert CheckConstraint in [type(c) for c in t.constraints]
+            else:
+                is_(getattr(col, paramname), value)
+
+    @testing.combinations(True, False, argnames="specify_identity")
+    @testing.combinations(True, False, None, argnames="specify_nullable")
+    def test_merge_column_identity(
+        self,
+        specify_identity,
+        specify_nullable,
+    ):
+        if specify_identity:
+            args = [Identity()]
         else:
-            is_(getattr(target, paramname), value)
+            args = []
+
+        if specify_nullable is not None:
+            params = {"nullable": specify_nullable}
+        else:
+            params = {}
+
+        source = Column(*args, **params)
+
+        target = Column()
+
+        source._merge(target)
+
+        # test identity + _copy() for #8410
+        for col in (
+            target,
+            target._copy(),
+        ):
+            if specify_nullable is True:
+                is_(col.nullable, True)
+            elif specify_identity:
+                is_(col.nullable, False)
+            elif specify_nullable is False:
+                is_(col.nullable, False)
+            else:
+                is_(col.nullable, True)
 
     @testing.combinations(
         ("default", lambda ctx: 10, lambda ctx: 15),