]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed bug which affected all eagerload() and similar options
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 23 Mar 2010 00:15:50 +0000 (20:15 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 23 Mar 2010 00:15:50 +0000 (20:15 -0400)
such that "remote" eager loads, i.e. eagerloads off of a lazy
load such as query(A).options(eagerload(A.b, B.c))
wouldn't eagerload anything, but using eagerload("b.c") would
work fine.
- subquery eagerloading very close

CHANGES
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/strategies.py
test/orm/_fixtures.py
test/orm/test_eager_relations.py
test/orm/test_mapper.py
test/orm/test_subquery_relations.py

diff --git a/CHANGES b/CHANGES
index baa8e03188d480611609c527717a626c085a34f1..7db2a05075d116832160ca508aad5004816748f6 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -12,7 +12,13 @@ CHANGES
     join(prop) would fail to render the second join outside the
     subquery, when joining on the same criterion as was on the 
     inside.
-
+    
+  - Fixed bug which affected all eagerload() and similar options 
+    such that "remote" eager loads, i.e. eagerloads off of a lazy
+    load such as query(A).options(eagerload(A.b, B.c))
+    wouldn't eagerload anything, but using eagerload("b.c") would
+    work fine.
+     
 0.6beta2
 ========
 
index c773e74f6197d42cd4158bfce8032eb10204e9e3..255b6b6fef8fd5697a2927fd5b8b2eb5d9822649 100644 (file)
@@ -757,14 +757,35 @@ class PropertyOption(MapperOption):
         self._process(query, False)
 
     def _process(self, query, raiseerr):
-        paths, mappers = self.__get_paths(query, raiseerr)
+        paths, mappers = self._get_paths(query, raiseerr)
         if paths:
             self.process_query_property(query, paths, mappers)
 
     def process_query_property(self, query, paths, mappers):
         pass
 
-    def __find_entity(self, query, mapper, raiseerr):
+    def __getstate__(self):
+        d = self.__dict__.copy()
+        d['key'] = ret = []
+        for token in util.to_list(self.key):
+            if isinstance(token, PropComparator):
+                ret.append((token.mapper.class_, token.key))
+            else:
+                ret.append(token)
+        return d
+
+    def __setstate__(self, state):
+        ret = []
+        for key in state['key']:
+            if isinstance(key, tuple):
+                cls, propkey = key
+                ret.append(getattr(cls, propkey))
+            else:
+                ret.append(key)
+        state['key'] = tuple(ret)
+        self.__dict__ = state
+
+    def _find_entity(self, query, mapper, raiseerr):
         from sqlalchemy.orm.util import _class_to_mapper, _is_aliased_class
 
         if _is_aliased_class(mapper):
@@ -773,7 +794,7 @@ class PropertyOption(MapperOption):
         else:
             searchfor = _class_to_mapper(mapper)
             isa = True
-            
+
         for ent in query._mapper_entities:
             if searchfor is ent.path_entity or (
                                 isa and
@@ -789,28 +810,7 @@ class PropertyOption(MapperOption):
             else:
                 return None
 
-    def __getstate__(self):
-        d = self.__dict__.copy()
-        d['key'] = ret = []
-        for token in util.to_list(self.key):
-            if isinstance(token, PropComparator):
-                ret.append((token.mapper.class_, token.key))
-            else:
-                ret.append(token)
-        return d
-
-    def __setstate__(self, state):
-        ret = []
-        for key in state['key']:
-            if isinstance(key, tuple):
-                cls, propkey = key
-                ret.append(getattr(cls, propkey))
-            else:
-                ret.append(key)
-        state['key'] = tuple(ret)
-        self.__dict__ = state
-
-    def __get_paths(self, query, raiseerr):
+    def _get_paths(self, query, raiseerr):
         path = None
         entity = None
         l = []
@@ -820,61 +820,71 @@ class PropertyOption(MapperOption):
         # with an existing path
         current_path = list(query._current_path)
             
-        if self.mapper:
-            entity = self.__find_entity(query, self.mapper, raiseerr)
-            mapper = entity.mapper
-            path_element = entity.path_entity
-
+        tokens = []
         for key in util.to_list(self.key):
             if isinstance(key, basestring):
-                tokens = key.split('.')
+                tokens += key.split('.')
             else:
-                tokens = [key]
-            for token in tokens:
-                if isinstance(token, basestring):
-                    if not entity:
-                        entity = query._entity_zero()
-                        path_element = entity.path_entity
-                        mapper = entity.mapper
-                    mappers.append(mapper)
-                    prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr)
-                    key = token
-                elif isinstance(token, PropComparator):
-                    prop = token.property
-                    if not entity:
-                        entity = self.__find_entity(query, token.parententity, raiseerr)
-                        if not entity:
-                            return [], []
-                        path_element = entity.path_entity
-                    mappers.append(prop.parent)
-                    key = prop.key
-                else:
-                    raise sa_exc.ArgumentError("mapper option expects string key "
-                                                "or list of attributes")
-
-                if current_path and key == current_path[1]:
-                    current_path = current_path[2:]
-                    continue
+                tokens += [key]
+        
+        for token in tokens:
+            if isinstance(token, basestring):
+                if not entity:
+                    if current_path:
+                        if current_path[1] == token:
+                            current_path = current_path[2:]
+                            continue
                     
-                if prop is None:
-                    return [], []
-
-                path = build_path(path_element, prop.key, path)
-                l.append(path)
-                if getattr(token, '_of_type', None):
-                    path_element = mapper = token._of_type
-                else:
-                    path_element = mapper = getattr(prop, 'mapper', None)
-
-                if path_element:
-                    path_element = path_element
+                    entity = query._entity_zero()
+                    path_element = entity.path_entity
+                    mapper = entity.mapper
+                mappers.append(mapper)
+                prop = mapper.get_property(
+                                    token, 
+                                    resolve_synonyms=True, 
+                                    raiseerr=raiseerr)
+                key = token
+            elif isinstance(token, PropComparator):
+                prop = token.property
+                if not entity:
+                    if current_path:
+                        if current_path[0:2] == [token.parententity, prop.key]:
+                            current_path = current_path[2:]
+                            continue
+
+                    entity = self._find_entity(
+                                            query, 
+                                            token.parententity, 
+                                            raiseerr)
+                    if not entity:
+                        return [], []
+                    path_element = entity.path_entity
+                    mapper = entity.mapper
+                mappers.append(prop.parent)
+                key = prop.key
+            else:
+                raise sa_exc.ArgumentError("mapper option expects string key "
+                                            "or list of attributes")
+
+            if prop is None:
+                return [], []
+
+            path = build_path(path_element, prop.key, path)
+            l.append(path)
+            if getattr(token, '_of_type', None):
+                path_element = mapper = token._of_type
+            else:
+                path_element = mapper = getattr(prop, 'mapper', None)
+
+            if path_element:
+                path_element = path_element
                     
                 
         # if current_path tokens remain, then
         # we didn't have an exact path match.
         if current_path:
             return [], []
-            
+
         return l, mappers
 
 class AttributeExtension(object):
index 08bb5062a7bc9a916acf383bdb051d19c8d3b32b..ccaf1dd7b3ea199bbe4064bd9d5cfbbbf2a7a531 100644 (file)
@@ -643,30 +643,30 @@ class SubqueryLoader(AbstractRelationshipLoader):
 
         if not context.query._enable_eagerloads:
             return
-
-#        path = path + (self.key,)
-        
-        
-        if ("orig_query", SubqueryLoader) not in context.attributes:
-            context.attributes[("orig_query", SubqueryLoader)] =\
-                    context.query
         
-        orig_query = context.attributes[("orig_query", SubqueryLoader)]
+        orig_query = context.attributes.get(("orig_query", SubqueryLoader),
+                            context.query)
 
-#        orig_query = context.query
-        path = context.query._current_path + path + (self.key, )
+        path = path + (self.key, )
         
-        prop = path[0].get_property(path[1])
+        local_cols, remote_cols = self._local_remote_columns(self.parent_property)
+        if len(path) > 1:
+            leftmost_mapper, leftmost_prop = path[0], path[0].get_property(path[1])
+            leftmost_cols, remote_cols = self._local_remote_columns(leftmost_prop)
+        else:
+            leftmost_cols = local_cols
+            leftmost_mapper = self.parent
         
-        local_cols, remote_cols = self._local_remote_columns(prop)
+        leftmost_attr = [
+            leftmost_mapper._get_col_to_prop(c).class_attribute
+            for c in leftmost_cols
+        ]
 
         local_attr = [
-            path[0]._get_col_to_prop(c).class_attribute
+            self.parent._get_col_to_prop(c).class_attribute
             for c in local_cols
         ]
         
-        #attr = self.parent_property.class_attribute
-        
         # modify the query to just look for parent columns in the 
         # join condition
         
@@ -676,7 +676,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
         q = orig_query._clone() #context.query._clone()
         q._attributes = q._attributes.copy()
         q._attributes[("orig_query", SubqueryLoader)] = orig_query
-        q._set_entities(local_attr)
+        q._set_entities(leftmost_attr)
         q._order_by = None
         
         # now select from it as a subquery.
@@ -690,36 +690,19 @@ class SubqueryLoader(AbstractRelationshipLoader):
 
         q = q.order_by(*local_attr)
         
-        q._attributes = q._attributes.copy()
         for attr in orig_query._attributes:
             strat, opt_path = attr
             if strat == "loaderstrategy":
                 opt_path = opt_path[len(path):]
                 q._attributes[("loaderstrategy", opt_path)] =\
                        context.query._attributes[attr]
-
-        q = q._with_current_path(path)
+        
         if self.parent_property.order_by:
             q = q.order_by(*self.parent_property.order_by)
         
         context.attributes[('subquery', path)] = \
-                q._attributes[('subquery', path)] = \
-                q
-
-#        for value in self.mapper._iterate_polymorphic_properties():
-#            strat = value._get_context_strategy(
-#                                        context, path + 
-#                                        (self.mapper,value.key)
-#                                    )
-            #print "VALUE", value, "PATH", path + (self.mapper,), "STRAT", type(strat)
-#            if isinstance(strat, SubqueryLoader):
-#                value.setup(
-#                    context, 
-#                    entity, 
-##                    path + (self.mapper,), 
-#                    adapter, 
-#                    parentmapper=self.mapper,
-#                    )
+            q._attributes[('subquery', path)] = q
+        
     
     def _local_remote_columns(self, prop):
         if prop.secondary is None:
@@ -733,9 +716,11 @@ class SubqueryLoader(AbstractRelationshipLoader):
                 ]
         
     def create_row_processor(self, context, path, mapper, row, adapter):
