]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Query.join() can now construct multiple FROM clauses, if
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 15 Mar 2009 03:02:42 +0000 (03:02 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 15 Mar 2009 03:02:42 +0000 (03:02 +0000)
needed.  Such as, query(A, B).join(A.x).join(B.y)
might say SELECT A.*, B.* FROM A JOIN X, B JOIN Y.
Eager loading can also tack its joins onto those
multiple FROM clauses.  [ticket:1337]

CHANGES
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql/util.py
test/orm/eager_relations.py
test/orm/inheritance/query.py

diff --git a/CHANGES b/CHANGES
index 35af1e41b1ea25798b4785ad9675bb2408b2d7c0..503031d4118f0050470ac5e602679a0a7ab7bf1f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -19,6 +19,12 @@ CHANGES
       union(query1, query2), select([foo]).select_from(query), 
       etc.
 
+    - Query.join() can now construct multiple FROM clauses, if 
+      needed.  Such as, query(A, B).join(A.x).join(B.y)
+      might say SELECT A.*, B.* FROM A JOIN X, B JOIN Y.  
+      Eager loading can also tack its joins onto those 
+      multiple FROM clauses.  [ticket:1337]
+      
     - Fixed bug where column_prefix wasn't being checked before
       not mapping an attribute that already had class-level 
       name present.
index a8be43f5a1ccc57f85fff132b936260bc2f2c329..9dba9edad45ef173f9860b6311b1872793e98355 100644 (file)
@@ -83,7 +83,7 @@ class Query(object):
         self._current_path = ()
         self._only_load_props = None
         self._refresh_state = None
-        self._from_obj = None
+        self._from_obj = ()
         self._polymorphic_adapters = {}
         self._filter_aliases = None
         self._from_obj_alias = None
@@ -135,11 +135,11 @@ class Query(object):
         if isinstance(from_obj, expression._SelectBaseMixin):
             from_obj = from_obj.alias()
 
-        self._from_obj = from_obj
+        self._from_obj = (from_obj,)
         equivs = self.__all_equivs()
 
         if isinstance(from_obj, expression.Alias):
-            self._from_obj_alias = sql_util.ColumnAdapter(self._from_obj, equivs)
+            self._from_obj_alias = sql_util.ColumnAdapter(from_obj, equivs)
 
     def _get_polymorphic_adapter(self, entity, selectable):
         self.__mapper_loads_polymorphically_with(entity.mapper, sql_util.ColumnAdapter(selectable, entity.mapper._equivalent_columns))
@@ -258,12 +258,6 @@ class Query(object):
         self._entities = [entity] + self._entities[1:]
         return entity
 
-    def __mapper_zero_from_obj(self):
-        if self._from_obj:
-            return self._from_obj
-        else:
-            return self._entity_zero().selectable
-
     def __all_equivs(self):
         equivs = {}
         for ent in self._mapper_entities:
@@ -276,7 +270,8 @@ class Query(object):
                 self._group_by:
             raise sa_exc.InvalidRequestError("Query.%s() being called on a Query with existing criterion. " % meth)
 
-        self._statement = self._criterion = self._from_obj = None
+        self._from_obj = ()
+        self._statement = self._criterion = None
         self._order_by = self._group_by = self._distinct = False
         self.__joined_tables = {}
 
@@ -899,10 +894,8 @@ class Query(object):
         if not from_joinpoint:
             self.__reset_joinpoint()
 
-        # join from our from_obj.  This is
-        # None unless select_from()/from_self() has been called.
-        clause = self._from_obj
-
+        clause = replace_clause_index = None
+        
         # after the method completes,
         # the query's joinpoint will be set to this.
         right_entity = None
@@ -963,10 +956,13 @@ class Query(object):
             elif not left_entity:
                 left_entity = self._joinpoint_zero()
 
-            # if no initial left-hand clause is set, extract
-            # this from the left_entity or as a last
-            # resort from the onclause argument, if it's
-            # a PropComparator.
+            if not clause and self._from_obj:
+                mp, left_selectable, is_aliased_class = _entity_info(left_entity)
+
+                replace_clause_index, clause = sql_util.find_join_source(self._from_obj, left_selectable)
+                if not clause:
+                    clause = left_selectable
+                    
             if not clause:
                 for ent in self._entities:
                     if ent.corresponds_to(left_entity):
@@ -1067,7 +1063,7 @@ class Query(object):
             # construct the onclause.
             join_to_left = not is_aliased_class or \
                             onclause is prop or \
-                            clause is self._from_obj and self._from_obj_alias
+                            self._from_obj_alias and clause is self._from_obj[0]
 
             # create the join
             clause = orm_join(clause, right_entity, onclause, isouter=outerjoin, join_to_left=join_to_left)
@@ -1086,9 +1082,12 @@ class Query(object):
                                         ORMAdapter(right_entity, equivalents=right_mapper._equivalent_columns)
                                     )
 
