]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Limit select in loading for correct types
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 10 Jan 2018 04:03:40 +0000 (23:03 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 12 Jan 2018 17:59:09 +0000 (12:59 -0500)
Fixed bug in new "selectin" relationship loader where the loader could try
to load a non-existent relationship when loading a collection of
polymorphic objects, where only some of the mappers include that
relationship, typically when :meth:`.PropComparator.of_type` is being used.

This generalizes the mapper limiting that was present
in _load_subclass_via_in() to be part of the PostLoad object
itself, and is used by both polymorphic selectin loading and
relationship selectin loading.

Change-Id: I31416550e27bc8374b673860f57d9dcf96abe87d
Fixes: #4156
doc/build/changelog/unreleased_12/4156.rst [new file with mode: 0644]
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/strategies.py
test/orm/test_selectin_relations.py

diff --git a/doc/build/changelog/unreleased_12/4156.rst b/doc/build/changelog/unreleased_12/4156.rst
new file mode 100644 (file)
index 0000000..4511302
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 4156
+
+    Fixed bug in new "selectin" relationship loader where the loader could try
+    to load a non-existent relationship when loading a collection of
+    polymorphic objects, where only some of the mappers include that
+    relationship, typically when :meth:`.PropComparator.of_type` is being used.
index a23cafac2cc178dc95647177f13b6bc2d75d5f59..8a20bf0dd775567a47efdd91ed81d4d1acd0e2a5 100644 (file)
@@ -394,7 +394,8 @@ def _instance_processor(
             callable_ = _load_subclass_via_in(context, path, selectin_load_via)
 
             PostLoad.callable_for_path(
-                context, load_path, selectin_load_via,
+                context, load_path, selectin_load_via.mapper,
+                selectin_load_via,
                 callable_, selectin_load_via)
 
     post_load = PostLoad.for_context(context, load_path, only_load_props)
@@ -574,7 +575,6 @@ def _load_subclass_via_in(context, path, entity):
             primary_keys=[
                 state.key[1][0] if zero_idx else state.key[1]
                 for state, load_attrs in states
-                if state.mapper.isa(mapper)
             ]
         ).all()
 
@@ -738,16 +738,25 @@ class PostLoad(object):
         self.load_keys = None
 
     def add_state(self, state, overwrite):
+        # the states for a polymorphic load here are all shared
+        # within a single PostLoad object among multiple subtypes.
+        # Filtering of callables on a per-subclass basis needs to be done at
+        # the invocation level
         self.states[state] = overwrite
 
     def invoke(self, context, path):
         if not self.states:
             return
         path = path_registry.PathRegistry.coerce(path)
-        for key, loader, arg, kw in self.loaders.values():
+        for token, limit_to_mapper, loader, arg, kw in self.loaders.values():
+            states = [
+                (state, overwrite)
+                for state, overwrite
+                in self.states.items()
+                if state.manager.mapper.isa(limit_to_mapper)
+            ]
             loader(
-                context, path, self.states.items(),
-                self.load_keys, *arg, **kw)
+                context, path, states, self.load_keys, *arg, **kw)
         self.states.clear()
 
     @classmethod
@@ -764,12 +773,13 @@ class PostLoad(object):
 
     @classmethod
     def callable_for_path(
-            cls, context, path, attr_key, loader_callable, *arg, **kw):
+            cls, context, path, limit_to_mapper, token,
+            loader_callable, *arg, **kw):
         if path.path in context.post_load_paths:
             pl = context.post_load_paths[path.path]
         else:
             pl = context.post_load_paths[path.path] = PostLoad()
-        pl.loaders[attr_key] = (attr_key, loader_callable, arg, kw)
+        pl.loaders[token] = (token, limit_to_mapper, loader_callable, arg, kw)
 
 
 def load_scalar_attributes(mapper, state, attribute_names):
index a57b66045c1164d218f890d912643b49c54a939b..c3eae1e9125c8b4d9cd8ee526342062a3c64fd94 100644 (file)
@@ -1883,7 +1883,7 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots):
                 return
 
         loading.PostLoad.callable_for_path(
-            context, selectin_path, self.key,
+            context, selectin_path, self.parent, self.key,
             self._load_for_path, effective_entity)
 
     @util.dependencies("sqlalchemy.ext.baked")
index 6f10260cca93c3fb56a84051533a6b32460ab72c..ff1d0d40f1cb9f4861c3c0d9d66deda795409043 100644 (file)
@@ -5,7 +5,7 @@ from sqlalchemy import Integer, String, ForeignKey, bindparam
 from sqlalchemy.orm import selectinload, selectinload_all, \
     mapper, relationship, clear_mappers, create_session, \
     aliased, joinedload, deferred, undefer,\
