From: Mike Bayer Date: Sun, 21 Aug 2022 15:58:20 +0000 (-0400) Subject: Column._copy() duplicates "user defined" nullable state exactly X-Git-Tag: rel_2_0_0b1~99 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7842678484b9d00a64fab29cdc9e252754ac19ae;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Column._copy() duplicates "user defined" nullable state exactly To accommodate how mapped_column() works, after many attempts to get this working it became clear that _copy() should just transfer "nullable" state exactly as it was, including the state where .nullable was set but user_defined_nullable remains at not user set. additionally, added a similar step to _merge() that was needed to preserve the nullability behavior when Identity is present. server / client default objects are not copied within column._copy() and this should be fixed. Fixes: #8410 Change-Id: Ib09df52b71f3e58e67e9f19b893d40a6cc4eec5c --- diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 4ed5b9e6b1..3320214a27 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -2241,9 +2241,6 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): if isinstance(type_, SchemaEventTarget): type_ = type_.copy(**kw) - if self._user_defined_nullable is not NULL_UNSPECIFIED: - column_kwargs["nullable"] = self._user_defined_nullable - # TODO: DefaultGenerator is not copied here! it's just used again # with _set_parent() pointing to the old column. see the new # use of _copy() in the new _merge() method @@ -2267,6 +2264,12 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): *args, **column_kwargs, ) + + # copy the state of "nullable" exactly, to accommodate for + # ORM flipping the .nullable flag directly + c.nullable = self.nullable + c._user_defined_nullable = self._user_defined_nullable + return self._schema_item_copy(c) def _merge(self, other: Column[Any]) -> None: @@ -2300,6 +2303,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): and other._user_defined_nullable is NULL_UNSPECIFIED ): other.nullable = self.nullable + other._user_defined_nullable = self._user_defined_nullable if self.default is not None and other.default is None: new_default = self.default._copy() diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index f31f1f4e29..98736cf025 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -19,6 +19,7 @@ from sqlalchemy import DateTime from sqlalchemy import exc as sa_exc from sqlalchemy import ForeignKey from sqlalchemy import func +from sqlalchemy import Identity from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import Numeric @@ -437,6 +438,14 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_false(User.__table__.c.lnl_rnnl.nullable) is_true(User.__table__.c.lnl_rnl.nullable) + # test #8410 + is_false(User.__table__.c.lnnl_rndf._copy().nullable) + is_false(User.__table__.c.lnnl_rnnl._copy().nullable) + is_true(User.__table__.c.lnnl_rnl._copy().nullable) + is_true(User.__table__.c.lnl_rndf._copy().nullable) + is_false(User.__table__.c.lnl_rnnl._copy().nullable) + is_true(User.__table__.c.lnl_rnl._copy().nullable) + def test_fwd_refs(self, decl_base: Type[DeclarativeBase]): class MyClass(decl_base): __tablename__ = "my_table" @@ -652,6 +661,7 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): ("onupdate", func.foo()), ("server_onupdate", func.foo()), ("server_default", func.foo()), + ("server_default", Identity()), ("nullable", True), ("nullable", False), ("type", BigInteger()), @@ -690,24 +700,89 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): else: data: Mapped[element_ref] + data_col = Element.__table__.c.data if paramname in ( "default", "onupdate", "server_default", "server_onupdate", ): - default = getattr(Element.__table__.c.data, paramname) - is_(default.arg, value) - is_(default.column, Element.__table__.c.data) + default = getattr(data_col, paramname) + if default.is_server_default and default.has_argument: + is_(default.arg, value) + is_(default.column, data_col) elif paramname == "type": - assert type(Element.__table__.c.data.type) is type(value) + assert type(data_col.type) is type(value) + else: + is_(getattr(data_col, paramname), value) + + # test _copy() for #8410 + is_(getattr(data_col._copy(), paramname), value) + + sd = data_col.server_default + if sd is not None and isinstance(sd, Identity): + if paramname == "nullable" and value: + is_(data_col.nullable, True) + else: + is_(data_col.nullable, False) + elif paramname != "nullable": + is_(data_col.nullable, optional) + else: + is_(data_col.nullable, value) + + @testing.combinations(True, False, argnames="specify_identity") + @testing.combinations(True, False, None, argnames="specify_nullable") + @testing.combinations(True, False, argnames="optional") + @testing.combinations(True, False, argnames="include_existing_col") + def test_combine_args_from_pep593_identity_nullable( + self, + decl_base: Type[DeclarativeBase], + specify_identity, + specify_nullable, + optional, + include_existing_col, + ): + intpk = Annotated[int, mapped_column(primary_key=True)] + + if specify_identity: + args = [Identity()] else: - is_(getattr(Element.__table__.c.data, paramname), value) + args = [] - if paramname != "nullable": - is_(Element.__table__.c.data.nullable, optional) + if specify_nullable is not None: + params = {"nullable": specify_nullable} else: - is_(Element.__table__.c.data.nullable, value) + params = {} + + element_ref = Annotated[int, mapped_column(*args, **params)] + if optional: + element_ref = Optional[element_ref] + + class Element(decl_base): + __tablename__ = "element" + + id: Mapped[intpk] + + if include_existing_col: + data: Mapped[element_ref] = mapped_column() + else: + data: Mapped[element_ref] + + # test identity + _copy() for #8410 + for col in ( + Element.__table__.c.data, + Element.__table__.c.data._copy(), + ): + if specify_nullable is True: + is_(col.nullable, True) + elif specify_identity: + is_(col.nullable, False) + elif specify_nullable is False: + is_(col.nullable, False) + elif not optional: + is_(col.nullable, False) + else: + is_(col.nullable, True) @testing.combinations( ("default", lambda ctx: 10, lambda ctx: 15), diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index b7913e6068..7131476be2 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -4223,31 +4223,83 @@ class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase): source._merge(target) - if isinstance(value, (Computed, Identity)): - default = target.server_default - assert isinstance(default, type(value)) - elif isinstance(value, Sequence): - default = target.default - assert isinstance(default, type(value)) - - elif paramname in ( - "default", - "onupdate", - "server_default", - "server_onupdate", + target_copy = target._copy() + for col in ( + target, + target_copy, ): - default = getattr(target, paramname) - is_(default.arg, value) - is_(default.column, target) - elif paramname == "type": - assert type(target.type) is type(value) + if isinstance(value, (Computed, Identity)): + default = col.server_default + assert isinstance(default, type(value)) + is_(default.column, col) + elif isinstance(value, Sequence): + default = col.default + + # TODO: sequence mutated in place + is_(default.column, target_copy) + + assert isinstance(default, type(value)) + + elif paramname in ( + "default", + "onupdate", + "server_default", + "server_onupdate", + ): + default = getattr(col, paramname) + is_(default.arg, value) - if isinstance(target.type, Enum): - target.name = "data" - t = Table("t", MetaData(), target) - assert CheckConstraint in [type(c) for c in t.constraints] + # TODO: _copy() seems to note that it isn't copying + # server defaults or defaults outside of Computed, Identity, + # so here it's getting mutated in place. this is a bug + is_(default.column, target_copy) + + elif paramname == "type": + assert type(col.type) is type(value) + + if isinstance(col.type, Enum): + col.name = "data" + t = Table("t", MetaData(), col) + assert CheckConstraint in [type(c) for c in t.constraints] + else: + is_(getattr(col, paramname), value) + + @testing.combinations(True, False, argnames="specify_identity") + @testing.combinations(True, False, None, argnames="specify_nullable") + def test_merge_column_identity( + self, + specify_identity, + specify_nullable, + ): + if specify_identity: + args = [Identity()] else: - is_(getattr(target, paramname), value) + args = [] + + if specify_nullable is not None: + params = {"nullable": specify_nullable} + else: + params = {} + + source = Column(*args, **params) + + target = Column() + + source._merge(target) + + # test identity + _copy() for #8410 + for col in ( + target, + target._copy(), + ): + if specify_nullable is True: + is_(col.nullable, True) + elif specify_identity: + is_(col.nullable, False) + elif specify_nullable is False: + is_(col.nullable, False) + else: + is_(col.nullable, True) @testing.combinations( ("default", lambda ctx: 10, lambda ctx: 15),