]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
include columns from superclasses that indicate "selectin"
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 26 Feb 2023 14:31:36 +0000 (09:31 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 Feb 2023 06:35:26 +0000 (01:35 -0500)
Added support for the :paramref:`_orm.Mapper.polymorphic_load` parameter to
be applied to each mapper in an inheritance hierarchy more than one level
deep, allowing columns to load for all classes in the hierarchy that
indicate ``"selectin"`` using a single statement, rather than ignoring
elements on those intermediary classes that nonetheless indicate they also
would participate in ``"selectin"`` loading and were not part of the
base-most SELECT statement.

Fixes: #9373
Change-Id: If8dcba0f0191f6c2818ecd15870bccfdf5ce1112

doc/build/changelog/unreleased_20/9373.rst [new file with mode: 0644]
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/mapper.py
test/orm/inheritance/test_poly_loading.py

diff --git a/doc/build/changelog/unreleased_20/9373.rst b/doc/build/changelog/unreleased_20/9373.rst
new file mode 100644 (file)
index 0000000..fb726ac
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9373
+
+    Added support for the :paramref:`_orm.Mapper.polymorphic_load` parameter to
+    be applied to each mapper in an inheritance hierarchy more than one level
+    deep, allowing columns to load for all classes in the hierarchy that
+    indicate ``"selectin"`` using a single statement, rather than ignoring
+    elements on those intermediary classes that nonetheless indicate they also
+    would participate in ``"selectin"`` loading and were not part of the
+    base-most SELECT statement.
index 54b96c215f1061a7412c9b851ee99ccec6cae25d..7974d94c5ae7d99e6c5754f1a2c860bbcd9af65d 100644 (file)
@@ -972,17 +972,17 @@ def _instance_processor(
 
     if not refresh_state and _polymorphic_from is not None:
         key = ("loader", path.path)
+
         if key in context.attributes and context.attributes[key].strategy == (
             ("selectinload_polymorphic", True),
         ):
-            selectin_load_via = mapper._should_selectin_load(
-                context.attributes[key].local_opts["entities"],
-                _polymorphic_from,
-            )
+            option_entities = context.attributes[key].local_opts["entities"]
         else:
-            selectin_load_via = mapper._should_selectin_load(
-                None, _polymorphic_from
-            )
+            option_entities = None
+        selectin_load_via = mapper._should_selectin_load(
+            option_entities,
+            _polymorphic_from,
+        )
 
         if selectin_load_via and selectin_load_via is not _polymorphic_from:
             # only_load_props goes w/ refresh_state only, and in a refresh
@@ -990,8 +990,13 @@ def _instance_processor(
             # loading does not apply
             assert only_load_props is None
 
-            callable_ = _load_subclass_via_in(context, path, selectin_load_via)
-
+            callable_ = _load_subclass_via_in(
+                context,
+                path,
+                selectin_load_via,
+                _polymorphic_from,
+                option_entities,
+            )
             PostLoad.callable_for_path(
                 context,
                 load_path,
@@ -1212,17 +1217,42 @@ def _instance_processor(
     return _instance
 
 
-def _load_subclass_via_in(context, path, entity):
+def _load_subclass_via_in(
+    context, path, entity, polymorphic_from, option_entities
+):
     mapper = entity.mapper
 
+    # TODO: polymorphic_from seems to be a Mapper in all cases.
+    # this is likely not needed, but as we dont have typing in loading.py
+    # yet, err on the safe side
+    polymorphic_from_mapper = polymorphic_from.mapper
+    not_against_basemost = polymorphic_from_mapper.inherits is not None
+
     zero_idx = len(mapper.base_mapper.primary_key) == 1
 
-    if entity.is_aliased_class:
-        q, enable_opt, disable_opt = mapper._subclass_load_via_in(entity)
+    if entity.is_aliased_class or not_against_basemost:
+        q, enable_opt, disable_opt = mapper._subclass_load_via_in(
+            entity, polymorphic_from
+        )
     else:
         q, enable_opt, disable_opt = mapper._subclass_load_via_in_mapper
 
     def do_load(context, path, states, load_only, effective_entity):
+        if not option_entities:
+            # filter out states for those that would have selectinloaded
+            # from another loader
+            # TODO: we are currently ignoring the case where the
+            # "selectin_polymorphic" option is used, as this is much more
+            # complex / specific / very uncommon API use
+            states = [
+                (s, v)
+                for s, v in states
+                if s.mapper._would_selectin_load_only_from_given_mapper(mapper)
+            ]
+
+            if not states:
+                return
+
         orig_query = context.query
 
         options = (enable_opt,) + orig_query._with_options + (disable_opt,)
index c0ff2ed10e601a39e27a47abb9b302fbe10c4133..2ae6dadcd8e8db0178362ee6eb2743ca724de555 100644 (file)
@@ -3698,6 +3698,65 @@ class Mapper(
                 if m is mapper:
                     break
 
+    @HasMemoized.memoized_attribute
+    def _would_selectinload_combinations_cache(self):
+        return {}
+
+    def _would_selectin_load_only_from_given_mapper(self, super_mapper):
+        """return True if this mapper would "selectin" polymorphic load based
+        on the given super mapper, and not from a setting from a subclass.
+
+        given::
+
+            class A:
+                ...
+
+            class B(A):
+                __mapper_args__ = {"polymorphic_load": "selectin"}
+
+            class C(B):
+                ...
+
+            class D(B):
+                __mapper_args__ = {"polymorphic_load": "selectin"}
+
+        ``inspect(C)._would_selectin_load_only_from_given_mapper(inspect(B))``
+        returns True, because C does selectin loading because of B's setting.
+
+        OTOH, ``inspect(D)
+        ._would_selectin_load_only_from_given_mapper(inspect(B))``
+        returns False, because D does selectin loading because of its own
+        setting; when we are doing a selectin poly load from B, we want to
+        filter out D because it would already have its own selectin poly load
+        set up separately.
+
+        Added as part of #9373.
+
+        """
+        cache = self._would_selectinload_combinations_cache
+
+        try:
+            return cache[super_mapper]
+        except KeyError:
+            pass
+
+        # assert that given object is a supermapper, meaning we already
+        # strong reference it directly or indirectly.  this allows us
+        # to not worry that we are creating new strongrefs to unrelated
+        # mappers or other objects.
+        assert self.isa(super_mapper)
+
+        mapper = super_mapper
+        for m in self._iterate_to_target_viawpoly(mapper):
+            if m.polymorphic_load == "selectin":
+                retval = m is super_mapper
+                break
+        else:
+            retval = False
+
+        cache[super_mapper] = retval
+        return retval
+
     def _should_selectin_load(self, enabled_via_opt, polymorphic_from):
         if not enabled_via_opt:
             # common case, takes place for all polymorphic loads
@@ -3721,7 +3780,7 @@ class Mapper(
         return None
 
     @util.preload_module("sqlalchemy.orm.strategy_options")
-    def _subclass_load_via_in(self, entity):
+    def _subclass_load_via_in(self, entity, polymorphic_from):
         """Assemble a that can load the columns local to
         this subclass as a SELECT with IN.
 
@@ -3739,6 +3798,16 @@ class Mapper(
         disable_opt = strategy_options.Load(entity)
         enable_opt = strategy_options.Load(entity)
 
+        classes_to_include = {self}
+        m: Optional[Mapper[Any]] = self.inherits
+        while (
+            m is not None
+            and m is not polymorphic_from
+            and m.polymorphic_load == "selectin"
+        ):
+            classes_to_include.add(m)
+            m = m.inherits
+
         for prop in self.attrs:
 
             # skip prop keys that are not instrumented on the mapped class.
@@ -3747,7 +3816,7 @@ class Mapper(
             if prop.key not in self.class_manager:
                 continue
 
-            if prop.parent is self or prop in keep_props:
+            if prop.parent in classes_to_include or prop in keep_props:
                 # "enable" options, to turn on the properties that we want to
                 # load by default (subject to options from the query)
                 if not isinstance(prop, StrategizedProperty):
@@ -3811,7 +3880,8 @@ class Mapper(
 
     @HasMemoized.memoized_attribute
     def _subclass_load_via_in_mapper(self):
-        return self._subclass_load_via_in(self)
+        # the default is loading this mapper against the basemost mapper
+        return self._subclass_load_via_in(self, self.base_mapper)
 
     def cascade_iterator(
         self,
index 9086be3c4ad7e50ddf594c1bf6b6f08ca64605a1..869ee0a8e032da55ddcf97f1e5937a1a703a5d69 100644 (file)
@@ -7,6 +7,7 @@ from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy import union
+from sqlalchemy.orm import aliased
 from sqlalchemy.orm import backref
 from sqlalchemy.orm import column_property
 from sqlalchemy.orm import composite
@@ -28,6 +29,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.assertsql import AllOf
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.assertsql import Conditional
 from sqlalchemy.testing.assertsql import EachOf
 from sqlalchemy.testing.assertsql import Or
 from sqlalchemy.testing.entities import ComparableEntity
@@ -372,6 +374,238 @@ class TestGeometries(GeometryFixtureBase):
         with self.assert_statement_count(testing.db, 0):
             eq_(result, [d(d_data="d1"), e(e_data="e1")])
 
+    @testing.fixture
+    def threelevel_all_selectin_fixture(self):
+        self._fixture_from_geometry(
+            {
+                "a": {
+                    "subclasses": {
+                        "b": {"polymorphic_load": "selectin"},
+                        "c": {
+                            "subclasses": {
+                                "d": {
+                                    "polymorphic_load": "selectin",
+                                },
+                                "e": {
+                                    "polymorphic_load": "selectin",
+                                },
+                                "f": {},
+                            },
+                            "polymorphic_load": "selectin",
+                        },
+                    }
+                }
+            }
+        )
+
+    def test_threelevel_all_selectin_l1_load_l3(
+        self, threelevel_all_selectin_fixture
+    ):
+        """test for #9373 - load base to receive level 3 endpoints"""
+
+        a, b, c, d, e = self.classes("a", "b", "c", "d", "e")
+        sess = fixture_session()
+        sess.add_all(
+            [d(c_data="cd1", d_data="d1"), e(c_data="ce1", e_data="e1")]
+        )
+        sess.commit()
+
+        for i in range(3):
+            sess.close()
+
+            q = sess.query(a)
+
+            result = self.assert_sql_execution(
+                testing.db,
+                q.all,
+                CompiledSQL(
+                    "SELECT a.id AS a_id, a.type AS a_type, "
+                    "a.a_data AS a_a_data FROM a",
+                    {},
+                ),
+                CompiledSQL(
+                    "SELECT d.id AS d_id, c.id AS c_id, a.id AS a_id, "
+                    "a.type AS a_type, c.c_data AS c_c_data, "
+                    "d.d_data AS d_d_data "
+                    "FROM a JOIN c ON a.id = c.id JOIN d ON c.id = d.id "
+                    "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) "
+                    "ORDER BY a.id",
+                    [{"primary_keys": [1]}],
+                ),
+                CompiledSQL(
+                    "SELECT e.id AS e_id, c.id AS c_id, a.id AS a_id, "
+                    "a.type AS a_type, c.c_data AS c_c_data, "
+                    "e.e_data AS e_e_data "
+                    "FROM a JOIN c ON a.id = c.id JOIN e ON c.id = e.id "
+                    "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) "
+                    "ORDER BY a.id",
+                    [{"primary_keys": [2]}],
+                ),
+            )
+            with self.assert_statement_count(testing.db, 0):
+                eq_(
+                    result,
+                    [
+                        d(c_data="cd1", d_data="d1"),
+                        e(c_data="ce1", e_data="e1"),
+                    ],
+                )
+
+    def test_threelevel_partial_selectin_l1_load_l3(
+        self, threelevel_all_selectin_fixture
+    ):
+        """test for #9373 - load base to receive level 3 endpoints"""
+
+        a, b, c, d, f = self.classes("a", "b", "c", "d", "f")
+        sess = fixture_session()
+        sess.add_all(
+            [d(c_data="cd1", d_data="d1"), f(c_data="ce1", f_data="e1")]
+        )
+        sess.commit()
+
+        for i in range(3):
+            sess.close()
+            q = sess.query(a)
+
+            result = self.assert_sql_execution(
+                testing.db,
+                q.all,
+                CompiledSQL(
+                    "SELECT a.id AS a_id, a.type AS a_type, "
+                    "a.a_data AS a_a_data FROM a",
+                    {},
+                ),
+                CompiledSQL(
+                    "SELECT d.id AS d_id, c.id AS c_id, a.id AS a_id, "
+                    "a.type AS a_type, c.c_data AS c_c_data, "
+                    "d.d_data AS d_d_data "
+                    "FROM a JOIN c ON a.id = c.id JOIN d ON c.id = d.id "
+                    "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) "
+                    "ORDER BY a.id",
+                    [{"primary_keys": [1]}],
+                ),
+                # only loads pk 2 - this is the filtering inside of do_load
+                CompiledSQL(
+                    "SELECT c.id AS c_id, a.id AS a_id, a.type AS a_type, "
+                    "c.c_data AS c_c_data "
+                    "FROM a JOIN c ON a.id = c.id "
+                    "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) "
+                    "ORDER BY a.id",
+                    [{"primary_keys": [2]}],
+                ),
+                # no more SQL; if we hit pk 1 again, it would re-do the d here
+            )
+
+            with self.sql_execution_asserter(testing.db) as asserter_:
+                eq_(
+                    result,
+                    [
+                        d(c_data="cd1", d_data="d1"),
+                        f(c_data="ce1", f_data="e1"),
+                    ],
+                )
+
+            # f was told not to load its attrs, so they load here
+            asserter_.assert_(
+                CompiledSQL(
+                    "SELECT f.f_data AS f_f_data FROM f WHERE :param_1 = f.id",
+                    [{"param_1": 2}],
+                ),
+            )
+
+    def test_threelevel_all_selectin_l1_load_l2(
+        self, threelevel_all_selectin_fixture
+    ):
+        """test for #9373 - load base to receive level 2 endpoint"""
+        a, b, c, d, e = self.classes("a", "b", "c", "d", "e")
+        sess = fixture_session()
+        sess.add_all([c(c_data="c1", a_data="a1")])
+        sess.commit()
+
+        q = sess.query(a)
+
+        result = self.assert_sql_execution(
+            testing.db,
+            q.all,
+            CompiledSQL(
+                "SELECT a.id AS a_id, a.type AS a_type, "
+                "a.a_data AS a_a_data FROM a",
+                {},
+            ),
+            CompiledSQL(
+                "SELECT c.id AS c_id, a.id AS a_id, a.type AS a_type, "
+                "c.c_data AS c_c_data FROM a JOIN c ON a.id = c.id "
+                "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) ORDER BY a.id",
+                {"primary_keys": [1]},
+            ),
+        )
+        with self.assert_statement_count(testing.db, 0):
+            eq_(
+                result,
+                [c(c_data="c1", a_data="a1")],
+            )
+
+    @testing.variation("use_aliased_class", [True, False])
+    def test_threelevel_all_selectin_l2_load_l3(
+        self, threelevel_all_selectin_fixture, use_aliased_class
+    ):
+        """test for #9373 - load level 2 endpoing to receive level 3
+        endpoints"""
+        a, b, c, d, e = self.classes("a", "b", "c", "d", "e")
+        sess = fixture_session()
+        sess.add_all(
+            [d(c_data="cd1", d_data="d1"), e(c_data="ce1", e_data="e1")]
+        )
+        sess.commit()
+
+        if use_aliased_class:
+            q = sess.query(aliased(c, flat=True))
+        else:
+            q = sess.query(c)
+        result = self.assert_sql_execution(
+            testing.db,
+            q.all,
+            Conditional(
+                bool(use_aliased_class),
+                [
+                    CompiledSQL(
+                        "SELECT c_1.id AS c_1_id, a_1.id AS a_1_id, "
+                        "a_1.type AS a_1_type, a_1.a_data AS a_1_a_data, "
+                        "c_1.c_data AS c_1_c_data "
+                        "FROM a AS a_1 JOIN c AS c_1 ON a_1.id = c_1.id",
+                        {},
+                    )
+                ],
+                [
+                    CompiledSQL(
+                        "SELECT c.id AS c_id, a.id AS a_id, a.type AS a_type, "
+                        "a.a_data AS a_a_data, c.c_data AS c_c_data "
+                        "FROM a JOIN c ON a.id = c.id",
+                        {},
+                    )
+                ],
+            ),
+            CompiledSQL(
+                "SELECT d.id AS d_id, c.id AS c_id, a.id AS a_id, "
+                "a.type AS a_type, d.d_data AS d_d_data "
+                "FROM a JOIN c ON a.id = c.id JOIN d ON c.id = d.id "
+                "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) ORDER BY a.id",
+                [{"primary_keys": [1]}],
+            ),
+            CompiledSQL(
+                "SELECT e.id AS e_id, c.id AS c_id, a.id AS a_id, "
+                "a.type AS a_type, e.e_data AS e_e_data "
+                "FROM a JOIN c ON a.id = c.id JOIN e ON c.id = e.id "
+                "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) ORDER BY a.id",
+                [{"primary_keys": [2]}],
+            ),
+        )
+        with self.assert_statement_count(testing.db, 0):
+            eq_(
+                result,
+                [d(c_data="cd1", d_data="d1"), e(c_data="ce1", e_data="e1")],
+            )
+
     def test_threelevel_selectin_to_inline_options(self):
         self._fixture_from_geometry(
             {