-#        path = path + (self.key,)
-        path = context.query._current_path + path + (self.key,)
-
+        path = path + (self.key,)
+        
+        if ('subquery', path) not in context.attributes:
+            return None, None
+            
         local_cols, remote_cols = self._local_remote_columns(self.parent_property)
 
         local_attr = [self.parent._get_col_to_prop(c).key for c in local_cols]
@@ -743,7 +728,6 @@ class SubqueryLoader(AbstractRelationshipLoader):
                         self.mapper._get_col_to_prop(c).key 
                         for c in remote_cols]
         
-        print "STRAT LOOKING FOR SUBQ AT PATH", path
         q = context.attributes[('subquery', path)]
         
         collections = dict(
@@ -753,6 +737,9 @@ class SubqueryLoader(AbstractRelationshipLoader):
                         lambda x:x[1:]
                     ))
         
+        if adapter:
+            local_cols = [adapter.columns[c] for c in local_cols]
+
         def execute(state, dict_, row):
             collection = collections.get(
                 tuple([row[col] for col in local_cols]), 
@@ -1040,11 +1027,11 @@ class EagerLoader(AbstractRelationshipLoader):
 log.class_logger(EagerLoader)
 
 class EagerLazyOption(StrategizedOption):
-    def __init__(self, key, lazy=True, chained=False, 
-                    mapper=None, propagate_to_loaders=True,
+    def __init__(self, key, lazy=True, chained=False,
+                    propagate_to_loaders=True,
                     _strategy_cls=None
                     ):
-        super(EagerLazyOption, self).__init__(key, mapper)
+        super(EagerLazyOption, self).__init__(key)
         self.lazy = lazy
         self.chained = chained
         self.propagate_to_loaders = propagate_to_loaders
index 2809506084f192ac8ac21458aec90c93174df1cd..a8df63b4a3f80e814d16cfb9c57c60d161e4551e 100644 (file)
@@ -418,7 +418,7 @@ class CannedResults(object):
                            items=[item1, item2, item3]),
                      Order(id=4,
                            items=[item1, item5])]),
