]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
test cases were not fully testing contains_eager() with regards to [ticket:777],...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 18 Nov 2007 16:18:54 +0000 (16:18 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 18 Nov 2007 16:18:54 +0000 (16:18 +0000)
lib/sqlalchemy/orm/strategies.py
test/orm/query.py

index 32d148e3cda65e94179f94a3d9058ed78b837efe..2caee2dd4a1b5b1f38ab83828790af7028739595 100644 (file)
@@ -690,16 +690,13 @@ class RowDecorateOption(PropertyOption):
 
     def process_query_property(self, query, paths):
         if self.alias is not None and self.decorator is None:
+            (mapper, propname) = paths[-1][-2:]
+
+            prop = mapper.get_property(propname, resolve_synonyms=True)
             if isinstance(self.alias, basestring):
-                (mapper, propname) = paths[-1]
-                prop = mapper.get_property(propname, resolve_synonyms=True)
                 self.alias = prop.target.alias(self.alias)
-            def decorate(row):
-                d = {}
-                for c in prop.target.columns:
-                    d[c] = row[self.alias.corresponding_column(c)]
-                return d
-            self.decorator = decorate
+
+            self.decorator = mapperutil.create_row_adapter(self.alias, prop.target)
         query._attributes[("eager_row_processor", paths[-1])] = self.decorator
 
 RowDecorateOption.logger = logging.class_logger(RowDecorateOption)
index 438bc9634f0b13996256075722c6ebf31831c352..086704cb2796ca4b239750f83a214be48db3fc5f 100644 (file)
@@ -604,14 +604,16 @@ class InstancesTest(QueryTest):
     def test_from_alias(self):
 
         query = users.select(users.c.id==7).union(users.select(users.c.id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True,order_by=['ulist.id', addresses.c.id])
-        q = create_session().query(User)
+        sess =create_session()
+        q = sess.query(User)
 
         def go():
             l = q.options(contains_alias('ulist'), contains_eager('addresses')).instances(query.execute())
             assert fixtures.user_address_result == l
         self.assert_sql_count(testbase.db, go, 1)
 
-
+        sess.clear()
+        
         def go():
             l = q.options(contains_alias('ulist'), contains_eager('addresses')).from_statement(query).all()
             assert fixtures.user_address_result == l
@@ -620,13 +622,16 @@ class InstancesTest(QueryTest):
     def test_contains_eager(self):
 
         selectquery = users.outerjoin(addresses).select(users.c.id<10, use_labels=True, order_by=[users.c.id, addresses.c.id])
-        q = create_session().query(User)
+        sess = create_session()
+        q = sess.query(User)
 
         def go():
             l = q.options(contains_eager('addresses')).instances(selectquery.execute())
             assert fixtures.user_address_result[0:3] == l
         self.assert_sql_count(testbase.db, go, 1)
 
+        sess.clear()
+        
         def go():
             l = q.options(contains_eager('addresses')).from_statement(selectquery).all()
             assert fixtures.user_address_result[0:3] == l
@@ -635,20 +640,24 @@ class InstancesTest(QueryTest):
     def test_contains_eager_alias(self):
         adalias = addresses.alias('adalias')
         selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.id, adalias.c.id])
-        q = create_session().query(User)
+        sess = create_session()
+        q = sess.query(User)
 
         def go():
             # test using a string alias name
             l = q.options(contains_eager('addresses', alias="adalias")).instances(selectquery.execute())
             assert fixtures.user_address_result == l
         self.assert_sql_count(testbase.db, go, 1)
-
+        sess.clear()
+        
         def go():
             # test using the Alias object itself
             l = q.options(contains_eager('addresses', alias=adalias)).instances(selectquery.execute())
             assert fixtures.user_address_result == l
         self.assert_sql_count(testbase.db, go, 1)
-
+        
+        sess.clear()
+        
         def decorate(row):
             d = {}
             for c in addresses.columns:
@@ -660,12 +669,33 @@ class InstancesTest(QueryTest):
             l = q.options(contains_eager('addresses', decorator=decorate)).instances(selectquery.execute())
             assert fixtures.user_address_result == l
         self.assert_sql_count(testbase.db, go, 1)
