From: Mike Bayer Date: Thu, 22 Jan 2009 18:28:27 +0000 (+0000) Subject: - Fixed an eager loading bug whereby self-referential eager X-Git-Tag: rel_0_5_2~5 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=fc7de2aafd8caf1354b81a1a34817f1622302957;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Fixed an eager loading bug whereby self-referential eager loading would prevent other eager loads, self referential or not, from joining to the parent JOIN properly. Thanks to Alex K for creating a great test case. --- diff --git a/CHANGES b/CHANGES index 0328d44fc2..865e884889 100644 --- a/CHANGES +++ b/CHANGES @@ -31,6 +31,11 @@ CHANGES relations from two different parent classes to the same target class would prematurely expunge the instance. + - Fixed an eager loading bug whereby self-referential eager + loading would prevent other eager loads, self referential or not, + from joining to the parent JOIN properly. Thanks to Alex K + for creating a great test case. + - sql - Further fixes to the "percent signs and spaces in column/table names" functionality. [ticket:1284] diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 6edbd73d31..2a78c90de9 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -656,7 +656,7 @@ class EagerLoader(AbstractRelationLoader): # whether or not the Query will wrap the selectable in a subquery, # and then attach eager load joins to that (i.e., in the case of LIMIT/OFFSET etc.) should_nest_selectable = context.query._should_nest_selectable - + if entity in context.eager_joins: entity_key, default_towrap = entity, entity.selectable elif should_nest_selectable or not context.from_clause or not sql_util.search(context.from_clause, entity.selectable): @@ -669,22 +669,31 @@ class EagerLoader(AbstractRelationLoader): # otherwise, create a single eager join from the from clause. # Query._compile_context will adapt as needed and append to the # FROM clause of the select(). - entity_key, default_towrap = None, context.from_clause - + entity_key, default_towrap = None, context.from_clause + towrap = context.eager_joins.setdefault(entity_key, default_towrap) - + # create AliasedClauses object to build up the eager query. clauses = mapperutil.ORMAdapter(mapperutil.AliasedClass(self.mapper), equivalents=self.mapper._equivalent_columns) if adapter: + # TODO: the fallback to self.parent_property here is a hack to account for + # an eagerjoin using of_type(). this should be improved such that + # when using of_type(), the subtype is the target of the previous eager join. + # there shouldn't be a fallback here, since mapperutil.outerjoin() can't + # be trusted with a plain MapperProperty. if getattr(adapter, 'aliased_class', None): onclause = getattr(adapter.aliased_class, self.key, self.parent_property) else: onclause = getattr(mapperutil.AliasedClass(self.parent, adapter.selectable), self.key, self.parent_property) else: - onclause = self.parent_property - + # For a plain MapperProperty, wrap the mapped table in an AliasedClass anyway. + # this prevents mapperutil.outerjoin() from aliasing to the left side indiscriminately, + # which can break things if the left side contains multiple aliases of the parent + # mapper already. In the case of eager loading, we know exactly what left side we want to join to. + onclause = getattr(mapperutil.AliasedClass(self.parent, self.parent.mapped_table), self.key) + context.eager_joins[entity_key] = eagerjoin = mapperutil.outerjoin(towrap, clauses.aliased_class, onclause) # send a hint to the Query as to where it may "splice" this join diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index f4ba49ae1e..522f0a156c 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -370,7 +370,7 @@ class _ORMJoin(expression.Join): adapt_from = left else: adapt_from = None - + right_mapper, right, right_is_aliased = _entity_info(right) if right_is_aliased: adapt_to = right @@ -421,6 +421,15 @@ def join(left, right, onclause=None, isouter=False): string name of a relation(), or a class-bound descriptor representing a relation. + When passed a string or plain mapped descriptor for the + onclause, ``join()`` goes into "automatic" mode and + will attempt to join the right side to the left + in whatever way it sees fit, which may include aliasing + the ON clause to match the left side. Alternatively, + when passed a clause-based onclause, or an attribute + mapped to an :func:`~sqlalchemy.orm.aliased` construct, + no left-side guesswork is performed. + """ return _ORMJoin(left, right, onclause, isouter) diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py index 2752aae3ec..9dff0ffd19 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/eager_relations.py @@ -1064,6 +1064,76 @@ class SelfReferentialEagerTest(_base.MappedTest): ]) == d self.assert_sql_count(testing.db, go, 3) +class MixedSelfReferentialEagerTest(_base.MappedTest): + def define_tables(self, metadata): + Table('a_table', metadata, + Column('id', Integer, primary_key=True) + ) + + Table('b_table', metadata, + Column('id', Integer, primary_key=True), + Column('parent_b1_id', Integer, ForeignKey('b_table.id')), + Column('parent_a_id', Integer, ForeignKey('a_table.id')), + Column('parent_b2_id', Integer, ForeignKey('b_table.id'))) + + + @testing.resolve_artifact_names + def setup_mappers(self): + class A(_base.ComparableEntity): + pass + class B(_base.ComparableEntity): + pass + + mapper(A,a_table) + mapper(B,b_table,properties = { + 'parent_b1': relation(B, + remote_side = [b_table.c.id], + primaryjoin = (b_table.c.parent_b1_id ==b_table.c.id), + order_by = b_table.c.id + ), + 'parent_z': relation(A,lazy = True), + 'parent_b2': relation(B, + remote_side = [b_table.c.id], + primaryjoin = (b_table.c.parent_b2_id ==b_table.c.id), + order_by = b_table.c.id + ) + }); + + @testing.resolve_artifact_names + def insert_data(self): + a_table.insert().execute(dict(id=1), dict(id=2), dict(id=3)) + b_table.insert().execute( + dict(id=1, parent_a_id=2, parent_b1_id=None, parent_b2_id=None), + dict(id=2, parent_a_id=1, parent_b1_id=1, parent_b2_id=None), + dict(id=3, parent_a_id=1, parent_b1_id=1, parent_b2_id=2), + dict(id=4, parent_a_id=3, parent_b1_id=1, parent_b2_id=None), + dict(id=5, parent_a_id=3, parent_b1_id=None, parent_b2_id=2), + dict(id=6, parent_a_id=1, parent_b1_id=1, parent_b2_id=3), + dict(id=7, parent_a_id=2, parent_b1_id=None, parent_b2_id=3), + dict(id=8, parent_a_id=2, parent_b1_id=1, parent_b2_id=2), + dict(id=9, parent_a_id=None, parent_b1_id=1, parent_b2_id=None), + dict(id=10, parent_a_id=3, parent_b1_id=7, parent_b2_id=2), + dict(id=11, parent_a_id=3, parent_b1_id=1, parent_b2_id=8), + dict(id=12, parent_a_id=2, parent_b1_id=5, parent_b2_id=2), + dict(id=13, parent_a_id=3, parent_b1_id=4, parent_b2_id=4), + dict(id=14, parent_a_id=3, parent_b1_id=7, parent_b2_id=2), + ) + + @testing.resolve_artifact_names + def test_eager_load(self): + session = create_session() + def go(): + eq_( + session.query(B).options(eagerload('parent_b1'),eagerload('parent_b2'),eagerload('parent_z')). + filter(B.id.in_([2, 8, 11])).order_by(B.id).all(), + [ + B(id=2, parent_z=A(id=1), parent_b1=B(id=1), parent_b2=None), + B(id=8, parent_z=A(id=2), parent_b1=B(id=1), parent_b2=B(id=2)), + B(id=11, parent_z=A(id=3), parent_b1=B(id=1), parent_b2=B(id=8)) + ] + ) + self.assert_sql_count(testing.db, go, 1) + class SelfReferentialM2MEagerTest(_base.MappedTest): def define_tables(self, metadata): Table('widget', metadata,