]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
and here's where it gets *fun* ! so much for being easy
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 22 Mar 2010 20:54:58 +0000 (16:54 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 22 Mar 2010 20:54:58 +0000 (16:54 -0400)
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
test/orm/_fixtures.py
test/orm/test_eager_relations.py
test/orm/test_subquery_relations.py

index 03ebb97c4f57975b1af6b3770f11037a849094c3..c773e74f6197d42cd4158bfce8032eb10204e9e3 100644 (file)
@@ -646,7 +646,7 @@ class StrategizedProperty(MapperProperty):
     
     """
 
-    def __get_context_strategy(self, context, path):
+    def _get_context_strategy(self, context, path):
         cls = context.attributes.get(("loaderstrategy", _reduce_path(path)), None)
         if cls:
             try:
@@ -668,11 +668,11 @@ class StrategizedProperty(MapperProperty):
         return strategy
 
     def setup(self, context, entity, path, adapter, **kwargs):
-        self.__get_context_strategy(context, path + (self.key,)).\
+        self._get_context_strategy(context, path + (self.key,)).\
                     setup_query(context, entity, path, adapter, **kwargs)
 
     def create_row_processor(self, context, path, mapper, row, adapter):
-        return self.__get_context_strategy(context, path + (self.key,)).\
+        return self._get_context_strategy(context, path + (self.key,)).\
                     create_row_processor(context, path, mapper, row, adapter)
 
     def do_init(self):
@@ -775,12 +775,17 @@ class PropertyOption(MapperOption):
             isa = True
             
         for ent in query._mapper_entities:
-            if searchfor is ent.path_entity or (isa and searchfor.common_parent(ent.path_entity)):
+            if searchfor is ent.path_entity or (
+                                isa and
+                                searchfor.common_parent(ent.path_entity)):
                 return ent
         else:
             if raiseerr:
-                raise sa_exc.ArgumentError("Can't find entity %s in Query.  Current list: %r" 
-                    % (searchfor, [str(m.path_entity) for m in query._entities]))
+                raise sa_exc.ArgumentError(
+                    "Can't find entity %s in Query.  Current list: %r" 
+                    % (searchfor, [
+                                str(m.path_entity) for m in query._entities
+                                ]))
             else:
                 return None
 
@@ -921,7 +926,7 @@ class StrategizedOption(PropertyOption):
         return False
 
     def process_query_property(self, query, paths, mappers):
-        # __get_context_strategy may receive the path in terms of
+        # _get_context_strategy may receive the path in terms of
         # a base mapper - e.g.  options(eagerload_all(Company.employees, Engineer.machines))
         # in the polymorphic tests leads to "(Person, 'machines')" in 
         # the path due to the mechanics of how the eager strategy builds
index 7ae1194c1f7f8036ac10dfd01afbc1a3845d6987..f067172174ca5b638e1f4798733f7b3e22065ba6 100644 (file)
@@ -619,7 +619,8 @@ class Query(object):
         those being selected.
 
         """
-        fromclause = self.with_labels().enable_eagerloads(False).statement.correlate(None)
+        fromclause = self.with_labels().enable_eagerloads(False).\
+                                    statement.correlate(None)
         q = self._from_selectable(fromclause)
         if entities:
             q._set_entities(entities)
index 828530d7ac1bb2577b6883c62cee4281c60aa41e..08bb5062a7bc9a916acf383bdb051d19c8d3b32b 100644 (file)
@@ -644,28 +644,38 @@ class SubqueryLoader(AbstractRelationshipLoader):
         if not context.query._enable_eagerloads:
             return
 
-        path = path + (self.key,)
+#        path = path + (self.key,)
+        
         
-        local_cols, remote_cols = self._local_remote_columns
+        if ("orig_query", SubqueryLoader) not in context.attributes:
+            context.attributes[("orig_query", SubqueryLoader)] =\
+                    context.query
+        
+        orig_query = context.attributes[("orig_query", SubqueryLoader)]
+
+#        orig_query = context.query
+        path = context.query._current_path + path + (self.key, )
+        
+        prop = path[0].get_property(path[1])
+        
+        local_cols, remote_cols = self._local_remote_columns(prop)
 
         local_attr = [
-            self.parent._get_col_to_prop(c).class_attribute
+            path[0]._get_col_to_prop(c).class_attribute
             for c in local_cols
         ]
         
-        attr = self.parent_property.class_attribute
+        #attr = self.parent_property.class_attribute
         
         # modify the query to just look for parent columns in the 
         # join condition
         
-        # TODO: what happens to options() in the parent query ?  
-        # are they going
-        # to get in the way here ?
-        
         # set the original query to only look
         # for the significant columns, not order
         # by anything.
-        q = context.query._clone()
+        q = orig_query._clone() #context.query._clone()
+        q._attributes = q._attributes.copy()
+        q._attributes[("orig_query", SubqueryLoader)] = orig_query
         q._set_entities(local_attr)
         q._order_by = None
         
@@ -674,37 +684,66 @@ class SubqueryLoader(AbstractRelationshipLoader):
         
         # and join to the related thing we want
         # to load.
-        q = q.join(attr)
-                                                    
+        for mapper, key in [(path[i], path[i+1]) for i in xrange(0, len(path), 2)]:
+            prop = mapper.get_property(key)
+            q = q.join(prop.class_attribute)
+
         q = q.order_by(*local_attr)
         
+        q._attributes = q._attributes.copy()
+        for attr in orig_query._attributes:
+            strat, opt_path = attr
+            if strat == "loaderstrategy":
+                opt_path = opt_path[len(path):]
+                q._attributes[("loaderstrategy", opt_path)] =\
+                       context.query._attributes[attr]
+
+        q = q._with_current_path(path)
         if self.parent_property.order_by:
             q = q.order_by(*self.parent_property.order_by)
         
-        context.attributes[('subquery', path)] = q
+        context.attributes[('subquery', path)] = \
+                q._attributes[('subquery', path)] = \
+                q
+
+#        for value in self.mapper._iterate_polymorphic_properties():
+#            strat = value._get_context_strategy(
+#                                        context, path + 
+#                                        (self.mapper,value.key)
+#                                    )
+            #print "VALUE", value, "PATH", path + (self.mapper,), "STRAT", type(strat)
+#            if isinstance(strat, SubqueryLoader):
+#                value.setup(
+#                    context, 
+#                    entity, 
+##                    path + (self.mapper,), 
+#                    adapter, 
+#                    parentmapper=self.mapper,
+#                    )
     
-    @property
-    def _local_remote_columns(self):
-        if self.parent_property.secondary is None:
-            return zip(*self.parent_property.local_remote_pairs)
+    def _local_remote_columns(self, prop):
+        if prop.secondary is None:
+            return zip(*prop.local_remote_pairs)
         else:
             return \
-                [p[0] for p in self.parent_property.synchronize_pairs],\
+                [p[0] for p in prop.synchronize_pairs],\
                 [
-                    p[0] for p in self.parent_property.
+                    p[0] for p in prop.
                                         secondary_synchronize_pairs
                 ]
         
     def create_row_processor(self, context, path, mapper, row, adapter):
-        path = path + (self.key,)
+#        path = path + (self.key,)
+        path = context.query._current_path + path + (self.key,)
 
-        local_cols, remote_cols = self._local_remote_columns
+        local_cols, remote_cols = self._local_remote_columns(self.parent_property)
 
         local_attr = [self.parent._get_col_to_prop(c).key for c in local_cols]
         remote_attr = [
                         self.mapper._get_col_to_prop(c).key 
                         for c in remote_cols]
-
+        
+        print "STRAT LOOKING FOR SUBQ AT PATH", path
         q = context.attributes[('subquery', path)]
         
         collections = dict(
@@ -713,7 +752,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
                         q, 
                         lambda x:x[1:]
                     ))
