]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
group together with_polymorphic for single inh criteria
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Mar 2025 16:31:10 +0000 (11:31 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 5 Mar 2025 20:13:14 +0000 (15:13 -0500)
The behavior of :func:`_orm.with_polymorphic` when used with a single
inheritance mapping has been changed such that its behavior should match as
closely as possible to that of an equivalent joined inheritance mapping.
Specifically this means that the base class specified in the
:func:`_orm.with_polymorphic` construct will be the basemost class that is
loaded, as well as all descendant classes of that basemost class.
The change includes that the descendant classes named will no longer be
exclusively indicated in "WHERE polymorphic_col IN" criteria; instead, the
whole hierarchy starting with the given basemost class will be loaded.  If
the query indicates that rows should only be instances of a specific
subclass within the polymorphic hierarchy, an error is raised if an
incompatible superclass is loaded in the result since it cannot be made to
match the requested class; this behavior is the same as what joined
inheritance has done for many years. The change also allows a single result
set to include column-level results from multiple sibling classes at once
which was not previously possible with single table inheritance.

Fixes: #12395
Change-Id: I9307b236a6de8c47e452fb8f982098c54edb811a

doc/build/changelog/unreleased_21/12395.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/mapper.py
test/orm/inheritance/test_single.py

diff --git a/doc/build/changelog/unreleased_21/12395.rst b/doc/build/changelog/unreleased_21/12395.rst
new file mode 100644 (file)
index 0000000..8515db0
--- /dev/null
@@ -0,0 +1,20 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 12395
+
+    The behavior of :func:`_orm.with_polymorphic` when used with a single
+    inheritance mapping has been changed such that its behavior should match as
+    closely as possible to that of an equivalent joined inheritance mapping.
+    Specifically this means that the base class specified in the
+    :func:`_orm.with_polymorphic` construct will be the basemost class that is
+    loaded, as well as all descendant classes of that basemost class.
+    The change includes that the descendant classes named will no longer be
+    exclusively indicated in "WHERE polymorphic_col IN" criteria; instead, the
+    whole hierarchy starting with the given basemost class will be loaded.  If
+    the query indicates that rows should only be instances of a specific
+    subclass within the polymorphic hierarchy, an error is raised if an
+    incompatible superclass is loaded in the result since it cannot be made to
+    match the requested class; this behavior is the same as what joined
+    inheritance has done for many years. The change also allows a single result
+    set to include column-level results from multiple sibling classes at once
+    which was not previously possible with single table inheritance.
index 158a81712b6ba84a0c568e7381a03d376840585c..cfd0ed0f49c7d4612a0cef7c381864ff64c5b85b 100644 (file)
@@ -8,6 +8,7 @@
 
 from __future__ import annotations
 
+import collections
 import itertools
 from typing import Any
 from typing import cast
@@ -2481,31 +2482,83 @@ class _ORMSelectCompileState(_ORMCompileState, SelectState):
                     ext_info._adapter if ext_info.is_aliased_class else None,
                 )
 
-        search = set(self.extra_criteria_entities.values())
+        _where_criteria_to_add = ()
 
-        for ext_info, adapter in search:
+        merged_single_crit = collections.defaultdict(
+            lambda: (util.OrderedSet(), set())
+        )
+
+        for ext_info, adapter in util.OrderedSet(
+            self.extra_criteria_entities.values()
+        ):
             if ext_info in self._join_entities:
                 continue
 
-            single_crit = ext_info.mapper._single_table_criterion
-
-            if self.compile_options._for_refresh_state:
-                additional_entity_criteria = []
+            # assemble single table inheritance criteria.
+            if (
+                ext_info.is_aliased_class
+                and ext_info._base_alias()._is_with_polymorphic
+            ):
+                # for a with_polymorphic(), we always include the full
+                # hierarchy from what's given as the base class for the wpoly.
+                # this is new in 2.1 for #12395 so that it matches the behavior
+                # of joined inheritance.
+                hierarchy_root = ext_info._base_alias()
             else:
-                additional_entity_criteria = self._get_extra_criteria(ext_info)
+                hierarchy_root = ext_info
 
-            if single_crit is not None:
-                additional_entity_criteria += (single_crit,)
+            single_crit_component = (
+                hierarchy_root.mapper._single_table_criteria_component
+            )
 
