]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- refined and clarified query.__join() for readability rel_0_5_2
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 24 Jan 2009 17:29:56 +0000 (17:29 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 24 Jan 2009 17:29:56 +0000 (17:29 +0000)
- _ORMJoin() gets a new flag join_to_left to specify if
we really want to alias from the existing left side or not.  eager loading
wants this flag off in almost all cases, query.join() usually wants it on.
- query.join()/outerjoin() will now properly join an aliased()
construct to the existing left side, even if query.from_self()
or query.select_from(someselectable) has been called.
[ticket:1293]

CHANGES
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
test/orm/query.py

diff --git a/CHANGES b/CHANGES
index 9a95d3b2b5647202aa2db69bca4a6037adb7d2f4..36c8398c0ac7b08c1801a681de998e3230d503cd 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -39,6 +39,11 @@ CHANGES
     - session.expire() and related methods will not expire() unloaded
       deferred attributes.  This prevents them from being needlessly
       loaded when the instance is refreshed.
+
+    - query.join()/outerjoin() will now properly join an aliased()
+      construct to the existing left side, even if query.from_self()
+      or query.select_from(someselectable) has been called.
+      [ticket:1293]
       
 - sql
     - Further fixes to the "percent signs and spaces in column/table
index 6690eee128aa710c12df33c3eff836c3f2e40b53..6a26d30b4460b7d5cae0412b3eae679e2bfdb6ac 100644 (file)
@@ -887,26 +887,40 @@ class Query(object):
 
     @_generative(__no_statement_condition, __no_limit_offset)
     def __join(self, keys, outerjoin, create_aliases, from_joinpoint):
+        
+        # copy collections that may mutate so they do not affect
+        # the copied-from query.
         self.__currenttables = set(self.__currenttables)
         self._polymorphic_adapters = self._polymorphic_adapters.copy()
 
+        # start from the beginning unless from_joinpoint is set.
         if not from_joinpoint:
             self.__reset_joinpoint()
 
+        # join from our from_obj.  This is
+        # None unless select_from()/from_self() has been called.
         clause = self._from_obj
-        right_entity = None
 
+        # after the method completes,
+        # the query's joinpoint will be set to this.
+        right_entity = None
+        
         for arg1 in util.to_list(keys):
             aliased_entity = False
             alias_criterion = False
             left_entity = right_entity
             prop = of_type = right_entity = right_mapper = None
 
+            # distinguish between tuples, scalar args
             if isinstance(arg1, tuple):
                 arg1, arg2 = arg1
             else:
                 arg2 = None
 
+            # determine onclause/right_entity.  there
+            # is a little bit of legacy behavior still at work here
+            # which means they might be in either order.  may possibly
+            # lock this down to (right_entity, onclause) in 0.6.
             if isinstance(arg2, (interfaces.PropComparator, basestring)):
                 onclause = arg2
                 right_entity = arg1
@@ -917,6 +931,8 @@ class Query(object):
                 onclause = arg2
                 right_entity = arg1
 
+            # extract info from the onclause argument, determine
+            # left_entity and right_entity.
             if isinstance(onclause, interfaces.PropComparator):
                 of_type = getattr(onclause, '_of_type', None)
                 prop = onclause.property
@@ -942,25 +958,34 @@ class Query(object):
 
                 if not right_entity:
                     right_entity = right_mapper
-            elif onclause is None:
-                if not left_entity:
-                    left_entity = self._joinpoint_zero()
-            else:
-                if not left_entity:
-                    left_entity = self._joinpoint_zero()
+            elif not left_entity:
+                left_entity = self._joinpoint_zero()
 
+            # if no initial left-hand clause is set, extract
+            # this from the left_entity or as a last
+            # resort from the onclause argument, if it's
+            # a PropComparator.
             if not clause:
-                if isinstance(onclause, interfaces.PropComparator):
-                    clause = onclause.__clause_element__()
-
                 for ent in self._entities:
                     if ent.corresponds_to(left_entity):
                         clause = ent.selectable
                         break
+                    
+            if not clause:
+                if isinstance(onclause, interfaces.PropComparator):
+                    clause = onclause.__clause_element__()
 
             if not clause:
                 raise sa_exc.InvalidRequestError("Could not find a FROM clause to join from")
 
+            # if we have a MapperProperty and the onclause is not already
+            # an instrumented descriptor.  this catches of_type()
+            # PropComparators and string-based on clauses.
+            if prop and not isinstance(onclause, attributes.QueryableAttribute):
+                onclause = prop
+
+            # start looking at the right side of the join
+            
             mp, right_selectable, is_aliased_class = _entity_info(right_entity)
             
             if mp is not None and right_mapper is not None and not mp.common_parent(right_mapper):
@@ -971,11 +996,16 @@ class Query(object):
             if not right_mapper and mp:
                 right_mapper = mp
 
+            # determine if we need to wrap the right hand side in an alias.
+            # this occurs based on the create_aliases flag, or if the target
+            # is a selectable, Join, or polymorphically-loading mapper
             if right_mapper and not is_aliased_class:
                 if right_entity is right_selectable:
 
                     if not right_selectable.is_derived_from(right_mapper.mapped_table):
-                        raise sa_exc.InvalidRequestError("Selectable '%s' is not derived from '%s'" % (right_selectable.description, right_mapper.mapped_table.description))
+                        raise sa_exc.InvalidRequestError(
+                            "Selectable '%s' is not derived from '%s'" % 
+                            (right_selectable.description, right_mapper.mapped_table.description))
 
                     if not isinstance(right_selectable, expression.Alias):
                         right_selectable = right_selectable.alias()
@@ -993,12 +1023,17 @@ class Query(object):
                     aliased_entity = True
 
                 elif prop:
+                    # for joins across plain relation()s, try not to specify the
+                    # same joins twice.  the __currenttables collection tracks
+                    # what plain mapped tables we've joined to already.
+                    
                     if prop.table in self.__currenttables:
                         if prop.secondary is not None and prop.secondary not in self.__currenttables:
                             # TODO: this check is not strong enough for different paths to the same endpoint which
                             # does not use secondary tables
-                            raise sa_exc.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists.  Use the `alias=True` argument to `join()`." % descriptor)
-
+                            raise sa_exc.InvalidRequestError("Can't join to property '%s'; a path to this "
+                                "table along a different secondary table already "
+                                "exists.  Use the `alias=True` argument to `join()`." % descriptor)
                         continue
 
                     if prop.secondary:
@@ -1010,30 +1045,50 @@ class Query(object):
                     else:
                         right_entity = prop.mapper
 
+            # create adapters to the right side, if we've created aliases
             if alias_criterion:
                 right_adapter = ORMAdapter(right_entity,
                     equivalents=right_mapper._equivalent_columns, chain_to=self._filter_aliases)
 
-                if isinstance(onclause, sql.ClauseElement):
+            # if the onclause is a ClauseElement, adapt it with our right
+            # adapter, then with our query-wide adaptation if any.
+            if isinstance(onclause, expression.ClauseElement):
+                if alias_criterion:
                     onclause = right_adapter.traverse(onclause)
-
-            # TODO: is this a little hacky ?
-            if not isinstance(onclause, attributes.QueryableAttribute) or not isinstance(onclause.parententity, AliasedClass):
-                if prop:
-                    # MapperProperty based onclause
-                    onclause = prop
-                else:
-                    # ClauseElement based onclause
-                    onclause = self._adapt_clause(onclause, False, True)
-                
-            clause = orm_join(clause, right_entity, onclause, isouter=outerjoin)
+                onclause = self._adapt_clause(onclause, False, True)
+
+            # determine if we want _ORMJoin to alias the onclause 
+            # to the given left side.  This is used if we're joining against a 
+            # select_from() selectable, from_self() call, or the onclause
+            # has been resolved into a MapperProperty.  Otherwise we assume
+            # the onclause itself contains more specific information on how to
+            # construct the onclause.
+            join_to_left = not is_aliased_class or \
+                            onclause is prop or \
+                            clause is self._from_obj and self._from_obj_alias
+            
+            # create the join                
+            clause = orm_join(clause, right_entity, onclause, isouter=outerjoin, join_to_left=join_to_left)
+            
+            # set up state for the query as a whole
             if alias_criterion:
+                # adapt filter() calls based on our right side adaptation
                 self._filter_aliases = right_adapter
 
+                # if a polymorphic entity was aliased, establish that
+                # so that MapperEntity/ColumnEntity can pick up on it
+                # and adapt when it renders columns and fetches them from results
                 if aliased_entity:
-                    self.__mapper_loads_polymorphically_with(right_mapper, ORMAdapter(right_entity, equivalents=right_mapper._equivalent_columns))
-
+                    self.__mapper_loads_polymorphically_with(
+                                        right_mapper, 
+                                        ORMAdapter(right_entity, equivalents=right_mapper._equivalent_columns)
+                                    )
+        
+        # loop finished.  we're selecting from 
+        # our final clause now
         self._from_obj = clause
+        
+        # future joins with from_joinpoint=True join from our established right_entity.
         self._joinpoint = right_entity
 
     @_generative(__no_statement_condition)
index 91b2f359a93b80bf537f2f07ad38a34eb99dca0b..b72722e77d74335de8e128dc2f9be315c5832eeb 100644 (file)
@@ -678,24 +678,21 @@ class EagerLoader(AbstractRelationLoader):
         clauses = mapperutil.ORMAdapter(mapperutil.AliasedClass(self.mapper), 
                     equivalents=self.mapper._equivalent_columns)
 
+        join_to_left = False
         if adapter:
-            # TODO: the fallback to self.parent_property here is a hack to account for
-            # an eagerjoin using of_type().  this should be improved such that
-            # when using of_type(), the subtype is the target of the previous eager join.
-            # there shouldn't be a fallback here, since mapperutil.outerjoin() can't
-            # be trusted with a plain MapperProperty.
             if getattr(adapter, 'aliased_class', None):
                 onclause = getattr(adapter.aliased_class, self.key, self.parent_property)
             else:
                 onclause = getattr(mapperutil.AliasedClass(self.parent, adapter.selectable), self.key, self.parent_property)
+                
+            if onclause is self.parent_property:
+                # TODO: this is a temporary hack to account for polymorphic eager loads where
+                # the eagerload is referencing via of_type().
+                join_to_left = True
         else:
-            # For a plain MapperProperty, wrap the mapped table in an AliasedClass anyway.  
-            # this prevents mapperutil.outerjoin() from aliasing to the left side indiscriminately,
-            # which can break things if the left side contains multiple aliases of the parent
-            # mapper already. In the case of eager loading, we know exactly what left side we want to join to.
-            onclause = getattr(mapperutil.AliasedClass(self.parent, self.parent.mapped_table), self.key)
+            onclause = self.parent_property
             
-        context.eager_joins[entity_key] = eagerjoin = mapperutil.outerjoin(towrap, clauses.aliased_class, onclause)
+        context.eager_joins[entity_key] = eagerjoin = mapperutil.outerjoin(towrap, clauses.aliased_class, onclause, join_to_left=join_to_left)
         
         # send a hint to the Query as to where it may "splice" this join
         eagerjoin.stop_on = entity.selectable
index 522f0a156c0acd89be7001fd32faa36c662d998e..c8637290177ef598bc320cf0a5b843f1f4017e76 100644 (file)
@@ -359,17 +359,17 @@ class _ORMJoin(expression.Join):
     
     __visit_name__ = expression.Join.__visit_name__
 
-    def __init__(self, left, right, onclause=None, isouter=False):
+    def __init__(self, left, right, onclause=None, isouter=False, join_to_left=True):
+        adapt_from = None
+        
         if hasattr(left, '_orm_mappers'):
             left_mapper = left._orm_mappers[1]
-            adapt_from = left.right
-
+            if join_to_left:
+                adapt_from = left.right
         else:
             left_mapper, left, left_is_aliased = _entity_info(left)
-            if left_is_aliased or not left_mapper:
+            if join_to_left and (left_is_aliased or not left_mapper):
                 adapt_from = left
-            else:
-                adapt_from = None
             
         right_mapper, right, right_is_aliased = _entity_info(right)
         if right_is_aliased:
@@ -383,11 +383,8 @@ class _ORMJoin(expression.Join):
             if isinstance(onclause, basestring):
                 prop = left_mapper.get_property(onclause)
             elif isinstance(onclause, attributes.QueryableAttribute):
-                # TODO: we might want to honor the current adapt_from,
-                # if already set.  we would need to adjust how we calculate
-                # adapt_from though since it is present in too many cases
-                # at the moment (query tests illustrate that).
-                adapt_from = onclause.__clause_element__()
+                if not adapt_from:
+                    adapt_from = onclause.__clause_element__()
                 prop = onclause.property
             elif isinstance(onclause, MapperProperty):
                 prop = onclause
@@ -395,7 +392,12 @@ class _ORMJoin(expression.Join):
                 prop = None
 
             if prop:
-                pj, sj, source, dest, secondary, target_adapter = prop._create_joins(source_selectable=adapt_from, dest_selectable=adapt_to, source_polymorphic=True, dest_polymorphic=True, of_type=right_mapper)
+                pj, sj, source, dest, secondary, target_adapter = prop._create_joins(
+                                source_selectable=adapt_from, 
+                                dest_selectable=adapt_to, 
+                                source_polymorphic=True, 
+                                dest_polymorphic=True, 
+                                of_type=right_mapper)
 
                 if sj:
                     left = sql.join(left, secondary, pj, isouter)
@@ -406,13 +408,13 @@ class _ORMJoin(expression.Join):
                 
         expression.Join.__init__(self, left, right, onclause, isouter)
 
-    def join(self, right, onclause=None, isouter=False):
-        return _ORMJoin(self, right, onclause, isouter)
+    def join(self, right, onclause=None, isouter=False, join_to_left=True):
+        return _ORMJoin(self, right, onclause, isouter, join_to_left)
 
-    def outerjoin(self, right, onclause=None):
-        return _ORMJoin(self, right, onclause, True)
+    def outerjoin(self, right, onclause=None, join_to_left=True):
+        return _ORMJoin(self, right, onclause, True, join_to_left)
 
-def join(left, right, onclause=None, isouter=False):
+def join(left, right, onclause=None, isouter=False, join_to_left=True):
     """Produce an inner join between left and right clauses.
     
     In addition to the interface provided by 
@@ -421,19 +423,15 @@ def join(left, right, onclause=None, isouter=False):
     string name of a relation(), or a class-bound descriptor 
     representing a relation.
     
-    When passed a string or plain mapped descriptor for the
-    onclause, ``join()`` goes into "automatic" mode and
-    will attempt to join the right side to the left
-    in whatever way it sees fit, which may include aliasing
-    the ON clause to match the left side.  Alternatively,
-    when passed a clause-based onclause, or an attribute
-    mapped to an :func:`~sqlalchemy.orm.aliased` construct, 
-    no left-side guesswork is performed.
+    join_to_left indicates to attempt aliasing the ON clause,
+    in whatever form it is passed, to the selectable
+    passed as the left side.  If False, the onclause
+    is used as is.
     
     """
-    return _ORMJoin(left, right, onclause, isouter)
+    return _ORMJoin(left, right, onclause, isouter, join_to_left)
 
-def outerjoin(left, right, onclause=None):
+def outerjoin(left, right, onclause=None, join_to_left=True):
     """Produce a left outer join between left and right clauses.
     
     In addition to the interface provided by 
@@ -443,7 +441,7 @@ def outerjoin(left, right, onclause=None):
     representing a relation.
     
     """
-    return _ORMJoin(left, right, onclause, True)
+    return _ORMJoin(left, right, onclause, True, join_to_left)
 
 def with_parent(instance, prop):
     """Return criterion which selects instances with a given parent.
index 1b56dbb269e62872be5bf878a70322b0ba768f60..9d01be8371f05ba41c0292c2d42b367531431c70 100644 (file)
@@ -743,6 +743,42 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL):
             "LEFT OUTER JOIN addresses AS addresses_1 ON anon_1.users_id = addresses_1.user_id ORDER BY addresses_1.id"
         )
             
+    def test_aliases(self):
+        """test that aliased objects are accessible externally to a from_self() call."""
+        
+        s = create_session()
+        
+        ualias = aliased(User)
+        eq_(
+            s.query(User, ualias).filter(User.id > ualias.id).from_self(User.name, ualias.name).
+                    order_by(User.name, ualias.name).all(),
+            [
+                (u'chuck', u'ed'), 
+                (u'chuck', u'fred'), 
+                (u'chuck', u'jack'), 
+                (u'ed', u'jack'), 
+                (u'fred', u'ed'), 
+                (u'fred', u'jack')
+            ]
+        )
+
+        eq_(
+            s.query(User, ualias).filter(User.id > ualias.id).from_self(User.name, ualias.name).filter(ualias.name=='ed')\
+                .order_by(User.name, ualias.name).all(),
+            [(u'chuck', u'ed'), (u'fred', u'ed')]
+        )
+
+        eq_(
+            s.query(User, ualias).filter(User.id > ualias.id).from_self(ualias.name, Address.email_address).
+                    join(ualias.addresses).order_by(ualias.name, Address.email_address).all(),
+            [
+                (u'ed', u'fred@fred.com'), 
+                (u'jack', u'ed@bettyboop.com'), 
+                (u'jack', u'ed@lala.com'), 
+                (u'jack', u'ed@wood.com'), 
+                (u'jack', u'fred@fred.com')]
+        )
+        
         
     def test_multiple_entities(self):
         sess = create_session()