From 998183be6bb4575b1088b495b9078a6f3f91293f Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 16 Sep 2009 19:48:22 +0000 Subject: [PATCH] - contains_eager() now works with the automatically generated subquery that results when you say "query(Parent).join(Parent.somejoinedsubclass)", i.e. when Parent joins to a joined-table-inheritance subclass. Previously contains_eager() would erroneously add the subclass table to the query separately producing a cartesian product. An example is in the ticket description. [ticket:1543] --- CHANGES | 13 +++- lib/sqlalchemy/orm/strategies.py | 5 +- test/orm/inheritance/test_basic.py | 82 +++++++++++++++++++++++++ test/orm/inheritance/test_polymorph2.py | 4 +- test/orm/inheritance/test_query.py | 2 +- 5 files changed, 101 insertions(+), 5 deletions(-) diff --git a/CHANGES b/CHANGES index efa34454c9..19f7b3788c 100644 --- a/CHANGES +++ b/CHANGES @@ -379,7 +379,18 @@ CHANGES - AbstractType.get_search_list() is removed - the games that was used for are no longer necessary. - +0.5.7 +===== +- orm + - contains_eager() now works with the automatically + generated subquery that results when you say + "query(Parent).join(Parent.somejoinedsubclass)", i.e. + when Parent joins to a joined-table-inheritance subclass. + Previously contains_eager() would erroneously add the + subclass table to the query separately producing a + cartesian product. An example is in the ticket + description. [ticket:1543] + 0.5.6 ===== - orm diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index b3290a2d63..ed742a2bfa 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -858,7 +858,10 @@ class LoadEagerFromAliasOption(PropertyOption): self.alias = prop.target.alias(self.alias) query._attributes[("user_defined_eager_row_processor", paths[-1])] = sql_util.ColumnAdapter(self.alias) else: - query._attributes[("user_defined_eager_row_processor", paths[-1])] = None + (mapper, propname) = paths[-1][-2:] + prop = mapper.get_property(propname, resolve_synonyms=True) + adapter = query._polymorphic_adapters.get(prop.mapper, None) + query._attributes[("user_defined_eager_row_processor", paths[-1])] = adapter class _SingleParentValidator(interfaces.AttributeExtension): def __init__(self, prop): diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 713ae3b5fe..5ed6d1735f 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -447,6 +447,88 @@ class EagerTargetingTest(_base.MappedTest): eq_(node, B(id=1, name='b1',b_data='i')) eq_(node.children[0], B(id=2, name='b2',b_data='l')) +class EagerToSubclassTest(_base.MappedTest): + """Test eagerloads to subclass mappers""" + + run_setup_classes = 'once' + run_setup_mappers = 'once' + run_inserts = 'once' + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + Table('parent', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('data', String(10)), + ) + + Table('base', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('type', String(10)), + ) + + Table('sub', metadata, + Column('id', Integer, ForeignKey('base.id'), primary_key=True), + Column('data', String(10)), + Column('parent_id', Integer, ForeignKey('parent.id'), nullable=False) + ) + + @classmethod + @testing.resolve_artifact_names + def setup_classes(cls): + class Parent(_base.ComparableEntity): + pass + + class Base(_base.ComparableEntity): + pass + + class Sub(Base): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Parent, parent, properties={ + 'children':relation(Sub) + }) + mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='b') + mapper(Sub, sub, inherits=Base, polymorphic_identity='s') + + @classmethod + @testing.resolve_artifact_names + def insert_data(cls): + sess = create_session() + p1 = Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]) + p2 = Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')]) + sess.add(p1) + sess.add(p2) + sess.flush() + + @testing.resolve_artifact_names + def test_eagerload(self): + sess = create_session() + def go(): + eq_( + sess.query(Parent).options(eagerload(Parent.children)).all(), + [ + Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]), + Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')]) + ] + ) + self.assert_sql_count(testing.db, go, 1) + + @testing.resolve_artifact_names + def test_contains_eager(self): + sess = create_session() + def go(): + eq_( + sess.query(Parent).join(Parent.children).options(contains_eager(Parent.children)).all(), + [ + Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]), + Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')]) + ] + ) + self.assert_sql_count(testing.db, go, 1) class FlushTest(_base.MappedTest): """test dependency sorting among inheriting mappers""" diff --git a/test/orm/inheritance/test_polymorph2.py b/test/orm/inheritance/test_polymorph2.py index 80c14413a0..7aedc8245b 100644 --- a/test/orm/inheritance/test_polymorph2.py +++ b/test/orm/inheritance/test_polymorph2.py @@ -437,8 +437,8 @@ class RelationTest5(_base.MappedTest): Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('owner', Integer, ForeignKey('people.person_id'))) - def testeagerempty(self): - """an easy one...test parent object with child relation to an inheriting mapper, using eager loads, + def test_eager_empty(self): + """test parent object with child relation to an inheriting mapper, using eager loads, works when there are no child objects present""" class Person(object): def __init__(self, **kwargs): diff --git a/test/orm/inheritance/test_query.py b/test/orm/inheritance/test_query.py index ee2d7f0e5f..243ed4a7ba 100644 --- a/test/orm/inheritance/test_query.py +++ b/test/orm/inheritance/test_query.py @@ -476,7 +476,7 @@ def _produce_test(select_type): [Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", machines=[Machine(name="IBM ThinkPad"), Machine(name="IPhone")])] ) self.assert_sql_count(testing.db, go, 1) - + def test_join_to_subclass(self): sess = create_session() eq_(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1]) -- 2.47.2