-            current_adapter = self._get_current_adapter()
-            for crit in additional_entity_criteria:
+            if single_crit_component is not None:
+                polymorphic_on, criteria = single_crit_component
+
+                polymorphic_on = polymorphic_on._annotate(
+                    {
+                        "parententity": hierarchy_root,
+                        "parentmapper": hierarchy_root.mapper,
+                    }
+                )
+
+                list_of_single_crits, adapters = merged_single_crit[
+                    (hierarchy_root, polymorphic_on)
+                ]
+                list_of_single_crits.update(criteria)
                 if adapter:
-                    crit = adapter.traverse(crit)
+                    adapters.add(adapter)
 
-                if current_adapter:
-                    crit = sql_util._deep_annotate(crit, {"_orm_adapt": True})
-                    crit = current_adapter(crit, False)
+            # assemble "additional entity criteria", which come from
+            # with_loader_criteria() options
+            if not self.compile_options._for_refresh_state:
+                additional_entity_criteria = self._get_extra_criteria(ext_info)
+                _where_criteria_to_add += tuple(
+                    adapter.traverse(crit) if adapter else crit
+                    for crit in additional_entity_criteria
+                )
+
+        # merge together single table inheritance criteria keyed to
+        # top-level mapper / aliasedinsp (which may be a with_polymorphic())
+        for (ext_info, polymorphic_on), (
+            merged_crit,
+            adapters,
+        ) in merged_single_crit.items():
+            new_crit = polymorphic_on.in_(merged_crit)
+            for adapter in adapters:
+                new_crit = adapter.traverse(new_crit)
+            _where_criteria_to_add += (new_crit,)
+
+        current_adapter = self._get_current_adapter()
+        if current_adapter:
+            # finally run all the criteria through the "main" adapter, if we
+            # have one, and concatenate to final WHERE criteria
+            for crit in _where_criteria_to_add:
+                crit = sql_util._deep_annotate(crit, {"_orm_adapt": True})
+                crit = current_adapter(crit, False)
                 self._where_criteria += (crit,)
+        else:
+            # else just concatenate our criteria to the final WHERE criteria
+            self._where_criteria += _where_criteria_to_add
 
 
 def _column_descriptions(
@@ -2539,7 +2592,7 @@ def _column_descriptions(
 
 
 def _legacy_filter_by_entity_zero(
-    query_or_augmented_select: Union[Query[Any], Select[Unpack[TupleAny]]]
+    query_or_augmented_select: Union[Query[Any], Select[Unpack[TupleAny]]],
 ) -> Optional[_InternalEntityType[Any]]:
     self = query_or_augmented_select
     if self._setup_joins:
@@ -2554,7 +2607,7 @@ def _legacy_filter_by_entity_zero(
 
 
 def _entity_from_pre_ent_zero(
-    query_or_augmented_select: Union[Query[Any], Select[Unpack[TupleAny]]]
+    query_or_augmented_select: Union[Query[Any], Select[Unpack[TupleAny]]],
 ) -> Optional[_InternalEntityType[Any]]:
     self = query_or_augmented_select
     if not self._raw_columns:
index 3c6821d365683ff99f329fe1ecac3d81c0d23e71..6fb46a2bd81c34c1c42fda38d15cada9f03db0ba 100644 (file)
@@ -2626,17 +2626,29 @@ class Mapper(
             )
 
     @HasMemoized.memoized_attribute
-    def _single_table_criterion(self):
+    def _single_table_criteria_component(self):
         if self.single and self.inherits and self.polymorphic_on is not None:
-            return self.polymorphic_on._annotate(
-                {"parententity": self, "parentmapper": self}
-            ).in_(
-                [
-                    m.polymorphic_identity
-                    for m in self.self_and_descendants
-                    if not m.polymorphic_abstract
-                ]
+
+            hierarchy = tuple(
+                m.polymorphic_identity
+                for m in self.self_and_descendants
+                if not m.polymorphic_abstract
             )
+
+            return (
+                self.polymorphic_on._annotate(
+                    {"parententity": self, "parentmapper": self}
+                ),
+                hierarchy,
+            )
+        else:
+            return None
+
+    @HasMemoized.memoized_attribute
+    def _single_table_criterion(self):
+        component = self._single_table_criteria_component
+        if component is not None:
+            return component[0].in_(component[1])
         else:
             return None
 
index bfdf0b7bcfa3a56dedba4a25a4c8ea9f7c549df0..0f15ac4a511c0300e9544f01ae06bc182a0b124c 100644 (file)
@@ -125,6 +125,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest):
         cls.mapper_registry.map_imperatively(
             Employee,
             employees,
+            polymorphic_identity="employee",
             polymorphic_on=employees.c.type,
             properties={
                 "reports": relationship(Report, back_populates="employee")
@@ -186,7 +187,10 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest):
         assert row.employee_id == e1.employee_id
 
     def test_discrim_bound_param_cloned_ok(self):
-        """Test #6824"""
+        """Test #6824
+
+        note this changes a bit with #12395"""
+
         Manager = self.classes.Manager
 
         subq1 = select(Manager.employee_id).label("foo")
@@ -196,7 +200,8 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest):
             "SELECT (SELECT employees.employee_id FROM employees "
             "WHERE employees.type IN (__[POSTCOMPILE_type_1])) AS foo, "
             "(SELECT employees.employee_id FROM employees "
-            "WHERE employees.type IN (__[POSTCOMPILE_type_1])) AS bar",
+            "WHERE employees.type IN (__[POSTCOMPILE_type_2])) AS bar",
+            checkparams={"type_1": ["manager"], "type_2": ["manager"]},
         )
 
     def test_multi_qualification(self):
@@ -274,6 +279,16 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest):
         # so no result.
         eq_(session.query(Manager.employee_id, Engineer.employee_id).all(), [])
 
+        # however, with #12395, a with_polymorphic will merge the IN
+        # together
+        wp = with_polymorphic(Employee, [Manager, Engineer])
+        eq_(
+            session.query(
+                wp.Manager.employee_id, wp.Engineer.employee_id
+            ).all(),
+            [(m1id, m1id), (e1id, e1id), (e2id, e2id)],
+        )
+
         eq_(scalar(session.query(JuniorEngineer.employee_id)), [e2id])
 
     def test_bundle_qualification(self):
@@ -312,6 +327,16 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest):
             [],
         )
 
