if not refresh_state and _polymorphic_from is not None:
key = ("loader", path.path)
+
if key in context.attributes and context.attributes[key].strategy == (
("selectinload_polymorphic", True),
):
- selectin_load_via = mapper._should_selectin_load(
- context.attributes[key].local_opts["entities"],
- _polymorphic_from,
- )
+ option_entities = context.attributes[key].local_opts["entities"]
else:
- selectin_load_via = mapper._should_selectin_load(
- None, _polymorphic_from
- )
+ option_entities = None
+ selectin_load_via = mapper._should_selectin_load(
+ option_entities,
+ _polymorphic_from,
+ )
if selectin_load_via and selectin_load_via is not _polymorphic_from:
# only_load_props goes w/ refresh_state only, and in a refresh
# loading does not apply
assert only_load_props is None
- callable_ = _load_subclass_via_in(context, path, selectin_load_via)
-
+ callable_ = _load_subclass_via_in(
+ context,
+ path,
+ selectin_load_via,
+ _polymorphic_from,
+ option_entities,
+ )
PostLoad.callable_for_path(
context,
load_path,
return _instance
-def _load_subclass_via_in(context, path, entity):
+def _load_subclass_via_in(
+ context, path, entity, polymorphic_from, option_entities
+):
mapper = entity.mapper
+ # TODO: polymorphic_from seems to be a Mapper in all cases.
+ # this is likely not needed, but as we dont have typing in loading.py
+ # yet, err on the safe side
+ polymorphic_from_mapper = polymorphic_from.mapper
+ not_against_basemost = polymorphic_from_mapper.inherits is not None
+
zero_idx = len(mapper.base_mapper.primary_key) == 1
- if entity.is_aliased_class:
- q, enable_opt, disable_opt = mapper._subclass_load_via_in(entity)
+ if entity.is_aliased_class or not_against_basemost:
+ q, enable_opt, disable_opt = mapper._subclass_load_via_in(
+ entity, polymorphic_from
+ )
else:
q, enable_opt, disable_opt = mapper._subclass_load_via_in_mapper
def do_load(context, path, states, load_only, effective_entity):
+ if not option_entities:
+ # filter out states for those that would have selectinloaded
+ # from another loader
+ # TODO: we are currently ignoring the case where the
+ # "selectin_polymorphic" option is used, as this is much more
+ # complex / specific / very uncommon API use
+ states = [
+ (s, v)
+ for s, v in states
+ if s.mapper._would_selectin_load_only_from_given_mapper(mapper)
+ ]
+
+ if not states:
+ return
+
orig_query = context.query
options = (enable_opt,) + orig_query._with_options + (disable_opt,)
if m is mapper:
break
+ @HasMemoized.memoized_attribute
+ def _would_selectinload_combinations_cache(self):
+ return {}
+
+ def _would_selectin_load_only_from_given_mapper(self, super_mapper):
+ """return True if this mapper would "selectin" polymorphic load based
+ on the given super mapper, and not from a setting from a subclass.
+
+ given::
+
+ class A:
+ ...
+
+ class B(A):
+ __mapper_args__ = {"polymorphic_load": "selectin"}
+
+ class C(B):
+ ...
+
+ class D(B):
+ __mapper_args__ = {"polymorphic_load": "selectin"}
+
+ ``inspect(C)._would_selectin_load_only_from_given_mapper(inspect(B))``
+ returns True, because C does selectin loading because of B's setting.
+
+ OTOH, ``inspect(D)
+ ._would_selectin_load_only_from_given_mapper(inspect(B))``
+ returns False, because D does selectin loading because of its own
+ setting; when we are doing a selectin poly load from B, we want to
+ filter out D because it would already have its own selectin poly load
+ set up separately.
+
+ Added as part of #9373.
+
+ """
+ cache = self._would_selectinload_combinations_cache
+
+ try:
+ return cache[super_mapper]
+ except KeyError:
+ pass
+
+ # assert that given object is a supermapper, meaning we already
+ # strong reference it directly or indirectly. this allows us
+ # to not worry that we are creating new strongrefs to unrelated
+ # mappers or other objects.
+ assert self.isa(super_mapper)
+
+ mapper = super_mapper
+ for m in self._iterate_to_target_viawpoly(mapper):
+ if m.polymorphic_load == "selectin":
+ retval = m is super_mapper
+ break
+ else:
+ retval = False
+
+ cache[super_mapper] = retval
+ return retval
+
def _should_selectin_load(self, enabled_via_opt, polymorphic_from):
if not enabled_via_opt:
# common case, takes place for all polymorphic loads
return None
@util.preload_module("sqlalchemy.orm.strategy_options")
- def _subclass_load_via_in(self, entity):
+ def _subclass_load_via_in(self, entity, polymorphic_from):
"""Assemble a that can load the columns local to
this subclass as a SELECT with IN.
disable_opt = strategy_options.Load(entity)
enable_opt = strategy_options.Load(entity)
+ classes_to_include = {self}
+ m: Optional[Mapper[Any]] = self.inherits
+ while (
+ m is not None
+ and m is not polymorphic_from
+ and m.polymorphic_load == "selectin"
+ ):
+ classes_to_include.add(m)
+ m = m.inherits
+
for prop in self.attrs:
# skip prop keys that are not instrumented on the mapped class.
if prop.key not in self.class_manager:
continue
- if prop.parent is self or prop in keep_props:
+ if prop.parent in classes_to_include or prop in keep_props:
# "enable" options, to turn on the properties that we want to
# load by default (subject to options from the query)
if not isinstance(prop, StrategizedProperty):
@HasMemoized.memoized_attribute
def _subclass_load_via_in_mapper(self):
- return self._subclass_load_via_in(self)
+ # the default is loading this mapper against the basemost mapper
+ return self._subclass_load_via_in(self, self.base_mapper)
def cascade_iterator(
self,
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy import union
+from sqlalchemy.orm import aliased
from sqlalchemy.orm import backref
from sqlalchemy.orm import column_property
from sqlalchemy.orm import composite
from sqlalchemy.testing.assertions import expect_raises_message
from sqlalchemy.testing.assertsql import AllOf
from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.assertsql import Conditional
from sqlalchemy.testing.assertsql import EachOf
from sqlalchemy.testing.assertsql import Or
from sqlalchemy.testing.entities import ComparableEntity
with self.assert_statement_count(testing.db, 0):
eq_(result, [d(d_data="d1"), e(e_data="e1")])
+ @testing.fixture
+ def threelevel_all_selectin_fixture(self):
+ self._fixture_from_geometry(
+ {
+ "a": {
+ "subclasses": {
+ "b": {"polymorphic_load": "selectin"},
+ "c": {
+ "subclasses": {
+ "d": {
+ "polymorphic_load": "selectin",
+ },
+ "e": {
+ "polymorphic_load": "selectin",
+ },
+ "f": {},
+ },
+ "polymorphic_load": "selectin",
+ },
+ }
+ }
+ }
+ )
+
+ def test_threelevel_all_selectin_l1_load_l3(
+ self, threelevel_all_selectin_fixture
+ ):
+ """test for #9373 - load base to receive level 3 endpoints"""
+
+ a, b, c, d, e = self.classes("a", "b", "c", "d", "e")
+ sess = fixture_session()
+ sess.add_all(
+ [d(c_data="cd1", d_data="d1"), e(c_data="ce1", e_data="e1")]
+ )
+ sess.commit()
+
+ for i in range(3):
+ sess.close()
+
+ q = sess.query(a)
+
+ result = self.assert_sql_execution(
+ testing.db,
+ q.all,
+ CompiledSQL(
+ "SELECT a.id AS a_id, a.type AS a_type, "
+ "a.a_data AS a_a_data FROM a",
+ {},
+ ),
+ CompiledSQL(
+ "SELECT d.id AS d_id, c.id AS c_id, a.id AS a_id, "
+ "a.type AS a_type, c.c_data AS c_c_data, "
+ "d.d_data AS d_d_data "
+ "FROM a JOIN c ON a.id = c.id JOIN d ON c.id = d.id "
+ "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) "
+ "ORDER BY a.id",
+ [{"primary_keys": [1]}],
+ ),
+ CompiledSQL(
+ "SELECT e.id AS e_id, c.id AS c_id, a.id AS a_id, "
+ "a.type AS a_type, c.c_data AS c_c_data, "
+ "e.e_data AS e_e_data "
+ "FROM a JOIN c ON a.id = c.id JOIN e ON c.id = e.id "
+ "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) "
+ "ORDER BY a.id",
+ [{"primary_keys": [2]}],
+ ),
+ )
+ with self.assert_statement_count(testing.db, 0):
+ eq_(
+ result,
+ [
+ d(c_data="cd1", d_data="d1"),
+ e(c_data="ce1", e_data="e1"),
+ ],
+ )
+
+ def test_threelevel_partial_selectin_l1_load_l3(
+ self, threelevel_all_selectin_fixture
+ ):
+ """test for #9373 - load base to receive level 3 endpoints"""
+
+ a, b, c, d, f = self.classes("a", "b", "c", "d", "f")
+ sess = fixture_session()
+ sess.add_all(
+ [d(c_data="cd1", d_data="d1"), f(c_data="ce1", f_data="e1")]
+ )
+ sess.commit()
+
+ for i in range(3):
+ sess.close()
+ q = sess.query(a)
+
+ result = self.assert_sql_execution(
+ testing.db,
+ q.all,
+ CompiledSQL(
+ "SELECT a.id AS a_id, a.type AS a_type, "
+ "a.a_data AS a_a_data FROM a",
+ {},
+ ),
+ CompiledSQL(
+ "SELECT d.id AS d_id, c.id AS c_id, a.id AS a_id, "
+ "a.type AS a_type, c.c_data AS c_c_data, "
+ "d.d_data AS d_d_data "
+ "FROM a JOIN c ON a.id = c.id JOIN d ON c.id = d.id "
+ "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) "
+ "ORDER BY a.id",
+ [{"primary_keys": [1]}],
+ ),
+ # only loads pk 2 - this is the filtering inside of do_load
+ CompiledSQL(
+ "SELECT c.id AS c_id, a.id AS a_id, a.type AS a_type, "
+ "c.c_data AS c_c_data "
+ "FROM a JOIN c ON a.id = c.id "
+ "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) "
+ "ORDER BY a.id",
+ [{"primary_keys": [2]}],
+ ),
+ # no more SQL; if we hit pk 1 again, it would re-do the d here
+ )
+
+ with self.sql_execution_asserter(testing.db) as asserter_:
+ eq_(
+ result,
+ [
+ d(c_data="cd1", d_data="d1"),
+ f(c_data="ce1", f_data="e1"),
+ ],
+ )
+
+ # f was told not to load its attrs, so they load here
+ asserter_.assert_(
+ CompiledSQL(
+ "SELECT f.f_data AS f_f_data FROM f WHERE :param_1 = f.id",
+ [{"param_1": 2}],
+ ),
+ )
+
+ def test_threelevel_all_selectin_l1_load_l2(
+ self, threelevel_all_selectin_fixture
+ ):
+ """test for #9373 - load base to receive level 2 endpoint"""
+ a, b, c, d, e = self.classes("a", "b", "c", "d", "e")
+ sess = fixture_session()
+ sess.add_all([c(c_data="c1", a_data="a1")])
+ sess.commit()
+
+ q = sess.query(a)
+
+ result = self.assert_sql_execution(
+ testing.db,
+ q.all,
+ CompiledSQL(
+ "SELECT a.id AS a_id, a.type AS a_type, "
+ "a.a_data AS a_a_data FROM a",
+ {},
+ ),
+ CompiledSQL(
+ "SELECT c.id AS c_id, a.id AS a_id, a.type AS a_type, "
+ "c.c_data AS c_c_data FROM a JOIN c ON a.id = c.id "
+ "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) ORDER BY a.id",
+ {"primary_keys": [1]},
+ ),
+ )
+ with self.assert_statement_count(testing.db, 0):
+ eq_(
+ result,
+ [c(c_data="c1", a_data="a1")],
+ )
+
+ @testing.variation("use_aliased_class", [True, False])
+ def test_threelevel_all_selectin_l2_load_l3(
+ self, threelevel_all_selectin_fixture, use_aliased_class
+ ):
+ """test for #9373 - load level 2 endpoing to receive level 3
+ endpoints"""
+ a, b, c, d, e = self.classes("a", "b", "c", "d", "e")
+ sess = fixture_session()
+ sess.add_all(
+ [d(c_data="cd1", d_data="d1"), e(c_data="ce1", e_data="e1")]
+ )
+ sess.commit()
+
+ if use_aliased_class:
+ q = sess.query(aliased(c, flat=True))
+ else:
+ q = sess.query(c)
+ result = self.assert_sql_execution(
+ testing.db,
+ q.all,
+ Conditional(
+ bool(use_aliased_class),
+ [
+ CompiledSQL(
+ "SELECT c_1.id AS c_1_id, a_1.id AS a_1_id, "
+ "a_1.type AS a_1_type, a_1.a_data AS a_1_a_data, "
+ "c_1.c_data AS c_1_c_data "
+ "FROM a AS a_1 JOIN c AS c_1 ON a_1.id = c_1.id",
+ {},
+ )
+ ],
+ [
+ CompiledSQL(
+ "SELECT c.id AS c_id, a.id AS a_id, a.type AS a_type, "
+ "a.a_data AS a_a_data, c.c_data AS c_c_data "
+ "FROM a JOIN c ON a.id = c.id",
+ {},
+ )
+ ],
+ ),
+ CompiledSQL(
+ "SELECT d.id AS d_id, c.id AS c_id, a.id AS a_id, "
+ "a.type AS a_type, d.d_data AS d_d_data "
+ "FROM a JOIN c ON a.id = c.id JOIN d ON c.id = d.id "
+ "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) ORDER BY a.id",
+ [{"primary_keys": [1]}],
+ ),
+ CompiledSQL(
+ "SELECT e.id AS e_id, c.id AS c_id, a.id AS a_id, "
+ "a.type AS a_type, e.e_data AS e_e_data "
+ "FROM a JOIN c ON a.id = c.id JOIN e ON c.id = e.id "
+ "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) ORDER BY a.id",
+ [{"primary_keys": [2]}],
+ ),
+ )
+ with self.assert_statement_count(testing.db, 0):
+ eq_(
+ result,
+ [d(c_data="cd1", d_data="d1"), e(c_data="ce1", e_data="e1")],
+ )
+
def test_threelevel_selectin_to_inline_options(self):
self._fixture_from_geometry(
{