-        # loop finished.  we're selecting from
-        # our final clause now
-        self._from_obj = clause
+        if replace_clause_index is not None:
+            l = list(self._from_obj)
+            l[replace_clause_index] = clause
+            self._from_obj = tuple(l)
+        else:
+            self._from_obj = self._from_obj + (clause,)
 
         # future joins with from_joinpoint=True join from our established right_entity.
         self._joinpoint = right_entity
@@ -1115,7 +1114,13 @@ class Query(object):
         `from_obj` is a single table or selectable.
 
         """
+        
         if isinstance(from_obj, (tuple, list)):
+            # from_obj is actually a list again as of 0.5.3.   so this restriction here
+            # is somewhat artificial, but is still in place since select_from() implies aliasing all further
+            # criterion against what's placed here, and its less complex to only
+            # keep track of a single aliased FROM element being selected against.  This could in theory be opened
+            # up again to more complexity.
             util.warn_deprecated("select_from() now accepts a single Selectable as its argument, which replaces any existing FROM criterion.")
             from_obj = from_obj[-1]
         if not isinstance(from_obj, expression.FromClause):
@@ -1474,7 +1479,7 @@ class Query(object):
             entity.setup_context(self, context)
 
         if context.from_clause:
-            from_obj = [context.from_clause]
+            from_obj = list(context.from_clause)
         else:
             from_obj = context.froms
 
@@ -1722,7 +1727,7 @@ class Query(object):
         eager_joins = context.eager_joins.values()
 
         if context.from_clause:
-            froms = [context.from_clause]  # "load from a single FROM" mode, i.e. when select_from() or join() is used
+            froms = list(context.from_clause)  # "load from explicit FROMs" mode, i.e. when select_from() or join() is used
         else:
             froms = context.froms   # "load from discrete FROMs" mode, i.e. when each _MappedEntity has its own FROM
 
index f46ffc44d8a9dd4c36ac1b4088be70a083d4a48b..dc64b283d6bec599ad3d5372f0a6a5572299132b 100644 (file)
@@ -660,18 +660,26 @@ class EagerLoader(AbstractRelationLoader):
 
         if entity in context.eager_joins:
             entity_key, default_towrap = entity, entity.selectable
-        elif should_nest_selectable or not context.from_clause or not sql_util.search(context.from_clause, entity.selectable):
-            # if no from_clause, or a from_clause we can't join to, or a subquery is going to be generated, 
+
+        elif should_nest_selectable or not context.from_clause:
+            # if no from_clause, or a subquery is going to be generated, 
             # store eager joins per _MappedEntity; Query._compile_context will 
             # add them as separate selectables to the select(), or splice them together
             # after the subquery is generated
             entity_key, default_towrap = entity, entity.selectable
         else:
-            # otherwise, create a single eager join from the from clause.  
-            # Query._compile_context will adapt as needed and append to the
-            # FROM clause of the select().
-            entity_key, default_towrap = None, context.from_clause  
-
+            index, clause = sql_util.find_join_source(context.from_clause, entity.selectable)
+            if clause:
+                # join to an existing FROM clause on the query.
+                # key it to its list index in the eager_joins dict.
+                # Query._compile_context will adapt as needed and append to the
+                # FROM clause of the select().
+                entity_key, default_towrap = index, clause
+            else:
+                # if no from_clause to join to,
+                # store eager joins per _MappedEntity
+                entity_key, default_towrap = entity, entity.selectable
+                
         towrap = context.eager_joins.setdefault(entity_key, default_towrap)
 
         # create AliasedClauses object to build up the eager query.  
index b8ceabb7414ddebfaa66ce2d33e09791e4952bc6..a8de5c6352b696f253e00e06486b025344155b9f 100644 (file)
@@ -21,15 +21,31 @@ def sort_tables(tables):
         visitors.traverse(table, {'schema_visitor':True}, {'foreign_key':visit_foreign_key})    
     return topological.sort(tuples, tables)
 
-def search(clause, target):
-    if not clause:
-        return False
-    for elem in visitors.iterate(clause, {'column_collections':False}):
-        if elem is target:
-            return True
+def find_join_source(clauses, join_to):
+    """Given a list of FROM clauses and a selectable, 
+    return the first index and element from the list of 
+    clauses which can be joined against the selectable.  returns 
+    None, None if no match is found.
+    
+    e.g.::
+    
+        clause1 = table1.join(table2)
+        clause2 = table4.join(table5)
+        
+        join_to = table2.join(table3)
+        
+        find_join_source([clause1, clause2], join_to) == clause1
+    
+    """
+    
+    selectables = list(expression._from_objects(join_to))
+    for i, f in enumerate(clauses):
+        for s in selectables:
+            if f.is_derived_from(s):
+                return i, f
     else:
-        return False
-
+        return None, None
+    
 def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False):
     """locate Table objects within the given expression."""
     
index be1d0a955d737a5dd4eb20385634760a4ec62326..1876eca25814195c8d628a40f467f322b9b160ab 100644 (file)
@@ -4,7 +4,7 @@ import testenv; testenv.configure_for_tests()
 from testlib import sa, testing
 from sqlalchemy.orm import eagerload, deferred, undefer
 from testlib.sa import Table, Column, Integer, String, Date, ForeignKey, and_, select, func
-from testlib.sa.orm import mapper, relation, create_session, lazyload
+from testlib.sa.orm import mapper, relation, create_session, lazyload, aliased
 from testlib.testing import eq_
 from testlib.assertsql import CompiledSQL
 from orm import _base, _fixtures
@@ -1219,6 +1219,48 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
             )
         self.assert_sql_count(testing.db, go, 1)
 
+    @testing.exclude('sqlite', '>', (0, 0, 0), "sqlite flat out blows it on the multiple JOINs")
+    @testing.resolve_artifact_names
+    def test_two_entities_with_joins(self):
+        sess = create_session()
+        
+        # two FROM clauses where there's a join on each one
+        def go():
+            u1 = aliased(User)
+            o1 = aliased(Order)
+            eq_(
+                [
+                    (
+                        User(addresses=[Address(email_address=u'fred@fred.com')], name=u'fred'), 
+                        Order(description=u'order 2', isopen=0, items=[Item(description=u'item 1'), Item(description=u'item 2'), Item(description=u'item 3')]),
+                        User(addresses=[Address(email_address=u'jack@bean.com')], name=u'jack'), 
+                        Order(description=u'order 3', isopen=1, items=[Item(description=u'item 3'), Item(description=u'item 4'), Item(description=u'item 5')])
+                    ), 
+
+                    (
+                        User(addresses=[Address(email_address=u'fred@fred.com')], name=u'fred'), 
+                        Order(description=u'order 2', isopen=0, items=[Item(description=u'item 1'), Item(description=u'item 2'), Item(description=u'item 3')]),
+                        User(addresses=[Address(email_address=u'jack@bean.com')], name=u'jack'), 
+                        Order(address_id=None, description=u'order 5', isopen=0, items=[Item(description=u'item 5')])
+                    ), 
+
+                    (
+                        User(addresses=[Address(email_address=u'fred@fred.com')], name=u'fred'), 
+                        Order(description=u'order 4', isopen=1, items=[Item(description=u'item 1'), Item(description=u'item 5')]),
+                        User(addresses=[Address(email_address=u'jack@bean.com')], name=u'jack'), 
+                        Order(address_id=None, description=u'order 5', isopen=0, items=[Item(description=u'item 5')])
+                    ), 
+                ],
+                sess.query(User, Order, u1, o1).\
+                        join((Order, User.orders)).options(eagerload(User.addresses), eagerload(Order.items)).filter(User.id==9).\
+                        join((o1, u1.orders)).options(eagerload(u1.addresses), eagerload(o1.items)).filter(u1.id==7).\
+                        filter(Order.id<o1.id).\
+                        order_by(User.id, Order.id, u1.id, o1.id).all(),
+            )
+        self.assert_sql_count(testing.db, go, 1)
+        
+        
+
     @testing.resolve_artifact_names
     def test_aliased_entity(self):
         sess = create_session()
index 0f5d982f1df100ed649c8d5726d6e4d0dde104d9..c101bfd840c5a975301e3311dbd2c03400dfbcad 100644 (file)
@@ -210,6 +210,25 @@ def make_test(select_type):
             self.assertEquals(sess.query(Person).get(e1.person_id), Engineer(name="dilbert", primary_language="java"))
             self.assertEquals(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert", primary_language="java"))
             self.assertEquals(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss", golf_swing="fore"))
+        
+        def test_multi_join(self):
+            sess = create_session()
+
+            e = aliased(Person)
+            c = aliased(Company)
+            
+            q = sess.query(Company, Person, c, e).join((Person, Company.employees)).join((e, c.employees)).\
+                    filter(Person.name=='dilbert').filter(e.name=='wally')
+            
+            self.assertEquals(q.count(), 1)
+            self.assertEquals(q.all(), [
+                (
+                    Company(company_id=1,name=u'MegaCorp, Inc.'), 
+                    Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'),
+                    Company(company_id=1,name=u'MegaCorp, Inc.'), 
+                    Engineer(status=u'regular engineer',engineer_name=u'wally',name=u'wally',company_id=1,primary_language=u'c++',person_id=2,type=u'engineer')
+                )
+            ])
             
         def test_filter_on_subclass(self):
             sess = create_session()