-                User(id=10)]
+                User(id=10, orders=[])]
         return user_result
         
 FixtureTest.static = CannedResults()
index 4a5f876f0497eddcaa6d034ac6f5b356ca86054b..0f635e905023e04a9a0a29676e43de27fc0d4028 100644 (file)
@@ -241,6 +241,50 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
                 u)
         self.assert_sql_count(testing.db, go, 1)
 
+    @testing.resolve_artifact_names
+    def test_options_pathing(self):
+        mapper(User, users, properties={
+            'orders':relationship(Order, order_by=orders.c.id), # o2m, m2o
+        })
+        mapper(Order, orders, properties={
+            'items':relationship(Item, 
+                        secondary=order_items, order_by=items.c.id),  #m2m
+        })
+        mapper(Item, items, properties={
+            'keywords':relationship(Keyword, 
+                                        secondary=item_keywords,
+                                        order_by=keywords.c.id) #m2m
+        })
+        mapper(Keyword, keywords)
+
+        for opt, count in [
+            ((
+                eagerload(User.orders, Order.items), 
+            ), 10),
+            ((eagerload("orders.items"), ), 10),
+            ((
+                eagerload(User.orders, ), 
+                eagerload(User.orders, Order.items), 
+                eagerload(User.orders, Order.items, Item.keywords), 
+            ), 1),
+            ((
+                eagerload(User.orders, Order.items, Item.keywords), 
+            ), 10),
+            ((
+                eagerload(User.orders, Order.items), 
+                eagerload(User.orders, Order.items, Item.keywords), 
+            ), 5),
+        ]:
+            sess = create_session()
+            def go():
+                eq_(
+                    sess.query(User).options(*opt).order_by(User.id).all(),
+                    self.static.user_item_keyword_result
+                )
+            self.assert_sql_count(testing.db, go, count)
+
+
+
     @testing.resolve_artifact_names
     def test_many_to_many(self):
 
