]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed Query being able to join() from individual columns of
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Jun 2009 21:23:11 +0000 (21:23 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Jun 2009 21:23:11 +0000 (21:23 +0000)
a joined-table subclass entity, i.e.
query(SubClass.foo, SubcClass.bar).join(<anything>).
In most cases, an error "Could not find a FROM clause to join
from" would be raised. In a few others, the result would be
returned in terms of the base class rather than the subclass -
so applications which relied on this erroneous result need to be
adjusted. [ticket:1431]

CHANGES
lib/sqlalchemy/orm/query.py
test/orm/inheritance/query.py
test/orm/query.py

diff --git a/CHANGES b/CHANGES
index 63cf4ff3b62af7a3b5b6393975f63a40868aaf64..0653bd68dae9cb27d2f9afa746391e92ee49a03c 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -14,6 +14,15 @@ CHANGES
       wouldn't be deserialized correctly when the whole object
       was serialized.  [ticket:1426]
 
+    - Fixed Query being able to join() from individual columns of
+      a joined-table subclass entity, i.e.
+      query(SubClass.foo, SubcClass.bar).join(<anything>).
+      In most cases, an error "Could not find a FROM clause to join 
+      from" would be raised. In a few others, the result would be 
+      returned in terms of the base class rather than the subclass - 
+      so applications which relied on this erroneous result need to be 
+      adjusted. [ticket:1431]
+      
 - sql
     - Removed an obscure feature of execute() (including connection,
       engine, Session) whereby a bindparam() construct can be sent as 
index e3cc3c75697ef84a140ccc326d2c7a47913d1ee5..dee57bbc102e8d9b8693178c71cf8fcb59330dea 100644 (file)
@@ -899,7 +899,7 @@ class Query(object):
         # after the method completes,
         # the query's joinpoint will be set to this.
         right_entity = None
-
+        
         for arg1 in util.to_list(keys):
             aliased_entity = False
             alias_criterion = False
@@ -969,6 +969,10 @@ class Query(object):
                         clause = ent.selectable
                         break
 
+            # TODO:
+            # this provides one kind of "backwards join"
+            # tested in test/orm/query.py.
+            # remove this in 0.6
             if not clause:
                 if isinstance(onclause, interfaces.PropComparator):
                     clause = onclause.__clause_element__()
@@ -2057,11 +2061,8 @@ class _ColumnEntity(_QueryEntity):
         if _is_aliased_class(entity):
             return entity is self.entity_zero
         else:
-            # TODO: this will fail with inheritance, entity_zero
-            # is not a base mapper.  MapperEntity has path_entity
-            # which serves this purpose (when saying: query(FooBar.somecol).join(SomeClass, FooBar.id==SomeClass.foo_id))
-            return entity.base_mapper is self.entity_zero
-
+            return not _is_aliased_class(self.entity_zero) and entity.base_mapper.common_parent(self.entity_zero)
+            
     def _resolve_expr_against_query_aliases(self, query, expr, context):
         return query._adapt_clause(expr, False, True)
 
index 011576b043075895cca619aa0a6797428e2a0399..58d2054558619e13d4e433ae2e0141209c79310a 100644 (file)
@@ -325,7 +325,55 @@ def make_test(select_type):
                 c2
                 )
                 
-        
+        def test_join_from_columns_or_subclass(self):
+            sess = create_session()
+
+            self.assertEquals(
+                sess.query(Manager.name).order_by(Manager.name).all(),
+                [(u'dogbert',), (u'pointy haired boss',)]
+            )
+            
+            self.assertEquals(
+                sess.query(Manager.name).join((Paperwork, Manager.paperwork)).order_by(Manager.name).all(),
+                [(u'dogbert',), (u'dogbert',), (u'pointy haired boss',)]
+            )
+
+            self.assertEquals(
+                sess.query(Person.name).join((Paperwork, Person.paperwork)).order_by(Person.name).all(),
+                [(u'dilbert',), (u'dilbert',), (u'dogbert',), (u'dogbert',), (u'pointy haired boss',), (u'vlad',), (u'wally',), (u'wally',)]
+            )
+            
+            self.assertEquals(
+                sess.query(Person.name).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Person.name).all(),
+                [(u'dilbert',), (u'dilbert',), (u'dogbert',), (u'dogbert',), (u'pointy haired boss',), (u'vlad',), (u'wally',), (u'wally',)]
+            )
+            
+            self.assertEquals(
+                sess.query(Manager).join((Paperwork, Manager.paperwork)).order_by(Manager.name).all(),
+                [m1, b1]
+            )
+
+            self.assertEquals(
+                sess.query(Manager.name).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Manager.name).all(),
+                [(u'dogbert',), (u'dogbert',), (u'pointy haired boss',)]
+            )
+
+            self.assertEquals(
+                sess.query(Manager.person_id).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Manager.name).all(),
+                [(4,), (4,), (3,)]
+            )
+            
+            self.assertEquals(
+                sess.query(Manager.name, Paperwork.description).join((Paperwork, Manager.person_id==Paperwork.person_id)).all(),
+                [(u'pointy haired boss', u'review #1'), (u'dogbert', u'review #2'), (u'dogbert', u'review #3')]
+            )
+            
+            malias = aliased(Manager)
+            self.assertEquals(
+                sess.query(malias.name).join((paperwork, malias.person_id==paperwork.c.person_id)).all(),
+                [(u'pointy haired boss',), (u'dogbert',), (u'dogbert',)]
+            )
+            
         def test_expire(self):
             """test that individual column refresh doesn't get tripped up by the select_table mapper"""
             
index 29e2ad0b7b9fe6135da000cdab3bb4022c7a817c..33c3e39d7128f4ecc2a345e23d22ba6b257bccca 100644 (file)
@@ -1397,6 +1397,16 @@ class JoinTest(QueryTest):
             ]
         )
 
+    def test_plain_table(self):
+        
+        sess = create_session()
+        
+        self.assertEquals(
+            sess.query(User.name).join((addresses, User.id==addresses.c.user_id)).order_by(User.id).all(),
+            [(u'jack',), (u'ed',), (u'ed',), (u'ed',), (u'fred',)]
+        )
+        
+        
 class MultiplePathTest(_base.MappedTest):
     def define_tables(self, metadata):
         global t1, t2, t1t2_1, t1t2_2