-
+        
         def execute(state, dict_, row):
             collection = collections.get(
                 tuple([row[col] for col in local_cols]), 
index e9d6ac16565c4a55efa17471cae2a27a99ab6405..2809506084f192ac8ac21458aec90c93174df1cd 100644 (file)
@@ -378,5 +378,48 @@ class CannedResults(object):
                  keywords=[]),
             Item(id=5,
                  keywords=[])]
+    
+    @property
+    def user_item_keyword_result(self):
+        item1, item2, item3, item4, item5 = \
+             Item(id=1,
+                  keywords=[
+                    Keyword(name='red'),
+                    Keyword(name='big'),
+                    Keyword(name='round')]),\
+             Item(id=2,
+                  keywords=[
+                    Keyword(name='red'),
+                    Keyword(name='small'),
+                    Keyword(name='square')]),\
+             Item(id=3,
+                  keywords=[
+                    Keyword(name='green'),
+                    Keyword(name='big'),
+                    Keyword(name='round')]),\
+             Item(id=4,
+                  keywords=[]),\
+             Item(id=5,
+                  keywords=[])
+
+        user_result = [
+                User(id=7,
+                   orders=[
+                     Order(id=1,
+                           items=[item1, item2, item3]),
+                     Order(id=3,
+                           items=[item3, item4, item5]),
+                     Order(id=5,
+                           items=[item5])]),
+                User(id=8, orders=[]),
+                User(id=9,
+                   orders=[
+                     Order(id=2,
+                           items=[item1, item2, item3]),
+                     Order(id=4,
+                           items=[item1, item5])]),
+                User(id=10)]
+        return user_result
+        
 FixtureTest.static = CannedResults()
 
index 925a5b09fba7b5bbc7a2aa556030eaf0f3992202..4a5f876f0497eddcaa6d034ac6f5b356ca86054b 100644 (file)
@@ -440,10 +440,6 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
     @testing.resolve_artifact_names
     def test_limit(self):
         """Limit operations combined with lazy-load relationships."""
