]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed an eager loading bug whereby self-referential eager
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 22 Jan 2009 18:28:27 +0000 (18:28 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 22 Jan 2009 18:28:27 +0000 (18:28 +0000)
loading would prevent other eager loads, self referential or not,
from joining to the parent JOIN properly.  Thanks to Alex K
for creating a great test case.

CHANGES
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
test/orm/eager_relations.py

diff --git a/CHANGES b/CHANGES
index 0328d44fc25662158e35b2e935f2151156fe22ba..865e884889d985e4d191e55739b3f427e4d2e505 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -31,6 +31,11 @@ CHANGES
       relations from two different parent classes to the same target 
       class would prematurely expunge the instance.
 
+    - Fixed an eager loading bug whereby self-referential eager 
+      loading would prevent other eager loads, self referential or not,
+      from joining to the parent JOIN properly.  Thanks to Alex K
+      for creating a great test case.
+      
 - sql
     - Further fixes to the "percent signs and spaces in column/table
        names" functionality. [ticket:1284]
index 6edbd73d315260ea04eb8f567a3ee00bfeb8b0e4..2a78c90de92836ae53f8eabbc9b05beb2af4679f 100644 (file)
@@ -656,7 +656,7 @@ class EagerLoader(AbstractRelationLoader):
         # whether or not the Query will wrap the selectable in a subquery,
         # and then attach eager load joins to that (i.e., in the case of LIMIT/OFFSET etc.)
         should_nest_selectable = context.query._should_nest_selectable
-        
+
         if entity in context.eager_joins:
             entity_key, default_towrap = entity, entity.selectable
         elif should_nest_selectable or not context.from_clause or not sql_util.search(context.from_clause, entity.selectable):
@@ -669,22 +669,31 @@ class EagerLoader(AbstractRelationLoader):
             # otherwise, create a single eager join from the from clause.  
             # Query._compile_context will adapt as needed and append to the
             # FROM clause of the select().
-            entity_key, default_towrap = None, context.from_clause
-    
+            entity_key, default_towrap = None, context.from_clause  
+
         towrap = context.eager_joins.setdefault(entity_key, default_towrap)
-    
+
         # create AliasedClauses object to build up the eager query.  
         clauses = mapperutil.ORMAdapter(mapperutil.AliasedClass(self.mapper), 
                     equivalents=self.mapper._equivalent_columns)
 
         if adapter:
+            # TODO: the fallback to self.parent_property here is a hack to account for
+            # an eagerjoin using of_type().  this should be improved such that
+            # when using of_type(), the subtype is the target of the previous eager join.
+            # there shouldn't be a fallback here, since mapperutil.outerjoin() can't
+            # be trusted with a plain MapperProperty.
             if getattr(adapter, 'aliased_class', None):
                 onclause = getattr(adapter.aliased_class, self.key, self.parent_property)
             else:
                 onclause = getattr(mapperutil.AliasedClass(self.parent, adapter.selectable), self.key, self.parent_property)
         else:
-            onclause = self.parent_property
-    
+            # For a plain MapperProperty, wrap the mapped table in an AliasedClass anyway.  
+            # this prevents mapperutil.outerjoin() from aliasing to the left side indiscriminately,
+            # which can break things if the left side contains multiple aliases of the parent
+            # mapper already. In the case of eager loading, we know exactly what left side we want to join to.
+            onclause = getattr(mapperutil.AliasedClass(self.parent, self.parent.mapped_table), self.key)
+            
         context.eager_joins[entity_key] = eagerjoin = mapperutil.outerjoin(towrap, clauses.aliased_class, onclause)
         
         # send a hint to the Query as to where it may "splice" this join
index f4ba49ae1ec0443695ef794bd3c3c219ecef2a69..522f0a156c0acd89be7001fd32faa36c662d998e 100644 (file)
@@ -370,7 +370,7 @@ class _ORMJoin(expression.Join):
                 adapt_from = left
             else:
                 adapt_from = None
-        
+            
         right_mapper, right, right_is_aliased = _entity_info(right)
         if right_is_aliased:
             adapt_to = right
@@ -421,6 +421,15 @@ def join(left, right, onclause=None, isouter=False):
     string name of a relation(), or a class-bound descriptor 
     representing a relation.
     
+    When passed a string or plain mapped descriptor for the
+    onclause, ``join()`` goes into "automatic" mode and
+    will attempt to join the right side to the left
+    in whatever way it sees fit, which may include aliasing
+    the ON clause to match the left side.  Alternatively,
+    when passed a clause-based onclause, or an attribute
+    mapped to an :func:`~sqlalchemy.orm.aliased` construct, 
+    no left-side guesswork is performed.
+    
     """
     return _ORMJoin(left, right, onclause, isouter)
 
index 2752aae3ec2e66978d4fed939bbf6884b7490938..9dff0ffd191b8586b21dbde99c5e570b25bbbcb2 100644 (file)
@@ -1064,6 +1064,76 @@ class SelfReferentialEagerTest(_base.MappedTest):
             ]) == d
         self.assert_sql_count(testing.db, go, 3)
 
+class MixedSelfReferentialEagerTest(_base.MappedTest):
+    def define_tables(self, metadata):
+        Table('a_table', metadata,
+                       Column('id', Integer, primary_key=True)
+                       )
+
+        Table('b_table', metadata,
+                       Column('id', Integer, primary_key=True),
+                       Column('parent_b1_id', Integer, ForeignKey('b_table.id')),
+                       Column('parent_a_id', Integer, ForeignKey('a_table.id')),
+                       Column('parent_b2_id', Integer, ForeignKey('b_table.id')))
+
+
+    @testing.resolve_artifact_names
+    def setup_mappers(self):
+        class A(_base.ComparableEntity):
+            pass
+        class B(_base.ComparableEntity):
+            pass
+            
+        mapper(A,a_table)
+        mapper(B,b_table,properties = {
+           'parent_b1': relation(B,
+                            remote_side = [b_table.c.id],
+                            primaryjoin = (b_table.c.parent_b1_id ==b_table.c.id),
+                            order_by = b_table.c.id
+                            ),
+           'parent_z': relation(A,lazy = True),
+           'parent_b2': relation(B,
+                            remote_side = [b_table.c.id],
+                            primaryjoin = (b_table.c.parent_b2_id ==b_table.c.id),
+                            order_by = b_table.c.id
+                            )
+        });
+    
+    @testing.resolve_artifact_names
+    def insert_data(self):
+        a_table.insert().execute(dict(id=1), dict(id=2), dict(id=3))
+        b_table.insert().execute(
+            dict(id=1, parent_a_id=2, parent_b1_id=None, parent_b2_id=None),
+            dict(id=2, parent_a_id=1, parent_b1_id=1, parent_b2_id=None),
+            dict(id=3, parent_a_id=1, parent_b1_id=1, parent_b2_id=2),
+            dict(id=4, parent_a_id=3, parent_b1_id=1, parent_b2_id=None),
+            dict(id=5, parent_a_id=3, parent_b1_id=None, parent_b2_id=2),
+            dict(id=6, parent_a_id=1, parent_b1_id=1, parent_b2_id=3),
+            dict(id=7, parent_a_id=2, parent_b1_id=None, parent_b2_id=3),
+            dict(id=8, parent_a_id=2, parent_b1_id=1, parent_b2_id=2),
+            dict(id=9, parent_a_id=None, parent_b1_id=1, parent_b2_id=None),
+            dict(id=10, parent_a_id=3, parent_b1_id=7, parent_b2_id=2),
+            dict(id=11, parent_a_id=3, parent_b1_id=1, parent_b2_id=8),
+            dict(id=12, parent_a_id=2, parent_b1_id=5, parent_b2_id=2),
+            dict(id=13, parent_a_id=3, parent_b1_id=4, parent_b2_id=4),
+            dict(id=14, parent_a_id=3, parent_b1_id=7, parent_b2_id=2),
+        )
+        
+    @testing.resolve_artifact_names
+    def test_eager_load(self):
+        session = create_session()
+        def go():
+            eq_(
+                session.query(B).options(eagerload('parent_b1'),eagerload('parent_b2'),eagerload('parent_z')).
+                            filter(B.id.in_([2, 8, 11])).order_by(B.id).all(),
+                [
+                    B(id=2, parent_z=A(id=1), parent_b1=B(id=1), parent_b2=None),
+                    B(id=8, parent_z=A(id=2), parent_b1=B(id=1), parent_b2=B(id=2)),
+                    B(id=11, parent_z=A(id=3), parent_b1=B(id=1), parent_b2=B(id=8))
+                ]
+            )
+        self.assert_sql_count(testing.db, go, 1)
+        
 class SelfReferentialM2MEagerTest(_base.MappedTest):
     def define_tables(self, metadata):
         Table('widget', metadata,