From 40b00498e62d3bf10f75874852bab6d6e0e3a09a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 26 Feb 2023 09:31:36 -0500 Subject: [PATCH] include columns from superclasses that indicate "selectin" Added support for the :paramref:`_orm.Mapper.polymorphic_load` parameter to be applied to each mapper in an inheritance hierarchy more than one level deep, allowing columns to load for all classes in the hierarchy that indicate ``"selectin"`` using a single statement, rather than ignoring elements on those intermediary classes that nonetheless indicate they also would participate in ``"selectin"`` loading and were not part of the base-most SELECT statement. Fixes: #9373 Change-Id: If8dcba0f0191f6c2818ecd15870bccfdf5ce1112 --- doc/build/changelog/unreleased_20/9373.rst | 11 + lib/sqlalchemy/orm/loading.py | 54 +++-- lib/sqlalchemy/orm/mapper.py | 76 ++++++- test/orm/inheritance/test_poly_loading.py | 234 +++++++++++++++++++++ 4 files changed, 360 insertions(+), 15 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/9373.rst diff --git a/doc/build/changelog/unreleased_20/9373.rst b/doc/build/changelog/unreleased_20/9373.rst new file mode 100644 index 0000000000..fb726accbe --- /dev/null +++ b/doc/build/changelog/unreleased_20/9373.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, orm + :tickets: 9373 + + Added support for the :paramref:`_orm.Mapper.polymorphic_load` parameter to + be applied to each mapper in an inheritance hierarchy more than one level + deep, allowing columns to load for all classes in the hierarchy that + indicate ``"selectin"`` using a single statement, rather than ignoring + elements on those intermediary classes that nonetheless indicate they also + would participate in ``"selectin"`` loading and were not part of the + base-most SELECT statement. diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 54b96c215f..7974d94c5a 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -972,17 +972,17 @@ def _instance_processor( if not refresh_state and _polymorphic_from is not None: key = ("loader", path.path) + if key in context.attributes and context.attributes[key].strategy == ( ("selectinload_polymorphic", True), ): - selectin_load_via = mapper._should_selectin_load( - context.attributes[key].local_opts["entities"], - _polymorphic_from, - ) + option_entities = context.attributes[key].local_opts["entities"] else: - selectin_load_via = mapper._should_selectin_load( - None, _polymorphic_from - ) + option_entities = None + selectin_load_via = mapper._should_selectin_load( + option_entities, + _polymorphic_from, + ) if selectin_load_via and selectin_load_via is not _polymorphic_from: # only_load_props goes w/ refresh_state only, and in a refresh @@ -990,8 +990,13 @@ def _instance_processor( # loading does not apply assert only_load_props is None - callable_ = _load_subclass_via_in(context, path, selectin_load_via) - + callable_ = _load_subclass_via_in( + context, + path, + selectin_load_via, + _polymorphic_from, + option_entities, + ) PostLoad.callable_for_path( context, load_path, @@ -1212,17 +1217,42 @@ def _instance_processor( return _instance -def _load_subclass_via_in(context, path, entity): +def _load_subclass_via_in( + context, path, entity, polymorphic_from, option_entities +): mapper = entity.mapper + # TODO: polymorphic_from seems to be a Mapper in all cases. + # this is likely not needed, but as we dont have typing in loading.py + # yet, err on the safe side + polymorphic_from_mapper = polymorphic_from.mapper + not_against_basemost = polymorphic_from_mapper.inherits is not None + zero_idx = len(mapper.base_mapper.primary_key) == 1 - if entity.is_aliased_class: - q, enable_opt, disable_opt = mapper._subclass_load_via_in(entity) + if entity.is_aliased_class or not_against_basemost: + q, enable_opt, disable_opt = mapper._subclass_load_via_in( + entity, polymorphic_from + ) else: q, enable_opt, disable_opt = mapper._subclass_load_via_in_mapper def do_load(context, path, states, load_only, effective_entity): + if not option_entities: + # filter out states for those that would have selectinloaded + # from another loader + # TODO: we are currently ignoring the case where the + # "selectin_polymorphic" option is used, as this is much more + # complex / specific / very uncommon API use + states = [ + (s, v) + for s, v in states + if s.mapper._would_selectin_load_only_from_given_mapper(mapper) + ] + + if not states: + return + orig_query = context.query options = (enable_opt,) + orig_query._with_options + (disable_opt,) diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index c0ff2ed10e..2ae6dadcd8 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -3698,6 +3698,65 @@ class Mapper( if m is mapper: break + @HasMemoized.memoized_attribute + def _would_selectinload_combinations_cache(self): + return {} + + def _would_selectin_load_only_from_given_mapper(self, super_mapper): + """return True if this mapper would "selectin" polymorphic load based + on the given super mapper, and not from a setting from a subclass. + + given:: + + class A: + ... + + class B(A): + __mapper_args__ = {"polymorphic_load": "selectin"} + + class C(B): + ... + + class D(B): + __mapper_args__ = {"polymorphic_load": "selectin"} + + ``inspect(C)._would_selectin_load_only_from_given_mapper(inspect(B))`` + returns True, because C does selectin loading because of B's setting. + + OTOH, ``inspect(D) + ._would_selectin_load_only_from_given_mapper(inspect(B))`` + returns False, because D does selectin loading because of its own + setting; when we are doing a selectin poly load from B, we want to + filter out D because it would already have its own selectin poly load + set up separately. + + Added as part of #9373. + + """ + cache = self._would_selectinload_combinations_cache + + try: + return cache[super_mapper] + except KeyError: + pass + + # assert that given object is a supermapper, meaning we already + # strong reference it directly or indirectly. this allows us + # to not worry that we are creating new strongrefs to unrelated + # mappers or other objects. + assert self.isa(super_mapper) + + mapper = super_mapper + for m in self._iterate_to_target_viawpoly(mapper): + if m.polymorphic_load == "selectin": + retval = m is super_mapper + break + else: + retval = False + + cache[super_mapper] = retval + return retval + def _should_selectin_load(self, enabled_via_opt, polymorphic_from): if not enabled_via_opt: # common case, takes place for all polymorphic loads @@ -3721,7 +3780,7 @@ class Mapper( return None @util.preload_module("sqlalchemy.orm.strategy_options") - def _subclass_load_via_in(self, entity): + def _subclass_load_via_in(self, entity, polymorphic_from): """Assemble a that can load the columns local to this subclass as a SELECT with IN. @@ -3739,6 +3798,16 @@ class Mapper( disable_opt = strategy_options.Load(entity) enable_opt = strategy_options.Load(entity) + classes_to_include = {self} + m: Optional[Mapper[Any]] = self.inherits + while ( + m is not None + and m is not polymorphic_from + and m.polymorphic_load == "selectin" + ): + classes_to_include.add(m) + m = m.inherits + for prop in self.attrs: # skip prop keys that are not instrumented on the mapped class. @@ -3747,7 +3816,7 @@ class Mapper( if prop.key not in self.class_manager: continue - if prop.parent is self or prop in keep_props: + if prop.parent in classes_to_include or prop in keep_props: # "enable" options, to turn on the properties that we want to # load by default (subject to options from the query) if not isinstance(prop, StrategizedProperty): @@ -3811,7 +3880,8 @@ class Mapper( @HasMemoized.memoized_attribute def _subclass_load_via_in_mapper(self): - return self._subclass_load_via_in(self) + # the default is loading this mapper against the basemost mapper + return self._subclass_load_via_in(self, self.base_mapper) def cascade_iterator( self, diff --git a/test/orm/inheritance/test_poly_loading.py b/test/orm/inheritance/test_poly_loading.py index 9086be3c4a..869ee0a8e0 100644 --- a/test/orm/inheritance/test_poly_loading.py +++ b/test/orm/inheritance/test_poly_loading.py @@ -7,6 +7,7 @@ from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import union +from sqlalchemy.orm import aliased from sqlalchemy.orm import backref from sqlalchemy.orm import column_property from sqlalchemy.orm import composite @@ -28,6 +29,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.assertsql import AllOf from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.assertsql import Conditional from sqlalchemy.testing.assertsql import EachOf from sqlalchemy.testing.assertsql import Or from sqlalchemy.testing.entities import ComparableEntity @@ -372,6 +374,238 @@ class TestGeometries(GeometryFixtureBase): with self.assert_statement_count(testing.db, 0): eq_(result, [d(d_data="d1"), e(e_data="e1")]) + @testing.fixture + def threelevel_all_selectin_fixture(self): + self._fixture_from_geometry( + { + "a": { + "subclasses": { + "b": {"polymorphic_load": "selectin"}, + "c": { + "subclasses": { + "d": { + "polymorphic_load": "selectin", + }, + "e": { + "polymorphic_load": "selectin", + }, + "f": {}, + }, + "polymorphic_load": "selectin", + }, + } + } + } + ) + + def test_threelevel_all_selectin_l1_load_l3( + self, threelevel_all_selectin_fixture + ): + """test for #9373 - load base to receive level 3 endpoints""" + + a, b, c, d, e = self.classes("a", "b", "c", "d", "e") + sess = fixture_session() + sess.add_all( + [d(c_data="cd1", d_data="d1"), e(c_data="ce1", e_data="e1")] + ) + sess.commit() + + for i in range(3): + sess.close() + + q = sess.query(a) + + result = self.assert_sql_execution( + testing.db, + q.all, + CompiledSQL( + "SELECT a.id AS a_id, a.type AS a_type, " + "a.a_data AS a_a_data FROM a", + {}, + ), + CompiledSQL( + "SELECT d.id AS d_id, c.id AS c_id, a.id AS a_id, " + "a.type AS a_type, c.c_data AS c_c_data, " + "d.d_data AS d_d_data " + "FROM a JOIN c ON a.id = c.id JOIN d ON c.id = d.id " + "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) " + "ORDER BY a.id", + [{"primary_keys": [1]}], + ), + CompiledSQL( + "SELECT e.id AS e_id, c.id AS c_id, a.id AS a_id, " + "a.type AS a_type, c.c_data AS c_c_data, " + "e.e_data AS e_e_data " + "FROM a JOIN c ON a.id = c.id JOIN e ON c.id = e.id " + "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) " + "ORDER BY a.id", + [{"primary_keys": [2]}], + ), + ) + with self.assert_statement_count(testing.db, 0): + eq_( + result, + [ + d(c_data="cd1", d_data="d1"), + e(c_data="ce1", e_data="e1"), + ], + ) + + def test_threelevel_partial_selectin_l1_load_l3( + self, threelevel_all_selectin_fixture + ): + """test for #9373 - load base to receive level 3 endpoints""" + + a, b, c, d, f = self.classes("a", "b", "c", "d", "f") + sess = fixture_session() + sess.add_all( + [d(c_data="cd1", d_data="d1"), f(c_data="ce1", f_data="e1")] + ) + sess.commit() + + for i in range(3): + sess.close() + q = sess.query(a) + + result = self.assert_sql_execution( + testing.db, + q.all, + CompiledSQL( + "SELECT a.id AS a_id, a.type AS a_type, " + "a.a_data AS a_a_data FROM a", + {}, + ), + CompiledSQL( + "SELECT d.id AS d_id, c.id AS c_id, a.id AS a_id, " + "a.type AS a_type, c.c_data AS c_c_data, " + "d.d_data AS d_d_data " + "FROM a JOIN c ON a.id = c.id JOIN d ON c.id = d.id " + "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) " + "ORDER BY a.id", + [{"primary_keys": [1]}], + ), + # only loads pk 2 - this is the filtering inside of do_load + CompiledSQL( + "SELECT c.id AS c_id, a.id AS a_id, a.type AS a_type, " + "c.c_data AS c_c_data " + "FROM a JOIN c ON a.id = c.id " + "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) " + "ORDER BY a.id", + [{"primary_keys": [2]}], + ), + # no more SQL; if we hit pk 1 again, it would re-do the d here + ) + + with self.sql_execution_asserter(testing.db) as asserter_: + eq_( + result, + [ + d(c_data="cd1", d_data="d1"), + f(c_data="ce1", f_data="e1"), + ], + ) + + # f was told not to load its attrs, so they load here + asserter_.assert_( + CompiledSQL( + "SELECT f.f_data AS f_f_data FROM f WHERE :param_1 = f.id", + [{"param_1": 2}], + ), + ) + + def test_threelevel_all_selectin_l1_load_l2( + self, threelevel_all_selectin_fixture + ): + """test for #9373 - load base to receive level 2 endpoint""" + a, b, c, d, e = self.classes("a", "b", "c", "d", "e") + sess = fixture_session() + sess.add_all([c(c_data="c1", a_data="a1")]) + sess.commit() + + q = sess.query(a) + + result = self.assert_sql_execution( + testing.db, + q.all, + CompiledSQL( + "SELECT a.id AS a_id, a.type AS a_type, " + "a.a_data AS a_a_data FROM a", + {}, + ), + CompiledSQL( + "SELECT c.id AS c_id, a.id AS a_id, a.type AS a_type, " + "c.c_data AS c_c_data FROM a JOIN c ON a.id = c.id " + "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) ORDER BY a.id", + {"primary_keys": [1]}, + ), + ) + with self.assert_statement_count(testing.db, 0): + eq_( + result, + [c(c_data="c1", a_data="a1")], + ) + + @testing.variation("use_aliased_class", [True, False]) + def test_threelevel_all_selectin_l2_load_l3( + self, threelevel_all_selectin_fixture, use_aliased_class + ): + """test for #9373 - load level 2 endpoing to receive level 3 + endpoints""" + a, b, c, d, e = self.classes("a", "b", "c", "d", "e") + sess = fixture_session() + sess.add_all( + [d(c_data="cd1", d_data="d1"), e(c_data="ce1", e_data="e1")] + ) + sess.commit() + + if use_aliased_class: + q = sess.query(aliased(c, flat=True)) + else: + q = sess.query(c) + result = self.assert_sql_execution( + testing.db, + q.all, + Conditional( + bool(use_aliased_class), + [ + CompiledSQL( + "SELECT c_1.id AS c_1_id, a_1.id AS a_1_id, " + "a_1.type AS a_1_type, a_1.a_data AS a_1_a_data, " + "c_1.c_data AS c_1_c_data " + "FROM a AS a_1 JOIN c AS c_1 ON a_1.id = c_1.id", + {}, + ) + ], + [ + CompiledSQL( + "SELECT c.id AS c_id, a.id AS a_id, a.type AS a_type, " + "a.a_data AS a_a_data, c.c_data AS c_c_data " + "FROM a JOIN c ON a.id = c.id", + {}, + ) + ], + ), + CompiledSQL( + "SELECT d.id AS d_id, c.id AS c_id, a.id AS a_id, " + "a.type AS a_type, d.d_data AS d_d_data " + "FROM a JOIN c ON a.id = c.id JOIN d ON c.id = d.id " + "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) ORDER BY a.id", + [{"primary_keys": [1]}], + ), + CompiledSQL( + "SELECT e.id AS e_id, c.id AS c_id, a.id AS a_id, " + "a.type AS a_type, e.e_data AS e_e_data " + "FROM a JOIN c ON a.id = c.id JOIN e ON c.id = e.id " + "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) ORDER BY a.id", + [{"primary_keys": [2]}], + ), + ) + with self.assert_statement_count(testing.db, 0): + eq_( + result, + [d(c_data="cd1", d_data="d1"), e(c_data="ce1", e_data="e1")], + ) + def test_threelevel_selectin_to_inline_options(self): self._fixture_from_geometry( { -- 2.47.2