-        User, Item, Address, Order = self.classes.get_all(
-            'User', 'Item', 'Address', 'Order')
-        users, items, order_items, orders, addresses = self.tables.get_all(
-            'users', 'items', 'order_items', 'orders', 'addresses')
 
         mapper(Item, items)
         mapper(Order, orders, properties={
index ec01a94a991b6b7cd3b0aa15d50e72d1d277c93c..f05fba64e591d1b198f7bb817f535ecf2b1fc52e 100644 (file)
@@ -1,7 +1,8 @@
 from sqlalchemy.test.testing import eq_, is_, is_not_
 from sqlalchemy.test import testing
-from sqlalchemy.orm import backref, subqueryload, subqueryload_all
-from sqlalchemy.orm import mapper, relationship, create_session, lazyload, aliased
+from sqlalchemy.orm import backref, subqueryload, subqueryload_all, \
+                mapper, relationship, \
+                create_session, lazyload, aliased, eagerload
 from sqlalchemy.test.testing import eq_, assert_raises
 from sqlalchemy.test.assertsql import CompiledSQL
 from test.orm import _base, _fixtures
@@ -14,7 +15,9 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
     @testing.resolve_artifact_names
     def test_basic(self):
         mapper(User, users, properties={
-            'addresses':relationship(mapper(Address, addresses), order_by=Address.id)
+            'addresses':relationship(
+                            mapper(Address, addresses), 
+                            order_by=Address.id)
         })
         sess = create_session()
         
@@ -22,7 +25,8 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
         
         def go():
             eq_(
-                    [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])],
+                    [User(id=7, addresses=[
+                            Address(id=1, email_address='jack@bean.com')])],
                     q.filter(User.id==7).all()
             )
         
@@ -85,7 +89,9 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
         mapper(User, users, properties = {
             'addresses':relationship(mapper(Address, addresses), 
                             lazy='subquery', 
-                            order_by=[addresses.c.email_address, addresses.c.id]),
+                            order_by=[
+                                    addresses.c.email_address,
+                                    addresses.c.id]),
         })
         q = create_session().query(User)
         eq_([
@@ -110,11 +116,14 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
 
         mapper(Address, addresses)
         mapper(User, users, properties = dict(
-            addresses = relationship(Address, lazy='subquery', order_by=addresses.c.id),
+            addresses = relationship(Address, 
+                                        lazy='subquery',
+                                        order_by=addresses.c.id),
         ))
 
         q = create_session().query(User)
-        l = q.filter(User.id==Address.user_id).order_by(Address.email_address).all()
+        l = q.filter(User.id==Address.user_id).\
+            order_by(Address.email_address).all()
 
         eq_([
             User(id=8, addresses=[
@@ -135,7 +144,9 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
         mapper(Address, addresses)
         mapper(User, users, properties = dict(
             addresses = relationship(Address, lazy='subquery',
-                                 order_by=[sa.desc(addresses.c.email_address)]),
+                                 order_by=[
+                                    sa.desc(addresses.c.email_address)
+                                ]),
         ))
         sess = create_session()
         eq_([
@@ -153,6 +164,55 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
             User(id=10, addresses=[])
         ], sess.query(User).order_by(User.id).all())
 
+    @testing.resolve_artifact_names
+    def test_pathing(self):
+        mapper(User, users, properties={
+            'orders':relationship(Order, order_by=orders.c.id), # o2m, m2o
+        })
+        mapper(Order, orders, properties={
+            'items':relationship(Item, 
+                        secondary=order_items, order_by=items.c.id),  #m2m
+        })
+        mapper(Item, items, properties={
+            'keywords':relationship(Keyword, 
+                                        secondary=item_keywords,
+                                        order_by=keywords.c.id) #m2m
+        })
+        mapper(Keyword, keywords)
+        
+
+        for opt, count in [
+#            ((
+#                lazyload(User.orders), 
+#                lazyload(User.orders, Order.items), 
+#                lazyload(User.orders, Order.items, Item.keywords)
+#            ), 14),
+#            ((
+#                eagerload(User.orders), 
+#                eagerload(User.orders, Order.items), 
+#                eagerload(User.orders, Order.items, Item.keywords)
+#            ), 1),
+#            ((
+#                subqueryload(User.orders), 
+#            ), 12),
+            ((
+                subqueryload(User.orders), 
+                subqueryload(User.orders, Order.items), 
+            ), 8),
+#            ((
+#                subqueryload(User.orders), 
+#                subqueryload(User.orders, Order.items), 
+#                subqueryload(User.orders, Order.items, Item.keywords), 
+#            ), 4),
+        ]:
+            sess = create_session()
+            def go():
+                eq_(
+                    sess.query(User).options(*opt).order_by(User.id).all(),
+                    self.static.user_item_keyword_result
+                )
+            self.assert_sql_count(testing.db, go, count)
+            
     # TODO: all the tests in test_eager_relations
     
     # TODO: ensure state stuff works out OK, existing objects/collections