]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
always derive type from element in annotated case
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 6 Nov 2023 22:46:14 +0000 (17:46 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 7 Nov 2023 00:37:49 +0000 (19:37 -0500)
Fixed issue where use of :func:`_orm.foreign` annotation on a
non-initialized :func:`_orm.mapped_column` construct would produce an
expression without a type, which was then not updated at initialization
time of the actual column, leading to issues such as relationships not
determining ``use_get`` appropriately.

Fixes: #10597
Change-Id: I8339ba715ec6bd1f50888f8a424c3ac156e2364f

doc/build/changelog/unreleased_20/10597.rst [new file with mode: 0644]
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/schema.py
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py
test/sql/test_selectable.py

diff --git a/doc/build/changelog/unreleased_20/10597.rst b/doc/build/changelog/unreleased_20/10597.rst
new file mode 100644 (file)
index 0000000..9764518
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 10597
+
+    Fixed issue where use of :func:`_orm.foreign` annotation on a
+    non-initialized :func:`_orm.mapped_column` construct would produce an
+    expression without a type, which was then not updated at initialization
+    time of the actual column, leading to issues such as relationships not
+    determining ``use_get`` appropriately.
+
index 90ee100aae0f13eb0c929fb7d23f977cfe5187c0..48dfd25829a30213c8e6b267491f1246693052ca 100644 (file)
@@ -5223,6 +5223,20 @@ def _corresponding_column_or_error(fromclause, column, require_embedded=False):
     return c
 
 
+class _memoized_property_but_not_nulltype(
+    util.memoized_property["TypeEngine[_T]"]
+):
+    """memoized property, but dont memoize NullType"""
+
+    def __get__(self, obj, cls):
+        if obj is None:
+            return self
+        result = self.fget(obj)
+        if not result._isnull:
+            obj.__dict__[self.__name__] = result
+        return result
+
+
 class AnnotatedColumnElement(Annotated):
     _Annotated__element: ColumnElement[Any]
 
@@ -5234,6 +5248,7 @@ class AnnotatedColumnElement(Annotated):
             "_tq_key_label",
             "_tq_label",
             "_non_anon_label",
+            "type",
         ):
             self.__dict__.pop(attr, None)
         for attr in ("name", "key", "table"):
@@ -5250,6 +5265,20 @@ class AnnotatedColumnElement(Annotated):
         """pull 'name' from parent, if not present"""
         return self._Annotated__element.name
 
+    @_memoized_property_but_not_nulltype
+    def type(self):
+        """pull 'type' from parent and don't cache if null.
+
+        type is routinely changed on existing columns within the
+        mapped_column() initialization process, and "type" is also consulted
+        during the creation of SQL expressions.  Therefore it can change after
+        it was already retrieved.  At the same time we don't want annotated
+        objects having overhead when expressions are produced, so continue
+        to memoize, but only when we have a non-null type.
+
+        """
+        return self._Annotated__element.type
+
     @util.memoized_property
     def table(self):
         """pull 'table' from parent, if not present"""
index c464d7eb0ea91d621f226ef131cf812d904ce526..d4e3f4cff515e52c3567781f7ea46ea6fbcdca49 100644 (file)
@@ -2204,6 +2204,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
     identity: Optional[Identity]
 
     def _set_type(self, type_: TypeEngine[Any]) -> None:
+        assert self.type._isnull or type_ is self.type
+
         self.type = type_
         if isinstance(self.type, SchemaEventTarget):
             self.type._set_parent_with_dispatch(self)
index ec5f5e8209775a280d9d169e35a4c778dbcb8729..e61900418e2e15b93ae800d25fe4913ba18ebc60 100644 (file)
@@ -62,10 +62,12 @@ from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import declared_attr
 from sqlalchemy.orm import deferred
 from sqlalchemy.orm import DynamicMapped
+from sqlalchemy.orm import foreign
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import MappedAsDataclass
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import remote
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import undefer
 from sqlalchemy.orm import WriteOnlyMapped
@@ -177,6 +179,43 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_(MyClass.__table__.c.data.type, typ)
         is_true(MyClass.__table__.c.id.primary_key)
 
