From: Mike Bayer Date: Wed, 10 Jan 2018 04:03:40 +0000 (-0500) Subject: Limit select in loading for correct types X-Git-Tag: rel_1_2_1~4^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a216625bd03313e85f8063c2c875730e15edc4a4;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Limit select in loading for correct types Fixed bug in new "selectin" relationship loader where the loader could try to load a non-existent relationship when loading a collection of polymorphic objects, where only some of the mappers include that relationship, typically when :meth:`.PropComparator.of_type` is being used. This generalizes the mapper limiting that was present in _load_subclass_via_in() to be part of the PostLoad object itself, and is used by both polymorphic selectin loading and relationship selectin loading. Change-Id: I31416550e27bc8374b673860f57d9dcf96abe87d Fixes: #4156 --- diff --git a/doc/build/changelog/unreleased_12/4156.rst b/doc/build/changelog/unreleased_12/4156.rst new file mode 100644 index 0000000000..4511302e3d --- /dev/null +++ b/doc/build/changelog/unreleased_12/4156.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, orm + :tickets: 4156 + + Fixed bug in new "selectin" relationship loader where the loader could try + to load a non-existent relationship when loading a collection of + polymorphic objects, where only some of the mappers include that + relationship, typically when :meth:`.PropComparator.of_type` is being used. diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index a23cafac2c..8a20bf0dd7 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -394,7 +394,8 @@ def _instance_processor( callable_ = _load_subclass_via_in(context, path, selectin_load_via) PostLoad.callable_for_path( - context, load_path, selectin_load_via, + context, load_path, selectin_load_via.mapper, + selectin_load_via, callable_, selectin_load_via) post_load = PostLoad.for_context(context, load_path, only_load_props) @@ -574,7 +575,6 @@ def _load_subclass_via_in(context, path, entity): primary_keys=[ state.key[1][0] if zero_idx else state.key[1] for state, load_attrs in states - if state.mapper.isa(mapper) ] ).all() @@ -738,16 +738,25 @@ class PostLoad(object): self.load_keys = None def add_state(self, state, overwrite): + # the states for a polymorphic load here are all shared + # within a single PostLoad object among multiple subtypes. + # Filtering of callables on a per-subclass basis needs to be done at + # the invocation level self.states[state] = overwrite def invoke(self, context, path): if not self.states: return path = path_registry.PathRegistry.coerce(path) - for key, loader, arg, kw in self.loaders.values(): + for token, limit_to_mapper, loader, arg, kw in self.loaders.values(): + states = [ + (state, overwrite) + for state, overwrite + in self.states.items() + if state.manager.mapper.isa(limit_to_mapper) + ] loader( - context, path, self.states.items(), - self.load_keys, *arg, **kw) + context, path, states, self.load_keys, *arg, **kw) self.states.clear() @classmethod @@ -764,12 +773,13 @@ class PostLoad(object): @classmethod def callable_for_path( - cls, context, path, attr_key, loader_callable, *arg, **kw): + cls, context, path, limit_to_mapper, token, + loader_callable, *arg, **kw): if path.path in context.post_load_paths: pl = context.post_load_paths[path.path] else: pl = context.post_load_paths[path.path] = PostLoad() - pl.loaders[attr_key] = (attr_key, loader_callable, arg, kw) + pl.loaders[token] = (token, limit_to_mapper, loader_callable, arg, kw) def load_scalar_attributes(mapper, state, attribute_names): diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index a57b66045c..c3eae1e912 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1883,7 +1883,7 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): return loading.PostLoad.callable_for_path( - context, selectin_path, self.key, + context, selectin_path, self.parent, self.key, self._load_for_path, effective_entity) @util.dependencies("sqlalchemy.ext.baked") diff --git a/test/orm/test_selectin_relations.py b/test/orm/test_selectin_relations.py index 6f10260cca..ff1d0d40f1 100644 --- a/test/orm/test_selectin_relations.py +++ b/test/orm/test_selectin_relations.py @@ -5,7 +5,7 @@ from sqlalchemy import Integer, String, ForeignKey, bindparam from sqlalchemy.orm import selectinload, selectinload_all, \ mapper, relationship, clear_mappers, create_session, \ aliased, joinedload, deferred, undefer,\ - Session, subqueryload + Session, subqueryload, defaultload from sqlalchemy.testing import assert_raises, \ assert_raises_message from sqlalchemy.testing.assertsql import CompiledSQL @@ -1334,6 +1334,149 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic): ) +class HeterogeneousSubtypesTest(fixtures.DeclarativeMappedTest): + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class Company(Base): + __tablename__ = 'company' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + employees = relationship('Employee', order_by="Employee.id") + + class Employee(Base): + __tablename__ = 'employee' + id = Column(Integer, primary_key=True) + type = Column(String(50)) + name = Column(String(50)) + company_id = Column(ForeignKey('company.id')) + + __mapper_args__ = { + 'polymorphic_on': 'type', + 'with_polymorphic': '*', + } + + class Programmer(Employee): + __tablename__ = 'programmer' + id = Column(ForeignKey('employee.id'), primary_key=True) + languages = relationship('Language') + + __mapper_args__ = { + 'polymorphic_identity': 'programmer', + } + + class Manager(Employee): + __tablename__ = 'manager' + id = Column(ForeignKey('employee.id'), primary_key=True) + golf_swing_id = Column(ForeignKey("golf_swing.id")) + golf_swing = relationship("GolfSwing") + + __mapper_args__ = { + 'polymorphic_identity': 'manager', + } + + class Language(Base): + __tablename__ = 'language' + id = Column(Integer, primary_key=True) + programmer_id = Column( + Integer, + ForeignKey('programmer.id'), + nullable=False, + ) + name = Column(String(50)) + + class GolfSwing(Base): + __tablename__ = 'golf_swing' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + @classmethod + def insert_data(cls): + Company, Programmer, Manager, GolfSwing, Language = cls.classes( + "Company", "Programmer", "Manager", "GolfSwing", "Language") + c1 = Company( + id=1, + name='Foobar Corp', + employees=[Programmer( + id=1, + name='p1', + languages=[Language(id=1, name='Python')], + ), Manager( + id=2, + name='m1', + golf_swing=GolfSwing(name="fore") + )], + ) + c2 = Company( + id=2, + name='bat Corp', + employees=[ + Manager( + id=3, + name='m2', + golf_swing=GolfSwing(name="clubs"), + ), Programmer( + id=4, + name='p2', + languages=[Language(id=2, name="Java")] + )], + ) + sess = Session() + sess.add_all([c1, c2]) + sess.commit() + + def test_one_to_many(self): + + Company, Programmer, Manager, GolfSwing, Language = self.classes( + "Company", "Programmer", "Manager", "GolfSwing", "Language") + sess = Session() + company = sess.query(Company).filter( + Company.id == 1, + ).options( + selectinload(Company.employees.of_type(Programmer)). + selectinload(Programmer.languages), + ).one() + + def go(): + eq_(company.employees[0].languages[0].name, "Python") + + self.assert_sql_count(testing.db, go, 0) + + def test_many_to_one(self): + Company, Programmer, Manager, GolfSwing, Language = self.classes( + "Company", "Programmer", "Manager", "GolfSwing", "Language") + sess = Session() + company = sess.query(Company).filter( + Company.id == 2, + ).options( + selectinload(Company.employees.of_type(Manager)). + selectinload(Manager.golf_swing), + ).one() + + def go(): + eq_(company.employees[0].golf_swing.name, "clubs") + + self.assert_sql_count(testing.db, go, 0) + + def test_both(self): + Company, Programmer, Manager, GolfSwing, Language = self.classes( + "Company", "Programmer", "Manager", "GolfSwing", "Language") + sess = Session() + rows = sess.query(Company).options( + selectinload(Company.employees.of_type(Manager)). + selectinload(Manager.golf_swing), + defaultload(Company.employees.of_type(Programmer)). + selectinload(Programmer.languages), + ).order_by(Company.id).all() + + def go(): + eq_(rows[0].employees[0].languages[0].name, "Python") + eq_(rows[1].employees[0].golf_swing.name, "clubs") + + self.assert_sql_count(testing.db, go, 0) + + class ChunkingTest(fixtures.DeclarativeMappedTest): """test IN chunking.