From 64a97ab900e5876b8348d8d658bcbc90c31da9c1 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 6 Nov 2023 17:46:14 -0500 Subject: [PATCH] always derive type from element in annotated case Fixed issue where use of :func:`_orm.foreign` annotation on a non-initialized :func:`_orm.mapped_column` construct would produce an expression without a type, which was then not updated at initialization time of the actual column, leading to issues such as relationships not determining ``use_get`` appropriately. Fixes: #10597 Change-Id: I8339ba715ec6bd1f50888f8a424c3ac156e2364f (cherry picked from commit 432eb350a4b81ba557f14d49ebd37cf5899d5423) --- doc/build/changelog/unreleased_20/10597.rst | 10 +++++ lib/sqlalchemy/sql/elements.py | 29 ++++++++++++++ lib/sqlalchemy/sql/schema.py | 2 + .../test_tm_future_annotations_sync.py | 39 +++++++++++++++++++ test/orm/declarative/test_typed_mapping.py | 39 +++++++++++++++++++ test/sql/test_selectable.py | 32 +++++++++++++++ 6 files changed, 151 insertions(+) create mode 100644 doc/build/changelog/unreleased_20/10597.rst diff --git a/doc/build/changelog/unreleased_20/10597.rst b/doc/build/changelog/unreleased_20/10597.rst new file mode 100644 index 0000000000..9764518829 --- /dev/null +++ b/doc/build/changelog/unreleased_20/10597.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, orm + :tickets: 10597 + + Fixed issue where use of :func:`_orm.foreign` annotation on a + non-initialized :func:`_orm.mapped_column` construct would produce an + expression without a type, which was then not updated at initialization + time of the actual column, leading to issues such as relationships not + determining ``use_get`` appropriately. + diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 90ee100aae..48dfd25829 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -5223,6 +5223,20 @@ def _corresponding_column_or_error(fromclause, column, require_embedded=False): return c +class _memoized_property_but_not_nulltype( + util.memoized_property["TypeEngine[_T]"] +): + """memoized property, but dont memoize NullType""" + + def __get__(self, obj, cls): + if obj is None: + return self + result = self.fget(obj) + if not result._isnull: + obj.__dict__[self.__name__] = result + return result + + class AnnotatedColumnElement(Annotated): _Annotated__element: ColumnElement[Any] @@ -5234,6 +5248,7 @@ class AnnotatedColumnElement(Annotated): "_tq_key_label", "_tq_label", "_non_anon_label", + "type", ): self.__dict__.pop(attr, None) for attr in ("name", "key", "table"): @@ -5250,6 +5265,20 @@ class AnnotatedColumnElement(Annotated): """pull 'name' from parent, if not present""" return self._Annotated__element.name + @_memoized_property_but_not_nulltype + def type(self): + """pull 'type' from parent and don't cache if null. + + type is routinely changed on existing columns within the + mapped_column() initialization process, and "type" is also consulted + during the creation of SQL expressions. Therefore it can change after + it was already retrieved. At the same time we don't want annotated + objects having overhead when expressions are produced, so continue + to memoize, but only when we have a non-null type. + + """ + return self._Annotated__element.type + @util.memoized_property def table(self): """pull 'table' from parent, if not present""" diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index c464d7eb0e..d4e3f4cff5 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -2204,6 +2204,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): identity: Optional[Identity] def _set_type(self, type_: TypeEngine[Any]) -> None: + assert self.type._isnull or type_ is self.type + self.type = type_ if isinstance(self.type, SchemaEventTarget): self.type._set_parent_with_dispatch(self) diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index ec5f5e8209..e61900418e 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -62,10 +62,12 @@ from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import declared_attr from sqlalchemy.orm import deferred from sqlalchemy.orm import DynamicMapped +from sqlalchemy.orm import foreign from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import MappedAsDataclass from sqlalchemy.orm import relationship +from sqlalchemy.orm import remote from sqlalchemy.orm import Session from sqlalchemy.orm import undefer from sqlalchemy.orm import WriteOnlyMapped @@ -177,6 +179,43 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_(MyClass.__table__.c.data.type, typ) is_true(MyClass.__table__.c.id.primary_key) + @testing.variation("style", ["none", "lambda_", "string", "direct"]) + def test_foreign_annotation_propagates_correctly(self, decl_base, style): + """test #10597""" + + class Parent(decl_base): + __tablename__ = "parent" + id: Mapped[int] = mapped_column(primary_key=True) + + class Child(decl_base): + __tablename__ = "child" + + name: Mapped[str] = mapped_column(primary_key=True) + + if style.none: + parent_id: Mapped[int] = mapped_column(ForeignKey("parent.id")) + else: + parent_id: Mapped[int] = mapped_column() + + if style.lambda_: + parent: Mapped[Parent] = relationship( + primaryjoin=lambda: remote(Parent.id) + == foreign(Child.parent_id), + ) + elif style.string: + parent: Mapped[Parent] = relationship( + primaryjoin="remote(Parent.id) == " + "foreign(Child.parent_id)", + ) + elif style.direct: + parent: Mapped[Parent] = relationship( + primaryjoin=remote(Parent.id) == foreign(parent_id), + ) + elif style.none: + parent: Mapped[Parent] = relationship() + + assert Child.__mapper__.attrs.parent.strategy.use_get + @testing.combinations( (BIGINT(),), (BIGINT,), diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 6b8becf9c0..8da83ccb9d 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -53,10 +53,12 @@ from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import declared_attr from sqlalchemy.orm import deferred from sqlalchemy.orm import DynamicMapped +from sqlalchemy.orm import foreign from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import MappedAsDataclass from sqlalchemy.orm import relationship +from sqlalchemy.orm import remote from sqlalchemy.orm import Session from sqlalchemy.orm import undefer from sqlalchemy.orm import WriteOnlyMapped @@ -168,6 +170,43 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_(MyClass.__table__.c.data.type, typ) is_true(MyClass.__table__.c.id.primary_key) + @testing.variation("style", ["none", "lambda_", "string", "direct"]) + def test_foreign_annotation_propagates_correctly(self, decl_base, style): + """test #10597""" + + class Parent(decl_base): + __tablename__ = "parent" + id: Mapped[int] = mapped_column(primary_key=True) + + class Child(decl_base): + __tablename__ = "child" + + name: Mapped[str] = mapped_column(primary_key=True) + + if style.none: + parent_id: Mapped[int] = mapped_column(ForeignKey("parent.id")) + else: + parent_id: Mapped[int] = mapped_column() + + if style.lambda_: + parent: Mapped[Parent] = relationship( + primaryjoin=lambda: remote(Parent.id) + == foreign(Child.parent_id), + ) + elif style.string: + parent: Mapped[Parent] = relationship( + primaryjoin="remote(Parent.id) == " + "foreign(Child.parent_id)", + ) + elif style.direct: + parent: Mapped[Parent] = relationship( + primaryjoin=remote(Parent.id) == foreign(parent_id), + ) + elif style.none: + parent: Mapped[Parent] = relationship() + + assert Child.__mapper__.attrs.parent.strategy.use_get + @testing.combinations( (BIGINT(),), (BIGINT,), diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index a146a94c60..d3b7b47841 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -41,6 +41,7 @@ from sqlalchemy.sql import elements from sqlalchemy.sql import LABEL_STYLE_DISAMBIGUATE_ONLY from sqlalchemy.sql import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy.sql import operators +from sqlalchemy.sql import sqltypes from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util from sqlalchemy.sql import visitors @@ -3023,6 +3024,37 @@ class AnnotationsTest(fixtures.TestBase): eq_(whereclause.left._annotations, {"foo": "bar"}) eq_(whereclause.right._annotations, {"foo": "bar"}) + @testing.variation("use_col_ahead_of_time", [True, False]) + def test_set_type_on_column(self, use_col_ahead_of_time): + """test related to #10597""" + + col = Column() + + col_anno = col._annotate({"foo": "bar"}) + + if use_col_ahead_of_time: + expr = col_anno == bindparam("foo") + + # this could only be fixed if we put some kind of a container + # that receives the type directly rather than using NullType; + # like a PendingType or something + + is_(expr.right.type._type_affinity, sqltypes.NullType) + + assert "type" not in col_anno.__dict__ + + col.name = "name" + col._set_type(Integer()) + + eq_(col_anno.name, "name") + is_(col_anno.type._type_affinity, Integer) + + expr = col_anno == bindparam("foo") + + is_(expr.right.type._type_affinity, Integer) + + assert "type" in col_anno.__dict__ + @testing.combinations(True, False, None) def test_setup_inherit_cache(self, inherit_cache_value): if inherit_cache_value is None: -- 2.47.2