]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
expand entity_isa to include simple "isa" in poly case
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Jun 2024 18:50:25 +0000 (14:50 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Jun 2024 12:09:37 +0000 (08:09 -0400)
Fixed issue where the :func:`_orm.selectinload` and
:func:`_orm.subqueryload` loader options would fail to take effect when
made against an inherited subclass that itself included a subclass-specific
:paramref:`_orm.Mapper.with_polymorphic` setting.

Fixes: #11446
Change-Id: I2df3ebedbe4aa9da58af99d7729e5f3052ad6abc
(cherry picked from commit 63a903b918343ca312aaded93b7e9af7a88fa3a8)

doc/build/changelog/unreleased_20/11446.rst [new file with mode: 0644]
lib/sqlalchemy/orm/util.py
test/orm/inheritance/test_assorted_poly.py

diff --git a/doc/build/changelog/unreleased_20/11446.rst b/doc/build/changelog/unreleased_20/11446.rst
new file mode 100644 (file)
index 0000000..747230b
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11446
+
+    Fixed issue where the :func:`_orm.selectinload` and
+    :func:`_orm.subqueryload` loader options would fail to take effect when
+    made against an inherited subclass that itself included a subclass-specific
+    :paramref:`_orm.Mapper.with_polymorphic` setting.
index ad2b69ce3135bf0e7fe405e944e3dbe053ccb721..0ab3536dddc34a44da4c7afc1fb124ca240cb8aa 100644 (file)
@@ -2154,7 +2154,7 @@ def _entity_isa(given: _InternalEntityType[Any], mapper: Mapper[Any]) -> bool:
             mapper
         )
     elif given.with_polymorphic_mappers:
-        return mapper in given.with_polymorphic_mappers
+        return mapper in given.with_polymorphic_mappers or given.isa(mapper)
     else:
         return given.isa(mapper)
 
index 49d90f6c4372c2ce14813866aebb542d76cde013..ab06dbaea3d4dc07ce1ce8c1cdafa8faf30d2371 100644 (file)
@@ -32,6 +32,7 @@ from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import sessionmaker
+from sqlalchemy.orm import subqueryload
 from sqlalchemy.orm import with_polymorphic
 from sqlalchemy.orm.interfaces import MANYTOONE
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -3148,3 +3149,177 @@ class MultiOfTypeContainsEagerTest(fixtures.DeclarativeMappedTest):
             head,
             UnitHead(managers=expected_managers),
         )
