]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Improved the determination of the FROM clause
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 6 Aug 2008 20:58:48 +0000 (20:58 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 6 Aug 2008 20:58:48 +0000 (20:58 +0000)
when placing SQL expressions in the query()
list of entities.  In particular scalar subqueries
should not "leak" their inner FROM objects out
into the enclosing query.

CHANGES
lib/sqlalchemy/orm/query.py
test/orm/inheritance/query.py
test/orm/query.py

diff --git a/CHANGES b/CHANGES
index 1a911122b707c680829c9862cf55da432fe1142c..a5258ceacffd04cea161dd3db5bb52d1c35a465b 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -16,6 +16,12 @@ CHANGES
       Column objects such as Query(table.c.col) will
       return the "key" attribute of the Column.
 
+    - Improved the determination of the FROM clause
+      when placing SQL expressions in the query()
+      list of entities.  In particular scalar subqueries
+      should not "leak" their inner FROM objects out
+      into the enclosing query.
+      
 - sql
     - Temporarily rolled back the "ORDER BY" enhancement
       from [ticket:1068].  This feature is on hold 
index 43dde9fc053bdd3841e0408bea0faa3251679640..b5250638649bc13650978ab441e470c2146311c8 100644 (file)
@@ -1625,11 +1625,22 @@ class _ColumnEntity(_QueryEntity):
 
         self.column = column
         self.froms = set()
+        
+        # look for ORM entities represented within the
+        # given expression.  Try to count only entities 
+        # for columns whos FROM object is in the actual list
+        # of FROMs for the overall expression - this helps
+        # subqueries which were built from ORM constructs from
+        # leaking out their entities into the main select construct
+        actual_froms = set(column._get_from_objects())
+
         self.entities = util.OrderedSet(
             elem._annotations['parententity']
             for elem in visitors.iterate(column, {})
-            if 'parententity' in elem._annotations)
-        
+            if 'parententity' in elem._annotations
+            and actual_froms.intersection(elem._get_from_objects())
+            )
+            
         if self.entities:
             self.entity_zero = list(self.entities)[0]
         else:
index 0ce3b4fb78a305e441a7216de2004303134ec9aa..28659ba38a3e4c71c8fc4ae459e653a38424f2f3 100644 (file)
@@ -510,7 +510,6 @@ def make_test(select_type):
             
             self.assertEquals(sess.query(Person).filter(Person.person_id==subq).one(), e1)
             
-        
         def test_mixed_entities(self):
             sess = create_session()
 
@@ -537,7 +536,6 @@ def make_test(select_type):
                 [('pointy haired boss foo', ), ('dogbert foo',)]
             )
 
-
             row = sess.query(Engineer.name, Engineer.primary_language).filter(Engineer.name=='dilbert').first()
             assert row.name == 'dilbert'
             assert row.primary_language == 'java'
index 30cb03ab86c1fc97f0cb7bb8d56ad7f3586be6cf..d3e69573ddf0fe74d83925f386243a84c69c0143 100644 (file)
@@ -1207,6 +1207,36 @@ class MixedEntitiesTest(QueryTest):
         q2 = q.select_from(sel).filter(users.c.id==8).filter(users.c.id>sel.c.id).values(users.c.name, sel.c.name, User.name)
         self.assertEquals(list(q2), [(u'ed', u'jack', u'jack')])
     
+    def test_scalar_subquery(self):
+        """test that a subquery constructed from ORM attributes doesn't leak out 
+        those entities to the outermost query.
+        
+        """
+        sess = create_session()
+        
+        subq = select([func.count()]).\
+            where(User.id==Address.user_id).\
+            correlate(users).\
+            label('count')
+
+        # we don't want Address to be outside of the subquery here
+        self.assertEquals(
+            list(sess.query(User, subq)[0:3]),
+            [(User(id=7,name=u'jack'), 1), (User(id=8,name=u'ed'), 3), (User(id=9,name=u'fred'), 1)]
+            )
+
+        # same thing without the correlate, as it should
+        # not be needed
+        subq = select([func.count()]).\
+            where(User.id==Address.user_id).\
+            label('count')
+
+        # we don't want Address to be outside of the subquery here
+        self.assertEquals(
+            list(sess.query(User, subq)[0:3]),
+            [(User(id=7,name=u'jack'), 1), (User(id=8,name=u'ed'), 3), (User(id=9,name=u'fred'), 1)]
+            )
+    
     def test_tuple_labeling(self):
         sess = create_session()
         for row in sess.query(User, Address).join(User.addresses).all():