index 09d1387f7b0cb104dc9d6db457021187bc07d39a..dbca519805de7f12abd6bf9864d3aea97aecfc36 100644 (file)
@@ -1336,6 +1336,14 @@ class DeepOptionsTest(_fixtures.FixtureTest):
             x = u[0].orders[1].items[0].keywords[1]
         self.sql_count_(2, go)
 
+        sess = create_session()
+        q3 = sess.query(User).options(
+                    sa.orm.eagerload(User.orders, Order.items, Item.keywords))
+        u = q3.all()
+        def go():
+            x = u[0].orders[1].items[0].keywords[1]
+        self.sql_count_(2, go)
+
 class ValidatorTest(_fixtures.FixtureTest):
     @testing.resolve_artifact_names
     def test_scalar(self):
index f05fba64e591d1b198f7bb817f535ecf2b1fc52e..f69bb78c207f63d0b9568408478ce773b108f8d8 100644 (file)
@@ -1,7 +1,7 @@
 from sqlalchemy.test.testing import eq_, is_, is_not_
 from sqlalchemy.test import testing
 from sqlalchemy.orm import backref, subqueryload, subqueryload_all, \
-                mapper, relationship, \
+                mapper, relationship, clear_mappers,\
                 create_session, lazyload, aliased, eagerload
 from sqlalchemy.test.testing import eq_, assert_raises
 from sqlalchemy.test.assertsql import CompiledSQL
@@ -11,7 +11,7 @@ import sqlalchemy as sa
 class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
     run_inserts = 'once'
     run_deletes = None
-
+    
     @testing.resolve_artifact_names
     def test_basic(self):
         mapper(User, users, properties={
@@ -164,8 +164,28 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
             User(id=10, addresses=[])
         ], sess.query(User).order_by(User.id).all())
 