+
+
+@testing.combinations(
+    (2,),
+    (3,),
+    id_="s",
+    argnames="num_levels",
+)
+@testing.combinations(
+    ("with_poly_star",),
+    ("inline",),
+    ("selectin",),
+    ("none",),
+    id_="s",
+    argnames="wpoly_type",
+)
+class SubclassWithPolyEagerLoadTest(fixtures.DeclarativeMappedTest):
+    """test #11446"""
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class B(Base):
+            __tablename__ = "b"
+            id = Column(Integer, primary_key=True)
+            a_id = Column(ForeignKey("a.id"))
+
+        class A(Base):
+            __tablename__ = "a"
+
+            id = Column(Integer, primary_key=True)
+            type = Column(String(10))
+            bs = relationship("B")
+
+            if cls.wpoly_type == "selectin":
+                __mapper_args__ = {"polymorphic_on": "type"}
+            elif cls.wpoly_type == "inline":
+                __mapper_args__ = {"polymorphic_on": "type"}
+            elif cls.wpoly_type == "with_poly_star":
+                __mapper_args__ = {
+                    "with_polymorphic": "*",
+                    "polymorphic_on": "type",
+                }
+            else:
+                __mapper_args__ = {"polymorphic_on": "type"}
+
+        class ASub(A):
+            __tablename__ = "asub"
+            id = Column(ForeignKey("a.id"), primary_key=True)
+            sub_data = Column(String(10))
+
+            if cls.wpoly_type == "selectin":
+                __mapper_args__ = {
+                    "polymorphic_load": "selectin",
+                    "polymorphic_identity": "asub",
+                }
+            elif cls.wpoly_type == "inline":
+                __mapper_args__ = {
+                    "polymorphic_load": "inline",
+                    "polymorphic_identity": "asub",
+                }
+            elif cls.wpoly_type == "with_poly_star":
+                __mapper_args__ = {
+                    "with_polymorphic": "*",
+                    "polymorphic_identity": "asub",
+                }
+            else:
+                __mapper_args__ = {"polymorphic_identity": "asub"}
+
+        if cls.num_levels == 3:
+
+            class ASubSub(ASub):
+                __tablename__ = "asubsub"
+                id = Column(ForeignKey("asub.id"), primary_key=True)
+                sub_sub_data = Column(String(10))
+
+                if cls.wpoly_type == "selectin":
+                    __mapper_args__ = {
+                        "polymorphic_load": "selectin",
+                        "polymorphic_identity": "asubsub",
+                    }
+                elif cls.wpoly_type == "inline":
+                    __mapper_args__ = {
+                        "polymorphic_load": "inline",
+                        "polymorphic_identity": "asubsub",
+                    }
+                elif cls.wpoly_type == "with_poly_star":
+                    __mapper_args__ = {
+                        "with_polymorphic": "*",
+                        "polymorphic_identity": "asubsub",
+                    }
+                else:
+                    __mapper_args__ = {"polymorphic_identity": "asubsub"}
+
+    @classmethod
+    def insert_data(cls, connection):
+        if cls.num_levels == 3:
+            ASubSub, B = cls.classes("ASubSub", "B")
+
+            with Session(connection) as sess:
+                sess.add_all(
+                    [
+                        ASubSub(
+                            sub_data="sub",
+                            sub_sub_data="subsub",
+                            bs=[B(), B(), B()],
+                        )
+                        for i in range(3)
+                    ]
+                )
+
+                sess.commit()
+        else:
+            ASub, B = cls.classes("ASub", "B")
+
+            with Session(connection) as sess:
+                sess.add_all(
+                    [
+                        ASub(sub_data="sub", bs=[B(), B(), B()])
+                        for i in range(3)
+                    ]
+                )
+                sess.commit()
+
+    @testing.variation("query_from", ["aliased_class", "class_", "parent"])
+    @testing.combinations(selectinload, subqueryload, argnames="loader_fn")
+    def test_thing(self, query_from, loader_fn):
+
+        A = self.classes.A
+
+        if self.num_levels == 2:
+            target = self.classes.ASub
+        elif self.num_levels == 3:
+            target = self.classes.ASubSub
+
+        if query_from.aliased_class:
+            asub_alias = aliased(target)
+            query = select(asub_alias).options(loader_fn(asub_alias.bs))
+        elif query_from.class_:
+            query = select(target).options(loader_fn(A.bs))
+        elif query_from.parent:
+            query = select(A).options(loader_fn(A.bs))
+
+        s = fixture_session()
+
+        # NOTE: this is likely a different bug - setting
+        # polymorphic_load to "inline" and loading from the parent does not
+        # descend to the ASubSub subclass; however "selectin" setting
+        # **does**.   this is inconsistent
+        if (
+            query_from.parent
+            and self.wpoly_type == "inline"
+            and self.num_levels == 3
+        ):
+            # this should ideally be "2"
+            expected_q = 5
+
+        elif query_from.parent and self.wpoly_type == "none":
+            expected_q = 5
+        elif query_from.parent and self.wpoly_type == "selectin":
+            expected_q = 3
+        else:
+            expected_q = 2
+
+        with self.assert_statement_count(testing.db, expected_q):
+            for obj in s.scalars(query):
+                # test both that with_polymorphic loaded
+                eq_(obj.sub_data, "sub")
+                if self.num_levels == 3:
+                    eq_(obj.sub_sub_data, "subsub")
+
+                # as well as the collection eagerly loaded
+                assert obj.bs