+        sess.clear()
+        
+        oalias = orders.alias('o1')
+        ialias = items.alias('i1')
+        query = users.outerjoin(oalias).outerjoin(order_items).outerjoin(ialias).select(use_labels=True)
+        q = create_session().query(User)
+        # test using string alias with more than one level deep
+        def go():
+            l = q.options(contains_eager('orders', alias='o1'), contains_eager('orders.items', alias='i1')).instances(query.execute())
+            assert fixtures.user_order_result == l
+        self.assert_sql_count(testbase.db, go, 1)
+
+        sess.clear()
+        
+        # test using Alias with more than one level deep
+        def go():
+            l = q.options(contains_eager('orders', alias=oalias), contains_eager('orders.items', alias=ialias)).instances(query.execute())
+            assert fixtures.user_order_result == l
+        self.assert_sql_count(testbase.db, go, 1)
+        sess.clear()
+
 
     def test_multi_mappers(self):
-        sess = create_session()
 
-        (user7, user8, user9, user10) = sess.query(User).all()
-        (address1, address2, address3, address4, address5) = sess.query(Address).all()
+        test_session = create_session()
+        (user7, user8, user9, user10) = test_session.query(User).all()
+        (address1, address2, address3, address4, address5) = test_session.query(Address).all()
 
         # note the result is a cartesian product
         expected = [(user7, address1),
@@ -675,27 +705,36 @@ class InstancesTest(QueryTest):
             (user9, address5),
             (user10, None)]
 
+        sess = create_session()
+
         selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.id, addresses.c.id])
         q = sess.query(User)
         l = q.instances(selectquery.execute(), Address)
         assert l == expected
-
+        
+        sess.clear()
+        
         for aliased in (False, True):
             q = sess.query(User)
+
             q = q.add_entity(Address).outerjoin('addresses', aliased=aliased)
             l = q.all()
             assert l == expected
+            sess.clear()
 
             q = sess.query(User).add_entity(Address)
             l = q.join('addresses', aliased=aliased).filter_by(email_address='ed@bettyboop.com').all()
             assert l == [(user8, address3)]
+            sess.clear()
 
             q = sess.query(User, Address).join('addresses', aliased=aliased).filter_by(email_address='ed@bettyboop.com')
             assert q.all() == [(user8, address3)]
+            sess.clear()
 
             q = sess.query(User, Address).join('addresses', aliased=aliased).options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com')
             assert q.all() == [(user8, address3)]
-
+            sess.clear()
+            
     def test_aliased_multi_mappers(self):
         sess = create_session()
 
@@ -716,6 +755,8 @@ class InstancesTest(QueryTest):
         l = q.all()
         assert l == expected
 
+        sess.clear()
+        
         q = sess.query(User).add_entity(Address, alias=adalias)
         l = q.select_from(users.outerjoin(adalias)).filter(adalias.c.email_address=='ed@bettyboop.com').all()
         assert l == [(user8, address3)]
@@ -727,7 +768,8 @@ class InstancesTest(QueryTest):
         
         for add_col in (User.name, users.c.name, User.c.name):
             assert sess.query(User).add_column(add_col).all() == expected
-
+            sess.clear()
+            
         try:
             sess.query(User).add_column(object()).all()
             assert False
@@ -751,7 +793,8 @@ class InstancesTest(QueryTest):
             q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses', aliased=aliased).add_column(func.count(Address.id).label('count'))
             l = q.all()
             assert l == expected
-
+            sess.clear()
+            
         s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(User.id)
         q = sess.query(User)
         l = q.add_column("count").from_statement(s).all()
@@ -772,14 +815,17 @@ class InstancesTest(QueryTest):
         q = create_session().query(User)
         l = q.add_column("count").add_column("concat").from_statement(s).all()
         assert l == expected
-
+        
+        sess.clear()
+        
         # test with select_from()
         q = create_session().query(User).add_column(func.count(addresses.c.id))\
             .add_column(("Name:" + users.c.name)).select_from(users.outerjoin(addresses))\
             .group_by([c for c in users.c]).order_by(users.c.id)
 
         assert q.all() == expected
-
+        sess.clear()
+        
         # test with outerjoin() both aliased and non
         for aliased in (False, True):
             q = create_session().query(User).add_column(func.count(addresses.c.id))\
@@ -787,7 +833,8 @@ class InstancesTest(QueryTest):
                 .group_by([c for c in users.c]).order_by(users.c.id)
 
             assert q.all() == expected
-
+            sess.clear()
+            
 class CustomJoinTest(QueryTest):
     keep_mappers = False