]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed a bug involving contains_eager(), which would apply itself
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 7 Jul 2009 17:17:22 +0000 (17:17 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 7 Jul 2009 17:17:22 +0000 (17:17 +0000)
to a secondary (i.e. lazy) load in a particular rare case,
producing cartesian products.   improved the targeting
of query.options() on secondary loads overall [ticket:1461].

CHANGES
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/strategies.py
test/orm/test_mapper.py

diff --git a/CHANGES b/CHANGES
index faee0d3dd281d62d8315284a16ca6970fc493f85..cf2f8150d162e57403749d8f1dbdae0a9f2753b6 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -21,7 +21,12 @@ CHANGES
     
     - Fixed bug introduced in 0.5.4 whereby Composite types
       fail when default-holding columns are flushed.
-      
+    
+    - Fixed a bug involving contains_eager(), which would apply itself
+      to a secondary (i.e. lazy) load in a particular rare case,
+      producing cartesian products.   improved the targeting
+      of query.options() on secondary loads overall [ticket:1461].
+        
     - Fixed another 0.5.4 bug whereby mutable attributes (i.e. PickleType)
       wouldn't be deserialized correctly when the whole object
       was serialized.  [ticket:1426]
index 7b840a50da4782e2d2f08d0d9f0ac304045562d6..6af8dde9ff219eb4ec5ea1b64e717f9be5ebc067 100644 (file)
@@ -934,7 +934,7 @@ def contains_eager(*keys, **kwargs):
     if kwargs:
         raise exceptions.ArgumentError("Invalid kwargs for contains_eager: %r" % kwargs.keys())
 
-    return (strategies.EagerLazyOption(keys, lazy=False), strategies.LoadEagerFromAliasOption(keys, alias=alias))
+    return (strategies.EagerLazyOption(keys, lazy=False, _only_on_lead=True), strategies.LoadEagerFromAliasOption(keys, alias=alias))
 
 @sa_util.accepts_a_list_as_starargs(list_deprecation='pending')
 def defer(*keys):
index 0ac771305833137686b1ac126a43cc4023c4a445..9a9ebfcab2b0ef85e6707373f2d016b1a252ce2d 100644 (file)
@@ -682,13 +682,14 @@ class PropertyOption(MapperOption):
             searchfor = mapper
         else:
             searchfor = _class_to_mapper(mapper).base_mapper
-
+        
         for ent in query._mapper_entities:
             if ent.path_entity is searchfor:
                 return ent
         else:
             if raiseerr:
-                raise sa_exc.ArgumentError("Can't find entity %s in Query.  Current list: %r" % (searchfor, [str(m.path_entity) for m in query._entities]))
+                raise sa_exc.ArgumentError("Can't find entity %s in Query.  Current list: %r" 
+                    % (searchfor, [str(m.path_entity) for m in query._entities]))
             else:
                 return None
 
@@ -718,8 +719,10 @@ class PropertyOption(MapperOption):
         entity = None
         l = []
 
+        # _current_path implies we're in a secondary load
+        # with an existing path
         current_path = list(query._current_path)
-
+            
         if self.mapper:
             entity = self.__find_entity(query, self.mapper, raiseerr)
             mapper = entity.mapper
@@ -752,7 +755,7 @@ class PropertyOption(MapperOption):
                 if current_path and key == current_path[1]:
                     current_path = current_path[2:]
                     continue
-
+                    
                 if prop is None:
                     return []
 
@@ -764,7 +767,12 @@ class PropertyOption(MapperOption):
                     path_element = mapper = getattr(prop, 'mapper', None)
                 if path_element:
                     path_element = path_element.base_mapper
-
+        
+        # if current_path tokens remain, then
+        # we didn't have an exact path match.
+        if current_path:
+            return []
+            
         return l
 
 class AttributeExtension(object):
index 20cbb8f4dcdeb09300203d25cbd6779626a297f2..ebb576a71601602f35a00a2c1bc67b40baef661e 100644 (file)
@@ -776,11 +776,16 @@ class EagerLoader(AbstractRelationLoader):
 log.class_logger(EagerLoader)
 
 class EagerLazyOption(StrategizedOption):
-    def __init__(self, key, lazy=True, chained=False, mapper=None):
+    def __init__(self, key, lazy=True, chained=False, mapper=None, _only_on_lead=False):
         super(EagerLazyOption, self).__init__(key, mapper)
         self.lazy = lazy
         self.chained = chained
+        self._only_on_lead = _only_on_lead
         
+    def process_query_conditionally(self, query):
+        if not self._only_on_lead:
+            StrategizedOption.process_query_conditionally(self, query)
+            
     def is_chained(self):
         return not self.lazy and self.chained
         
@@ -800,6 +805,10 @@ class LoadEagerFromAliasOption(PropertyOption):
                 m, alias, is_aliased_class = mapperutil._entity_info(alias)
         self.alias = alias
 
+    def process_query_conditionally(self, query):
+        # dont run this option on a secondary load
+        pass
+        
     def process_query_property(self, query, paths):
         if self.alias:
             if isinstance(self.alias, basestring):
index 025b96424df8758b82ac110dd6077b86c95ee19b..edde2a7565dd04ed28988bab248ed4ee1e24d177 100644 (file)
@@ -11,6 +11,7 @@ from sqlalchemy.orm import mapper, relation, backref, create_session, class_mapp
 from sqlalchemy.orm import defer, deferred, synonym, attributes, column_property, composite, relation, dynamic_loader, comparable_property
 from sqlalchemy.test.testing import eq_, AssertsCompiledSQL
 from test.orm import _base, _fixtures
+from sqlalchemy.test.assertsql import AllOf, CompiledSQL
 
 
 class MapperTest(_fixtures.FixtureTest):
@@ -1631,6 +1632,166 @@ class DeferredTest(_fixtures.FixtureTest):
         self.sql_count_(0, go)
         eq_(item.description, 'item 4')
 
+
+class SecondaryOptionsTest(_base.MappedTest):
+    """test that the contains_eager() option doesn't bleed into a secondary load."""
+
+    run_inserts = 'once'
+
+    run_deletes = None
+    
+    @classmethod
+    def define_tables(cls, metadata):
+        Table("base", metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('type', String(50), nullable=False)
+        )
+        Table("child1", metadata,
+            Column('id', Integer, ForeignKey('base.id'), primary_key=True),
+            Column('child2id', Integer, ForeignKey('child2.id'), nullable=False)
+        )
+        Table("child2", metadata,
+            Column('id', Integer, ForeignKey('base.id'), primary_key=True),
+        )
+        Table('related', metadata,
+            Column('id', Integer, ForeignKey('base.id'), primary_key=True),
+        )
+        
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        class Base(_base.ComparableEntity):
+            pass
+        class Child1(Base):
+            pass
+        class Child2(Base):
+            pass
+        class Related(_base.ComparableEntity):
+            pass
+        mapper(Base, base, polymorphic_on=base.c.type, properties={
+            'related':relation(Related, uselist=False)
+        })
+        mapper(Child1, child1, inherits=Base, polymorphic_identity='child1', properties={
+            'child2':relation(Child2, primaryjoin=child1.c.child2id==base.c.id, foreign_keys=child1.c.child2id)
+        })
+        mapper(Child2, child2, inherits=Base, polymorphic_identity='child2')
+        mapper(Related, related)
+        
+    @classmethod
+    @testing.resolve_artifact_names
+    def insert_data(cls):
+        base.insert().execute([
+            {'id':1, 'type':'child1'},
+            {'id':2, 'type':'child1'},
+            {'id':3, 'type':'child1'},
+            {'id':4, 'type':'child2'},
+            {'id':5, 'type':'child2'},
+            {'id':6, 'type':'child2'},
+        ])
+        child2.insert().execute([
+            {'id':4},
+            {'id':5},
+            {'id':6},
+        ])
+        child1.insert().execute([
+            {'id':1, 'child2id':4},
+            {'id':2, 'child2id':5},
+            {'id':3, 'child2id':6},
+        ])
+        related.insert().execute([
+            {'id':1},
+            {'id':2},
+            {'id':3},
+            {'id':4},
+            {'id':5},
+            {'id':6},
+        ])
+        
+    @testing.resolve_artifact_names
+    def test_contains_eager(self):
+        sess = create_session()
+        
+        
+        child1s = sess.query(Child1).join(Child1.related).options(sa.orm.contains_eager(Child1.related)).order_by(Child1.id)
+
+        def go():
+            eq_(
+                child1s.all(),
+                [Child1(id=1, related=Related(id=1)), Child1(id=2, related=Related(id=2)), Child1(id=3, related=Related(id=3))]
+            )
+        self.assert_sql_count(testing.db, go, 1)
+        
+        c1 = child1s[0]
+        self.assert_sql_execution(
+            testing.db, 
+            lambda: c1.child2, 
+            CompiledSQL(
+                "SELECT base.id AS base_id, child2.id AS child2_id, base.type AS base_type "
+                "FROM base JOIN child2 ON base.id = child2.id "
+                "WHERE base.id = :param_1",
+                {'param_1':4}
+            )
+        )
+
+    @testing.resolve_artifact_names
+    def test_eagerload_on_other(self):
+        sess = create_session()
+
+        child1s = sess.query(Child1).join(Child1.related).options(sa.orm.eagerload(Child1.related)).order_by(Child1.id)
+
+        def go():
+            eq_(
+                child1s.all(),
+                [Child1(id=1, related=Related(id=1)), Child1(id=2, related=Related(id=2)), Child1(id=3, related=Related(id=3))]
+            )
+        self.assert_sql_count(testing.db, go, 1)
+        
+        c1 = child1s[0]
+
+        self.assert_sql_execution(
+            testing.db, 
+            lambda: c1.child2, 
+            CompiledSQL(
+            "SELECT base.id AS base_id, child2.id AS child2_id, base.type AS base_type "
+            "FROM base JOIN child2 ON base.id = child2.id WHERE base.id = :param_1",
+
+#   eagerload- this shouldn't happen
+#            "SELECT base.id AS base_id, child2.id AS child2_id, base.type AS base_type, "
+#            "related_1.id AS related_1_id FROM base JOIN child2 ON base.id = child2.id "
+#            "LEFT OUTER JOIN related AS related_1 ON base.id = related_1.id WHERE base.id = :param_1",
+                {'param_1':4}
+            )
+        )
+
+    @testing.resolve_artifact_names
+    def test_eagerload_on_same(self):
+        sess = create_session()
+
+        child1s = sess.query(Child1).join(Child1.related).options(sa.orm.eagerload(Child1.child2, Child2.related)).order_by(Child1.id)
+
+        def go():
+            eq_(
+                child1s.all(),
+                [Child1(id=1, related=Related(id=1)), Child1(id=2, related=Related(id=2)), Child1(id=3, related=Related(id=3))]
+            )
+        self.assert_sql_count(testing.db, go, 4)
+        
+        c1 = child1s[0]
+
+        # this *does* eagerload
+        self.assert_sql_execution(
+            testing.db, 
+            lambda: c1.child2, 
+            CompiledSQL(
+                "SELECT base.id AS base_id, child2.id AS child2_id, base.type AS base_type, "
+                "related_1.id AS related_1_id FROM base JOIN child2 ON base.id = child2.id "
+                "LEFT OUTER JOIN related AS related_1 ON base.id = related_1.id WHERE base.id = :param_1",
+                {'param_1':4}
+            )
+        )
+        
+
 class DeferredPopulationTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):