+        # however, with #12395, a with_polymorphic will merge the IN
+        # together
+        wp = with_polymorphic(Employee, [Manager, Engineer])
+        eq_(
+            session.query(
+                Bundle("name", wp.Manager.employee_id, wp.Engineer.employee_id)
+            ).all(),
+            [((m1id, m1id),), ((e1id, e1id),), ((e2id, e2id),)],
+        )
+
         eq_(
             scalar(session.query(Bundle("name", JuniorEngineer.employee_id))),
             [e2id],
@@ -831,6 +856,291 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest):
         assert len(rq.join(Report.employee.of_type(Engineer)).all()) == 0
 
 
+class WPolySingleJoinedParityTest:
+    """a suite to test that with_polymorphic behaves identically
+    with joined or single inheritance as of 2.1, issue #12395
+
+    """
+
+    @classmethod
+    def insert_data(cls, connection):
+        Employee, Manager, Engineer, Boss, JuniorEngineer = cls.classes(
+            "Employee", "Manager", "Engineer", "Boss", "JuniorEngineer"
+        )
+        with Session(connection) as session:
+            session.add(Employee(name="Employee 1"))
+            session.add(Manager(name="Manager 1", manager_data="manager data"))
+            session.add(
+                Engineer(name="Engineer 1", engineer_info="engineer_info")
+            )
+            session.add(
+                JuniorEngineer(
+                    name="Junior Engineer 1",
+                    engineer_info="junior info",
+                    junior_name="junior name",
+                )
+            )
+            session.add(Boss(name="Boss 1", manager_data="boss data"))
+
+            session.commit()
+
+    @testing.variation("wpoly_type", ["star", "classes"])
+    def test_with_polymorphic_sibling_classes_base(
+        self, wpoly_type: testing.Variation
+    ):
+        Employee, Manager, Engineer, JuniorEngineer, Boss = self.classes(
+            "Employee", "Manager", "Engineer", "JuniorEngineer", "Boss"
+        )
+
+        if wpoly_type.star:
+            wp = with_polymorphic(Employee, "*")
+        elif wpoly_type.classes:
+            wp = with_polymorphic(
+                Employee, [Manager, Engineer, JuniorEngineer]
+            )
+        else:
+            wpoly_type.fail()
+
+        stmt = select(wp).order_by(wp.id)
+        session = fixture_session()
+        eq_(
+            session.scalars(stmt).all(),
+            [
+                Employee(name="Employee 1"),
+                Manager(name="Manager 1", manager_data="manager data"),
+                Engineer(engineer_info="engineer_info"),
+                JuniorEngineer(engineer_info="junior info"),
+                Boss(name="Boss 1", manager_data="boss data"),
+            ],
+        )
+
+        # this raises, because we get rows that are not Manager or
+        # JuniorEngineer
+
+        stmt = select(wp, wp.Manager, wp.JuniorEngineer).order_by(wp.id)
+        with expect_raises_message(
+            exc.InvalidRequestError,
+            r"Row with identity key \(<.*Employee'>, .*\) can't be loaded "
+            r"into an object; the polymorphic discriminator column "
+            r"'employee.type' refers to Mapper\[Employee\(.*\)\], "
+            r"which is "
+            r"not a sub-mapper of the requested "
+            r"Mapper\[Manager\(.*\)\]",
+        ):
+            session.scalars(stmt).all()
+
+    @testing.variation("wpoly_type", ["star", "classes"])
+    def test_with_polymorphic_sibling_classes_middle(
+        self, wpoly_type: testing.Variation
+    ):
+        Employee, Manager, Engineer, JuniorEngineer = self.classes(
+            "Employee", "Manager", "Engineer", "JuniorEngineer"
+        )
+
+        if wpoly_type.star:
+            wp = with_polymorphic(Engineer, "*")
+        elif wpoly_type.classes:
+            wp = with_polymorphic(Engineer, [Engineer, JuniorEngineer])
+        else:
+            wpoly_type.fail()
+
+        stmt = select(wp).order_by(wp.id)
+
+        session = fixture_session()
+        eq_(
+            session.scalars(stmt).all(),
+            [
+                Engineer(engineer_info="engineer_info"),
+                JuniorEngineer(engineer_info="junior info"),
+            ],
+        )
+
+        # this raises, because we get rows that are not JuniorEngineer
+
+        stmt = select(wp.JuniorEngineer).order_by(wp.id)
+        with expect_raises_message(
+            exc.InvalidRequestError,
+            r"Row with identity key \(<.*Employee'>, .*\) can't be loaded "
+            r"into an object; the polymorphic discriminator column "
+            r"'employee.type' refers to Mapper\[Engineer\(.*\)\], "
+            r"which is "
+            r"not a sub-mapper of the requested "
+            r"Mapper\[JuniorEngineer\(.*\)\]",
+        ):
+            session.scalars(stmt).all()
+
+    @testing.variation("wpoly_type", ["star", "classes"])
+    def test_with_polymorphic_sibling_columns(
+        self, wpoly_type: testing.Variation
+    ):
+        Employee, Manager, Engineer, JuniorEngineer = self.classes(
+            "Employee", "Manager", "Engineer", "JuniorEngineer"
+        )
+
+        if wpoly_type.star:
+            wp = with_polymorphic(Employee, "*")
+        elif wpoly_type.classes:
+            wp = with_polymorphic(Employee, [Manager, Engineer])
+        else:
+            wpoly_type.fail()
+
+        stmt = select(
+            wp.name, wp.Manager.manager_data, wp.Engineer.engineer_info
+        ).order_by(wp.id)
+
+        session = fixture_session()
+
+        eq_(
+            session.execute(stmt).all(),
+            [
+                ("Employee 1", None, None),
+                ("Manager 1", "manager data", None),
+                ("Engineer 1", None, "engineer_info"),
+                ("Junior Engineer 1", None, "junior info"),
+                ("Boss 1", "boss data", None),
+            ],
+        )
+
+    @testing.variation("wpoly_type", ["star", "classes"])
+    def test_with_polymorphic_sibling_columns_middle(
+        self, wpoly_type: testing.Variation
+    ):
+        Employee, Manager, Engineer, JuniorEngineer = self.classes(
+            "Employee", "Manager", "Engineer", "JuniorEngineer"
+        )
+
+        if wpoly_type.star:
+            wp = with_polymorphic(Engineer, "*")
+        elif wpoly_type.classes:
+            wp = with_polymorphic(Engineer, [JuniorEngineer])
+        else:
+            wpoly_type.fail()
+
+        stmt = select(wp.name, wp.engineer_info, wp.JuniorEngineer.junior_name)
+
+        session = fixture_session()
+
+        eq_(
+            session.execute(stmt).all(),
+            [
+                ("Engineer 1", "engineer_info", None),
+                ("Junior Engineer 1", "junior info", "junior name"),
+            ],
+        )
+
+
+class JoinedWPolyParityTest(
+    WPolySingleJoinedParityTest, fixtures.DeclarativeMappedTest
+):
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class Employee(ComparableEntity, Base):
+            __tablename__ = "employee"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            name: Mapped[str]
+            type: Mapped[str]
+
+            __mapper_args__ = {
+                "polymorphic_on": "type",
+                "polymorphic_identity": "employee",
+            }
+
+        class Manager(Employee):
+            __tablename__ = "manager"
+            id = mapped_column(
+                Integer, ForeignKey("employee.id"), primary_key=True
+            )
+            manager_data: Mapped[str] = mapped_column(nullable=True)
+
+            __mapper_args__ = {
+                "polymorphic_identity": "manager",
+            }
+
+        class Boss(Manager):
+            __tablename__ = "boss"
+            id = mapped_column(
+                Integer, ForeignKey("manager.id"), primary_key=True
+            )
+
+            __mapper_args__ = {
+                "polymorphic_identity": "boss",
+            }
+
+        class Engineer(Employee):
+            __tablename__ = "engineer"
+            id = mapped_column(
+                Integer, ForeignKey("employee.id"), primary_key=True
+            )
+            engineer_info: Mapped[str] = mapped_column(nullable=True)
+
+            __mapper_args__ = {
+                "polymorphic_identity": "engineer",
+            }
+
+        class JuniorEngineer(Engineer):
+            __tablename__ = "juniorengineer"
+            id = mapped_column(
+                Integer, ForeignKey("engineer.id"), primary_key=True
+            )
+            junior_name: Mapped[str] = mapped_column(nullable=True)
+            __mapper_args__ = {
+                "polymorphic_identity": "juniorengineer",
+                "polymorphic_load": "inline",
+            }
+
+
+class SingleWPolyParityTest(
+    WPolySingleJoinedParityTest, fixtures.DeclarativeMappedTest
+):
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class Employee(ComparableEntity, Base):
+            __tablename__ = "employee"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            name: Mapped[str]
+            type: Mapped[str]
+
+            __mapper_args__ = {
+                "polymorphic_on": "type",
+                "polymorphic_identity": "employee",
+            }
+
+        class Manager(Employee):
+            manager_data: Mapped[str] = mapped_column(nullable=True)
+
+            __mapper_args__ = {
+                "polymorphic_identity": "manager",
+                "polymorphic_load": "inline",
+            }
+
+        class Boss(Manager):
+
+            __mapper_args__ = {
+                "polymorphic_identity": "boss",
+                "polymorphic_load": "inline",
+            }
+
+        class Engineer(Employee):
+            engineer_info: Mapped[str] = mapped_column(nullable=True)
+
+            __mapper_args__ = {
+                "polymorphic_identity": "engineer",
+                "polymorphic_load": "inline",
+            }
+
+        class JuniorEngineer(Engineer):
+            junior_name: Mapped[str] = mapped_column(nullable=True)
+
+            __mapper_args__ = {
+                "polymorphic_identity": "juniorengineer",
+                "polymorphic_load": "inline",
+            }
+
+
 class RelationshipFromSingleTest(
     testing.AssertsCompiledSQL, fixtures.MappedTest
 ):
@@ -1917,8 +2227,7 @@ class SingleFromPolySelectableTest(
             "engineer.engineer_info AS engineer_engineer_info, "
             "engineer.manager_id AS engineer_manager_id "
             "FROM employee JOIN engineer ON employee.id = engineer.id) "
-            "AS anon_1 "
-            "WHERE anon_1.employee_type IN (__[POSTCOMPILE_type_1])",
+            "AS anon_1",
         )
 
     def test_query_wpoly_single_inh_subclass(self):