]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- query.join() can also accept tuples of attribute
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 23 Jan 2008 19:20:49 +0000 (19:20 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 23 Jan 2008 19:20:49 +0000 (19:20 +0000)
name/some selectable as arguments.  This allows
construction of joins *from* subclasses of a
polymorphic relation, i.e.:

query(Company).\
join(
  [('employees', people.join(engineer)), Engineer.name]
)

CHANGES
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/expression.py
test/orm/inheritance/query.py

diff --git a/CHANGES b/CHANGES
index 937bae302c86768084297a8ecccd2b09fd3c2548..947ee360354b2d6f23be0bf3d13b228a5bb98652 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -54,8 +54,18 @@ CHANGES
       relation, i.e.:
 
         query(Company).join(['employees', Engineer.name])
-
-    - General improvements to the behavior of join() in
+        
+    - query.join() can also accept tuples of attribute
+      name/some selectable as arguments.  This allows
+      construction of joins *from* subclasses of a 
+      polymorphic relation, i.e.:
+      
+        query(Company).\
+        join(
+          [('employees', people.join(engineer)), Engineer.name]
+        )
+      
+    - General improvements to the behavior of join() in 
       conjunction with polymorphic mappers, i.e. joining
       from/to polymorphic mappers and properly applying
       aliases.
index 201d0e2e308264b74960461ae60279e6002110f1..bcf0f3dd336e5c758635a644a3df8e60c6883fe0 100644 (file)
@@ -422,18 +422,39 @@ class Query(object):
 
         mapper = start
         alias = self._aliases
-        for key in util.to_list(keys):
+        if not isinstance(keys, list):
+            keys = [keys]
+        for key in keys:
+            use_selectable = None
+            if isinstance(key, tuple):
+                key, use_selectable = key
+
             if isinstance(key, interfaces.PropComparator):
                 prop = key.property
             else:
                 prop = mapper.get_property(key, resolve_synonyms=True)
+
+            if use_selectable:
+                if not use_selectable.is_derived_from(prop.mapper.mapped_table):
+                    raise exceptions.InvalidRequestError("Selectable '%s' is not derived from '%s'" % (use_selectable.description, prop.mapper.mapped_table.description))
+                if not isinstance(use_selectable, expression.Alias):
+                    use_selectable = use_selectable.alias()
                 
-            if prop._is_self_referential() and not create_aliases:
+            if prop._is_self_referential() and not create_aliases and not use_selectable:
                 raise exceptions.InvalidRequestError("Self-referential query on '%s' property requires create_aliases=True argument." % str(prop))
 
-            if prop.select_table not in currenttables or create_aliases:
+            if prop.select_table not in currenttables or create_aliases or use_selectable:
                 if prop.secondary:
-                    if create_aliases:
+                    if use_selectable:
+                        alias = mapperutil.PropertyAliasedClauses(prop,
+                            prop.primary_join_against(mapper, adapt_against),
+                            prop.secondary_join_against(mapper),
+                            alias,
+                            alias=use_selectable
+                        )
+                        crit = alias.primaryjoin
+                        clause = clause.join(alias.secondary, crit, isouter=outerjoin).join(alias.alias, alias.secondaryjoin, isouter=outerjoin)
+                    elif create_aliases:
                         alias = mapperutil.PropertyAliasedClauses(prop,
                             prop.primary_join_against(mapper, adapt_against),
                             prop.secondary_join_against(mapper),
@@ -446,7 +467,16 @@ class Query(object):
                         clause = clause.join(prop.secondary, crit, isouter=outerjoin)
                         clause = clause.join(prop.select_table, prop.secondary_join_against(mapper), isouter=outerjoin)
                 else:
-                    if create_aliases:
+                    if use_selectable:
+                        alias = mapperutil.PropertyAliasedClauses(prop,
+                            prop.primary_join_against(mapper, adapt_against), 
+                            None,
+                            alias,
+                            alias=use_selectable
+                        )
+                        crit = alias.primaryjoin
+                        clause = clause.join(alias.alias, crit, isouter=outerjoin)
+                    elif create_aliases:
                         alias = mapperutil.PropertyAliasedClauses(prop,
                             prop.primary_join_against(mapper, adapt_against), 
                             None,
@@ -464,13 +494,12 @@ class Query(object):
 
             mapper = prop.mapper
 
-            if mapper.select_table is not mapper.mapped_table:
+            if use_selectable:
+                adapt_against = use_selectable
+            elif mapper.select_table is not mapper.mapped_table:
                 adapt_against = mapper.select_table
 
-        if create_aliases:
-            return (clause, mapper, alias)
-        else:
-            return (clause, mapper, None)
+        return (clause, mapper, alias)
 
     def _generative_col_aggregate(self, col, func):
         """apply the given aggregate function to the query and return the newly
@@ -594,14 +623,16 @@ class Query(object):
         'prop' may be one of:
           * a string property name, i.e. "rooms"
           * a class-mapped attribute, i.e. Houses.rooms
-          * a list containing a combination of any of the above.
+          * a 2-tuple containing one of the above, combined with a selectable
+          which derives from the properties' mapped table
+          * a list (not a tuple) containing a combination of any of the above.
           
         e.g.::
         
             session.query(Company).join('employees')
             session.query(Company).join(['employees', 'tasks'])
             session.query(Houses).join([Colonials.rooms, Room.closets])
-        
+            session.query(Company).join([('employees', people.join(engineers)), Engineer.computers])
         """
 
         return self._join(prop, id=id, outerjoin=False, aliased=aliased, from_joinpoint=from_joinpoint)
@@ -613,13 +644,16 @@ class Query(object):
         'prop' may be one of:
           * a string property name, i.e. "rooms"
           * a class-mapped attribute, i.e. Houses.rooms
-          * a list containing a combination of any of the above.
+          * a 2-tuple containing one of the above, combined with a selectable
+          which derives from the properties' mapped table
+          * a list (not a tuple) containing a combination of any of the above.
           
         e.g.::
         
             session.query(Company).outerjoin('employees')
             session.query(Company).outerjoin(['employees', 'tasks'])
             session.query(Houses).outerjoin([Colonials.rooms, Room.closets])
+            session.query(Company).join([('employees', people.join(engineers)), Engineer.computers])
         
         """
 
index d2f1d2491a626ddc511e5e273492349a6da3d83a..a801210f94fb057856d6c02f3310916f0a37ed6a 100644 (file)
@@ -193,8 +193,8 @@ class AliasedClauses(object):
 class PropertyAliasedClauses(AliasedClauses):
     """extends AliasedClauses to add support for primary/secondary joins on a relation()."""
     
-    def __init__(self, prop, primaryjoin, secondaryjoin, parentclauses=None):
-        super(PropertyAliasedClauses, self).__init__(prop.select_table)
+    def __init__(self, prop, primaryjoin, secondaryjoin, parentclauses=None, alias=None):
+        super(PropertyAliasedClauses, self).__init__(prop.select_table, alias=alias)
             
         self.parentclauses = parentclauses
 
index c603418028f443d46302b07d5a78002efe9eeea0..aff8654f256e987b81907e968c24808dcf0450b8 100644 (file)
@@ -2202,6 +2202,9 @@ class Join(FromClause):
         return "Join object on %s(%d) and %s(%d)" % (self.left.description, id(self.left), self.right.description, id(self.right))
     description = property(description)
 
+    def is_derived_from(self, fromclause):
+        return fromclause is self or self.left.is_derived_from(fromclause) or self.right.is_derived_from(fromclause)
+
     def self_group(self, against=None):
         return _FromGrouping(self)
 
index b9f11faa7c31fd351f7b6fea3f3cf38eb75935c3..503364787264560ab4a26b15d542c2ec15ce144c 100644 (file)
@@ -202,22 +202,22 @@ def make_test(select_type):
             self.assertEquals(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2)
     
         def test_join_to_subclass(self):
-            if select_type == '':
-                return
-
             sess = create_session()
 
-            self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1])
-        
-            self.assertEquals(sess.query(Company).join(['employees']).filter(Engineer.primary_language=='java').all(), [c1])
-        
-            self.assertEquals(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3])
-
-            self.assertEquals(sess.query(Person).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
-        
-            self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).all(), [c1, c2])
-
-            self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
+            if select_type == '':
+                self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1])
+                self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1])
+                self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).all(), [e1, e2, e3])
+                self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
+                self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).all(), [c1, c2])
+                self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
+            else:
+                self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1])
+                self.assertEquals(sess.query(Company).join(['employees']).filter(Engineer.primary_language=='java').all(), [c1])
+                self.assertEquals(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3])
+                self.assertEquals(sess.query(Person).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
+                self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).all(), [c1, c2])
+                self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
         
         def test_join_through_polymorphic(self):