+    _pathing_runs = [
+        ( "lazyload", "lazyload", "lazyload", 15 ),
+        ("eagerload", "eagerload", "eagerload", 1),
+        ("subqueryload", "lazyload", "lazyload", 12),
+        ("subqueryload", "subqueryload", "lazyload", 8),
+        ("eagerload", "subqueryload", "lazyload", 7),
+        ("lazyload", "lazyload", "subqueryload", 12),
+        
+        # here's the one that fails:
+        #("subqueryload", "subqueryload", "subqueryload", 4),
+    ]
+#    _pathing_runs = [("subqueryload", "subqueryload", "subqueryload", 4)]
+    _pathing_runs = [("lazyload", "lazyload", "subqueryload", 12)]
+    
+    def test_options_pathing(self):
+        self._do_options_test(self._pathing_runs)
+    
+    def test_mapper_pathing(self):
+        self._do_mapper_test(self._pathing_runs)
+    
     @testing.resolve_artifact_names
-    def test_pathing(self):
+    def _do_options_test(self, configs):
         mapper(User, users, properties={
             'orders':relationship(Order, order_by=orders.c.id), # o2m, m2o
         })
@@ -180,39 +200,81 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
         })
         mapper(Keyword, keywords)
         
+        callables = {
+                        'eagerload':eagerload, 
+                    'subqueryload':subqueryload
+                }
+        
+        for o, i, k, count in configs:
+            options = []
+            if o in callables:
+                options.append(callables[o](User.orders))
+            if i in callables:
+                options.append(callables[i](User.orders, Order.items))
+            if k in callables:
+                options.append(callables[k](User.orders, Order.items, Item.keywords))
 
-        for opt, count in [
-#            ((
-#                lazyload(User.orders), 
-#                lazyload(User.orders, Order.items), 
-#                lazyload(User.orders, Order.items, Item.keywords)
-#            ), 14),
-#            ((
-#                eagerload(User.orders), 
-#                eagerload(User.orders, Order.items), 
-#                eagerload(User.orders, Order.items, Item.keywords)
-#            ), 1),
-#            ((
-#                subqueryload(User.orders), 
-#            ), 12),
-            ((
-                subqueryload(User.orders), 
-                subqueryload(User.orders, Order.items), 
-            ), 8),
-#            ((
-#                subqueryload(User.orders), 
-#                subqueryload(User.orders, Order.items), 
-#                subqueryload(User.orders, Order.items, Item.keywords), 
-#            ), 4),
-        ]:
             sess = create_session()
             def go():
                 eq_(
-                    sess.query(User).options(*opt).order_by(User.id).all(),
+                    sess.query(User).options(*options).order_by(User.id).all(),
                     self.static.user_item_keyword_result
                 )
             self.assert_sql_count(testing.db, go, count)
+
+            sess = create_session()
+#            def go():
+            eq_(
+                sess.query(User).filter(User.name=='fred').
+                        options(*options).order_by(User.id).all(),
+                self.static.user_item_keyword_result[2:3]
+            )
+#            self.assert_sql_count(testing.db, go, count)
+
+    @testing.resolve_artifact_names
+    def _do_mapper_test(self, configs):
+        opts = {
+            'lazyload':'select',
+            'eagerload':'joined',
+            'subqueryload':'subquery',
             
+        }
+
+        for o, i, k, count in configs:
+            mapper(User, users, properties={
+                'orders':relationship(Order, lazy=opts[o], order_by=orders.c.id), 
+            })
+            mapper(Order, orders, properties={
+                'items':relationship(Item, 
+                            secondary=order_items, lazy=opts[i], order_by=items.c.id), 
+            })
+            mapper(Item, items, properties={
+                'keywords':relationship(Keyword, 
+                                            lazy=opts[k],
+                                            secondary=item_keywords,
+                                            order_by=keywords.c.id)
+            })
+            mapper(Keyword, keywords)
+
+            sess = create_session()
+            def go():
+                eq_(
+                    sess.query(User).order_by(User.id).all(),
+                    self.static.user_item_keyword_result
+                )
+            try:
+                self.assert_sql_count(testing.db, go, count)
+
+                eq_(
+                    sess.query(User).filter(User.name=='fred').
+                            order_by(User.id).all(),
+                    self.static.user_item_keyword_result[2:3]
+                )
+
+            finally:
+                clear_mappers()
+        
+        
     # TODO: all the tests in test_eager_relations
     
     # TODO: ensure state stuff works out OK, existing objects/collections