]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- using contains_eager() against an alias combined with an overall query alias repair...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Oct 2008 22:39:19 +0000 (22:39 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Oct 2008 22:39:19 +0000 (22:39 +0000)
contains_eager adapter wraps the query adapter, not vice versa.  Test coverage added.
- contains_eager() will now add columns into the "primary" column collection within Query._compile_context(), instead
of the "secondary" collection.  This allows those columns to get wrapped within the subquery generated
by limit/offset in conjunction with an ORM-generated eager join.
Eager strategy also picks up on context.adapter in this case to deliver the columns during result load.
contains_eager() is now compatible with the subquery generated by a regular eager load
with limit/offset. [ticket:1180]

CHANGES
lib/sqlalchemy/orm/strategies.py
test/orm/query.py

diff --git a/CHANGES b/CHANGES
index c5ddb123014af49013097ee915ff719b5180db3b..3ad0f12c06ff04c461b9e2c22509352e2fd40114 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -33,6 +33,15 @@ CHANGES
 
     - Added an example illustrating Celko's "nested sets" as a 
       SQLA mapping.
+    
+    - contains_eager() with an alias argument works even when 
+      the alias is embedded in a SELECT, as when sent to the
+      Query via query.select_from().
+      
+    - contains_eager() usage is now compatible with a Query that
+      also contains a regular eager load and limit/offset, in that
+      the columns are added to the Query-generated subquery.
+      [ticket:1180]
       
     - session.execute() will execute a Sequence object passed to
       it (regression from 0.4).
index 308266cd8319978169b5a630914c1e7187674f00..7439ab68b774b1ed661898aabf62c2b74d9edf86 100644 (file)
@@ -595,24 +595,28 @@ class EagerLoader(AbstractRelationLoader):
         path = path + (self.key,)
 
         # check for user-defined eager alias
-        if ("eager_row_processor", path) in context.attributes:
-            clauses = context.attributes[("eager_row_processor", path)]
+        if ("user_defined_eager_row_processor", path) in context.attributes:
+            clauses = context.attributes[("user_defined_eager_row_processor", path)]
             
             adapter = entity._get_entity_clauses(context.query, context)
             if adapter and clauses:
-                context.attributes[("eager_row_processor", path)] = clauses = adapter.wrap(clauses)
+                context.attributes[("user_defined_eager_row_processor", path)] = clauses = clauses.wrap(adapter)
             elif adapter:
-                context.attributes[("eager_row_processor", path)] = clauses = adapter
-                
+                context.attributes[("user_defined_eager_row_processor", path)] = clauses = adapter
+            
+            add_to_collection = context.primary_columns
+            
         else:
             clauses = self._create_eager_join(context, entity, path, adapter, parentmapper)
             if not clauses:
                 return
 
             context.attributes[("eager_row_processor", path)] = clauses
+
+            add_to_collection = context.secondary_columns
             
         for value in self.mapper._iterate_polymorphic_properties():
-            value.setup(context, entity, path + (self.mapper.base_mapper,), clauses, parentmapper=self.mapper, column_collection=context.secondary_columns)
+            value.setup(context, entity, path + (self.mapper.base_mapper,), clauses, parentmapper=self.mapper, column_collection=add_to_collection)
     
     def _create_eager_join(self, context, entity, path, adapter, parentmapper):
         # check for join_depth or basic recursion,
@@ -691,7 +695,15 @@ class EagerLoader(AbstractRelationLoader):
         return clauses
         
     def _create_eager_adapter(self, context, row, adapter, path):
-        if ("eager_row_processor", path) in context.attributes:
+        if ("user_defined_eager_row_processor", path) in context.attributes:
+            decorator = context.attributes[("user_defined_eager_row_processor", path)]
+            # user defined eagerloads are part of the "primary" portion of the load.
+            # the adapters applied to the Query should be honored.
+            if context.adapter and decorator:
+                decorator = decorator.wrap(context.adapter)
+            elif context.adapter:
+                decorator = context.adapter
+        elif ("eager_row_processor", path) in context.attributes:
             decorator = context.attributes[("eager_row_processor", path)]
         else:
             if self._should_log_debug:
@@ -789,8 +801,8 @@ class LoadEagerFromAliasOption(PropertyOption):
 
                 prop = mapper.get_property(propname, resolve_synonyms=True)
                 self.alias = prop.target.alias(self.alias)
-            query._attributes[("eager_row_processor", paths[-1])] = sql_util.ColumnAdapter(self.alias)
+            query._attributes[("user_defined_eager_row_processor", paths[-1])] = sql_util.ColumnAdapter(self.alias)
         else:
-            query._attributes[("eager_row_processor", paths[-1])] = None
+            query._attributes[("user_defined_eager_row_processor", paths[-1])] = None
 
         
index 5d77f2d8155a43cdd5e285e9f82a955c300aedc1..12c75d94f144e58d35a60d6e08e8a2d3cde5e855 100644 (file)
@@ -1234,16 +1234,27 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
             assert fixtures.user_address_result == l
         self.assert_sql_count(testing.db, go, 1)
 
+        # same thing, but alias addresses, so that the adapter generated by select_from() is wrapped within
+        # the adapter created by contains_eager()
+        adalias = addresses.alias()
+        query = users.select(users.c.id==7).union(users.select(users.c.id>7)).alias('ulist').outerjoin(adalias).select(use_labels=True,order_by=['ulist.id', adalias.c.id])
+        def go():
+            l = sess.query(User).select_from(query).options(contains_eager('addresses', alias=adalias)).all()
+            assert fixtures.user_address_result == l
+        self.assert_sql_count(testing.db, go, 1)
+
     def test_contains_eager(self):
         sess = create_session()
 
         # test that contains_eager suppresses the normal outer join rendering
         q = sess.query(User).outerjoin(User.addresses).options(contains_eager(User.addresses)).order_by(User.id)
-        self.assert_compile(q.with_labels().statement, "SELECT users.id AS users_id, users.name AS users_name, "\
-                "addresses.id AS addresses_id, addresses.user_id AS addresses_user_id, "\
-                "addresses.email_address AS addresses_email_address FROM users LEFT OUTER JOIN addresses "\
-                "ON users.id = addresses.user_id ORDER BY users.id", dialect=default.DefaultDialect())
-
+        self.assert_compile(q.with_labels().statement, 
+            "SELECT addresses.id AS addresses_id, addresses.user_id AS addresses_user_id, "\
+            "addresses.email_address AS addresses_email_address, users.id AS users_id, "\
+            "users.name AS users_name FROM users LEFT OUTER JOIN addresses "\
+            "ON users.id = addresses.user_id ORDER BY users.id"
+            , dialect=default.DefaultDialect())
+                    
         def go():
             assert fixtures.user_address_result == q.all()
         self.assert_sql_count(testing.db, go, 1)
@@ -1266,7 +1277,6 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
 
         sess.clear()
 
-
         def go():
             l = list(q.options(contains_eager(User.addresses)).instances(selectquery.execute()))
             assert fixtures.user_address_result[0:3] == l
@@ -1306,7 +1316,6 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
             assert fixtures.user_address_result == l.all()
         self.assert_sql_count(testing.db, go, 1)
         sess.clear()
-        
 
         oalias = orders.alias('o1')
         ialias = items.alias('i1')
@@ -1337,7 +1346,42 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
         self.assert_sql_count(testing.db, go, 1)
         sess.clear()
 
+    def test_mixed_eager_contains_with_limit(self):
+        sess = create_session()
+        
+        q = sess.query(User)
+        def go():
+            # outerjoin to User.orders, offset 1/limit 2 so we get user 7 + second two orders.
+            # then eagerload the addresses.  User + Order columns go into the subquery, address
+            # left outer joins to the subquery, eagerloader for User.orders applies context.adapter 
+            # to result rows.  This was [ticket:1180].
+            l = q.outerjoin(User.orders).options(eagerload(User.addresses), contains_eager(User.orders)).offset(1).limit(2).all()
+            eq_(l, [User(id=7,
+            addresses=[Address(email_address=u'jack@bean.com',user_id=7,id=1)],
+            name=u'jack',
+            orders=[
+                Order(address_id=1,user_id=7,description=u'order 3',isopen=1,id=3), 
+                Order(address_id=None,user_id=7,description=u'order 5',isopen=0,id=5)
+            ])])
+        self.assert_sql_count(testing.db, go, 1)
 
+        sess.clear()
+        
+        def go():
+            # same as above, except Order is aliased, so two adapters are applied by the
+            # eager loader
+            oalias = aliased(Order)
+            l = q.outerjoin(User.orders, oalias).options(eagerload(User.addresses), contains_eager(User.orders, alias=oalias)).offset(1).limit(2).all()
+            eq_(l, [User(id=7,
+            addresses=[Address(email_address=u'jack@bean.com',user_id=7,id=1)],
+            name=u'jack',
+            orders=[
+                Order(address_id=1,user_id=7,description=u'order 3',isopen=1,id=3), 
+                Order(address_id=None,user_id=7,description=u'order 5',isopen=0,id=5)
+            ])])
+        self.assert_sql_count(testing.db, go, 1)
+        
+        
 class MixedEntitiesTest(QueryTest):
 
     def test_values(self):