-    Session, subqueryload
+    Session, subqueryload, defaultload
 from sqlalchemy.testing import assert_raises, \
     assert_raises_message
 from sqlalchemy.testing.assertsql import CompiledSQL
@@ -1334,6 +1334,149 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic):
         )
 
 
+class HeterogeneousSubtypesTest(fixtures.DeclarativeMappedTest):
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class Company(Base):
+            __tablename__ = 'company'
+            id = Column(Integer, primary_key=True)
+            name = Column(String(50))
+            employees = relationship('Employee', order_by="Employee.id")
+
+        class Employee(Base):
+            __tablename__ = 'employee'
+            id = Column(Integer, primary_key=True)
+            type = Column(String(50))
+            name = Column(String(50))
+            company_id = Column(ForeignKey('company.id'))
+
+            __mapper_args__ = {
+                'polymorphic_on': 'type',
+                'with_polymorphic': '*',
+            }
+
+        class Programmer(Employee):
+            __tablename__ = 'programmer'
+            id = Column(ForeignKey('employee.id'), primary_key=True)
+            languages = relationship('Language')
+
+            __mapper_args__ = {
+                'polymorphic_identity': 'programmer',
+            }
+
+        class Manager(Employee):
+            __tablename__ = 'manager'
+            id = Column(ForeignKey('employee.id'), primary_key=True)
+            golf_swing_id = Column(ForeignKey("golf_swing.id"))
+            golf_swing = relationship("GolfSwing")
+
+            __mapper_args__ = {
+                'polymorphic_identity': 'manager',
+            }
+
+        class Language(Base):
+            __tablename__ = 'language'
+            id = Column(Integer, primary_key=True)
+            programmer_id = Column(
+                Integer,
+                ForeignKey('programmer.id'),
+                nullable=False,
+            )
+            name = Column(String(50))
+
+        class GolfSwing(Base):
+            __tablename__ = 'golf_swing'
+            id = Column(Integer, primary_key=True)
+            name = Column(String(50))
+
+    @classmethod
+    def insert_data(cls):
+        Company, Programmer, Manager, GolfSwing, Language = cls.classes(
+            "Company", "Programmer", "Manager", "GolfSwing", "Language")
+        c1 = Company(
+            id=1,
+            name='Foobar Corp',
+            employees=[Programmer(
+                id=1,
+                name='p1',
+                languages=[Language(id=1, name='Python')],
+            ), Manager(
+                id=2,
+                name='m1',
+                golf_swing=GolfSwing(name="fore")
+            )],
+        )
+        c2 = Company(
+            id=2,
+            name='bat Corp',
+            employees=[
+                Manager(
+                    id=3,
+                    name='m2',
+                    golf_swing=GolfSwing(name="clubs"),
+                ), Programmer(
+                    id=4,
+                    name='p2',
+                    languages=[Language(id=2, name="Java")]
+                )],
+        )
+        sess = Session()
+        sess.add_all([c1, c2])
+        sess.commit()
+
+    def test_one_to_many(self):
+
+        Company, Programmer, Manager, GolfSwing, Language = self.classes(
+            "Company", "Programmer", "Manager", "GolfSwing", "Language")
+        sess = Session()
+        company = sess.query(Company).filter(
+            Company.id == 1,
+        ).options(
+            selectinload(Company.employees.of_type(Programmer)).
+            selectinload(Programmer.languages),
+        ).one()
+
+        def go():
+            eq_(company.employees[0].languages[0].name, "Python")
+
+        self.assert_sql_count(testing.db, go, 0)
+
+    def test_many_to_one(self):
+        Company, Programmer, Manager, GolfSwing, Language = self.classes(
+            "Company", "Programmer", "Manager", "GolfSwing", "Language")
+        sess = Session()
+        company = sess.query(Company).filter(
+            Company.id == 2,
+        ).options(
+            selectinload(Company.employees.of_type(Manager)).
+            selectinload(Manager.golf_swing),
+        ).one()
+
+        def go():
+            eq_(company.employees[0].golf_swing.name, "clubs")
+
+        self.assert_sql_count(testing.db, go, 0)
+
+    def test_both(self):
+        Company, Programmer, Manager, GolfSwing, Language = self.classes(
+            "Company", "Programmer", "Manager", "GolfSwing", "Language")
+        sess = Session()
+        rows = sess.query(Company).options(
+            selectinload(Company.employees.of_type(Manager)).
+            selectinload(Manager.golf_swing),
+            defaultload(Company.employees.of_type(Programmer)).
+            selectinload(Programmer.languages),
+        ).order_by(Company.id).all()
+
+        def go():
+            eq_(rows[0].employees[0].languages[0].name, "Python")
+            eq_(rows[1].employees[0].golf_swing.name, "clubs")
+
+        self.assert_sql_count(testing.db, go, 0)
+
+
 class ChunkingTest(fixtures.DeclarativeMappedTest):
     """test IN chunking.