+    @testing.variation("style", ["none", "lambda_", "string", "direct"])
+    def test_foreign_annotation_propagates_correctly(self, decl_base, style):
+        """test #10597"""
+
+        class Parent(decl_base):
+            __tablename__ = "parent"
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+        class Child(decl_base):
+            __tablename__ = "child"
+
+            name: Mapped[str] = mapped_column(primary_key=True)
+
+            if style.none:
+                parent_id: Mapped[int] = mapped_column(ForeignKey("parent.id"))
+            else:
+                parent_id: Mapped[int] = mapped_column()
+
+            if style.lambda_:
+                parent: Mapped[Parent] = relationship(
+                    primaryjoin=lambda: remote(Parent.id)
+                    == foreign(Child.parent_id),
+                )
+            elif style.string:
+                parent: Mapped[Parent] = relationship(
+                    primaryjoin="remote(Parent.id) == "
+                    "foreign(Child.parent_id)",
+                )
+            elif style.direct:
+                parent: Mapped[Parent] = relationship(
+                    primaryjoin=remote(Parent.id) == foreign(parent_id),
+                )
+            elif style.none:
+                parent: Mapped[Parent] = relationship()
+
+        assert Child.__mapper__.attrs.parent.strategy.use_get
+
     @testing.combinations(
         (BIGINT(),),
         (BIGINT,),
index 6b8becf9c02be7e68aa09972b6c52142732d8232..8da83ccb9d6cf25a4558a1d68315255ed7451a6a 100644 (file)
@@ -53,10 +53,12 @@ from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import declared_attr
 from sqlalchemy.orm import deferred
 from sqlalchemy.orm import DynamicMapped
+from sqlalchemy.orm import foreign
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import MappedAsDataclass
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import remote
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import undefer
 from sqlalchemy.orm import WriteOnlyMapped
@@ -168,6 +170,43 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_(MyClass.__table__.c.data.type, typ)
         is_true(MyClass.__table__.c.id.primary_key)
 
+    @testing.variation("style", ["none", "lambda_", "string", "direct"])
+    def test_foreign_annotation_propagates_correctly(self, decl_base, style):
+        """test #10597"""
+
+        class Parent(decl_base):
+            __tablename__ = "parent"
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+        class Child(decl_base):
+            __tablename__ = "child"
+
+            name: Mapped[str] = mapped_column(primary_key=True)
+
+            if style.none:
+                parent_id: Mapped[int] = mapped_column(ForeignKey("parent.id"))
+            else:
+                parent_id: Mapped[int] = mapped_column()
+
+            if style.lambda_:
+                parent: Mapped[Parent] = relationship(
+                    primaryjoin=lambda: remote(Parent.id)
+                    == foreign(Child.parent_id),
+                )
+            elif style.string:
+                parent: Mapped[Parent] = relationship(
+                    primaryjoin="remote(Parent.id) == "
+                    "foreign(Child.parent_id)",
+                )
+            elif style.direct:
+                parent: Mapped[Parent] = relationship(
+                    primaryjoin=remote(Parent.id) == foreign(parent_id),
+                )
+            elif style.none:
+                parent: Mapped[Parent] = relationship()
+
+        assert Child.__mapper__.attrs.parent.strategy.use_get
+
     @testing.combinations(
         (BIGINT(),),
         (BIGINT,),
index a146a94c6003eb7fbc63af0cbe1950b85120d673..d3b7b47841fee70f9762b5d5256ce0f41ecca7f6 100644 (file)
@@ -41,6 +41,7 @@ from sqlalchemy.sql import elements
 from sqlalchemy.sql import LABEL_STYLE_DISAMBIGUATE_ONLY
 from sqlalchemy.sql import LABEL_STYLE_TABLENAME_PLUS_COL
 from sqlalchemy.sql import operators
+from sqlalchemy.sql import sqltypes
 from sqlalchemy.sql import table
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql import visitors
@@ -3023,6 +3024,37 @@ class AnnotationsTest(fixtures.TestBase):
         eq_(whereclause.left._annotations, {"foo": "bar"})
         eq_(whereclause.right._annotations, {"foo": "bar"})
 
+    @testing.variation("use_col_ahead_of_time", [True, False])
+    def test_set_type_on_column(self, use_col_ahead_of_time):
+        """test related to #10597"""
+
+        col = Column()
+
+        col_anno = col._annotate({"foo": "bar"})
+
+        if use_col_ahead_of_time:
+            expr = col_anno == bindparam("foo")
+
+            # this could only be fixed if we put some kind of a container
+            # that receives the type directly rather than using NullType;
+            # like a PendingType or something
+
+            is_(expr.right.type._type_affinity, sqltypes.NullType)
+
+        assert "type" not in col_anno.__dict__
+
+        col.name = "name"
+        col._set_type(Integer())
+
+        eq_(col_anno.name, "name")
+        is_(col_anno.type._type_affinity, Integer)
+
+        expr = col_anno == bindparam("foo")
+
+        is_(expr.right.type._type_affinity, Integer)
+
+        assert "type" in col_anno.__dict__
+
     @testing.combinations(True, False, None)
     def test_setup_inherit_cache(self, inherit_cache_value):
         if inherit_cache_value is None: