From 8b328f694216616e06f05decd728d227ccc1353f Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 16 Sep 2009 20:38:29 +0000 Subject: [PATCH] merged r6357 of rel_0_5 branch --- lib/sqlalchemy/orm/interfaces.py | 25 ++-- lib/sqlalchemy/orm/strategies.py | 9 +- test/orm/inheritance/test_basic.py | 83 ------------ test/orm/inheritance/test_query.py | 198 +++++++++++++++++++++++++++++ 4 files changed, 218 insertions(+), 97 deletions(-) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index eaafe5761a..dace1978e4 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -668,11 +668,11 @@ class PropertyOption(MapperOption): self._process(query, False) def _process(self, query, raiseerr): - paths = self.__get_paths(query, raiseerr) + paths, mappers = self.__get_paths(query, raiseerr) if paths: - self.process_query_property(query, paths) + self.process_query_property(query, paths, mappers) - def process_query_property(self, query, paths): + def process_query_property(self, query, paths, mappers): pass def __find_entity(self, query, mapper, raiseerr): @@ -718,7 +718,8 @@ class PropertyOption(MapperOption): path = None entity = None l = [] - + mappers = [] + # _current_path implies we're in a secondary load # with an existing path current_path = list(query._current_path) @@ -739,6 +740,7 @@ class PropertyOption(MapperOption): entity = query._entity_zero() path_element = entity.path_entity mapper = entity.mapper + mappers.append(mapper) prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr) key = token elif isinstance(token, PropComparator): @@ -746,8 +748,9 @@ class PropertyOption(MapperOption): if not entity: entity = self.__find_entity(query, token.parententity, raiseerr) if not entity: - return [] + return [], [] path_element = entity.path_entity + mappers.append(prop.parent) key = prop.key else: raise sa_exc.ArgumentError("mapper option expects string key or list of attributes") @@ -757,7 +760,7 @@ class PropertyOption(MapperOption): continue if prop is None: - return [] + return [], [] path = build_path(path_element, prop.key, path) l.append(path) @@ -765,15 +768,17 @@ class PropertyOption(MapperOption): path_element = mapper = token._of_type else: path_element = mapper = getattr(prop, 'mapper', None) + if path_element: path_element = path_element.base_mapper - + + # if current_path tokens remain, then # we didn't have an exact path match. if current_path: - return [] + return [], [] - return l + return l, mappers class AttributeExtension(object): """An event handler for individual attribute change events. @@ -823,7 +828,7 @@ class StrategizedOption(PropertyOption): def is_chained(self): return False - def process_query_property(self, query, paths): + def process_query_property(self, query, paths, mappers): if self.is_chained(): for path in paths: query._attributes[("loaderstrategy", path)] = self.get_strategy_class() diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index ed742a2bfa..4ab1a49486 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -849,16 +849,17 @@ class LoadEagerFromAliasOption(PropertyOption): # dont run this option on a secondary load pass - def process_query_property(self, query, paths): + def process_query_property(self, query, paths, mappers): if self.alias: if isinstance(self.alias, basestring): - (mapper, propname) = paths[-1][-2:] - + mapper = mappers[-1] + (root_mapper, propname) = paths[-1][-2:] prop = mapper.get_property(propname, resolve_synonyms=True) self.alias = prop.target.alias(self.alias) query._attributes[("user_defined_eager_row_processor", paths[-1])] = sql_util.ColumnAdapter(self.alias) else: - (mapper, propname) = paths[-1][-2:] + (root_mapper, propname) = paths[-1][-2:] + mapper = mappers[-1] prop = mapper.get_property(propname, resolve_synonyms=True) adapter = query._polymorphic_adapters.get(prop.mapper, None) query._attributes[("user_defined_eager_row_processor", paths[-1])] = adapter diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 5ed6d1735f..4f329a91df 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -447,89 +447,6 @@ class EagerTargetingTest(_base.MappedTest): eq_(node, B(id=1, name='b1',b_data='i')) eq_(node.children[0], B(id=2, name='b2',b_data='l')) -class EagerToSubclassTest(_base.MappedTest): - """Test eagerloads to subclass mappers""" - - run_setup_classes = 'once' - run_setup_mappers = 'once' - run_inserts = 'once' - run_deletes = None - - @classmethod - def define_tables(cls, metadata): - Table('parent', metadata, - Column('id', Integer, primary_key=True, test_needs_autoincrement=True), - Column('data', String(10)), - ) - - Table('base', metadata, - Column('id', Integer, primary_key=True, test_needs_autoincrement=True), - Column('type', String(10)), - ) - - Table('sub', metadata, - Column('id', Integer, ForeignKey('base.id'), primary_key=True), - Column('data', String(10)), - Column('parent_id', Integer, ForeignKey('parent.id'), nullable=False) - ) - - @classmethod - @testing.resolve_artifact_names - def setup_classes(cls): - class Parent(_base.ComparableEntity): - pass - - class Base(_base.ComparableEntity): - pass - - class Sub(Base): - pass - - @classmethod - @testing.resolve_artifact_names - def setup_mappers(cls): - mapper(Parent, parent, properties={ - 'children':relation(Sub) - }) - mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='b') - mapper(Sub, sub, inherits=Base, polymorphic_identity='s') - - @classmethod - @testing.resolve_artifact_names - def insert_data(cls): - sess = create_session() - p1 = Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]) - p2 = Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')]) - sess.add(p1) - sess.add(p2) - sess.flush() - - @testing.resolve_artifact_names - def test_eagerload(self): - sess = create_session() - def go(): - eq_( - sess.query(Parent).options(eagerload(Parent.children)).all(), - [ - Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]), - Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')]) - ] - ) - self.assert_sql_count(testing.db, go, 1) - - @testing.resolve_artifact_names - def test_contains_eager(self): - sess = create_session() - def go(): - eq_( - sess.query(Parent).join(Parent.children).options(contains_eager(Parent.children)).all(), - [ - Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]), - Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')]) - ] - ) - self.assert_sql_count(testing.db, go, 1) - class FlushTest(_base.MappedTest): """test dependency sorting among inheriting mappers""" @classmethod diff --git a/test/orm/inheritance/test_query.py b/test/orm/inheritance/test_query.py index 243ed4a7ba..c74ddcad6f 100644 --- a/test/orm/inheritance/test_query.py +++ b/test/orm/inheritance/test_query.py @@ -1115,3 +1115,201 @@ class SelfReferentialM2MTest(_base.MappedTest, AssertsCompiledSQL): assert q.first() is c1 +class EagerToSubclassTest(_base.MappedTest): + """Test eagerloads to subclass mappers""" + + run_setup_classes = 'once' + run_setup_mappers = 'once' + run_inserts = 'once' + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + Table('parent', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('data', String(10)), + ) + + Table('base', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('type', String(10)), + ) + + Table('sub', metadata, + Column('id', Integer, ForeignKey('base.id'), primary_key=True), + Column('data', String(10)), + Column('parent_id', Integer, ForeignKey('parent.id'), nullable=False) + ) + + @classmethod + @testing.resolve_artifact_names + def setup_classes(cls): + class Parent(_base.ComparableEntity): + pass + + class Base(_base.ComparableEntity): + pass + + class Sub(Base): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Parent, parent, properties={ + 'children':relation(Sub) + }) + mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='b') + mapper(Sub, sub, inherits=Base, polymorphic_identity='s') + + @classmethod + @testing.resolve_artifact_names + def insert_data(cls): + sess = create_session() + p1 = Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]) + p2 = Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')]) + sess.add(p1) + sess.add(p2) + sess.flush() + + @testing.resolve_artifact_names + def test_eagerload(self): + sess = create_session() + def go(): + eq_( + sess.query(Parent).options(eagerload(Parent.children)).all(), + [ + Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]), + Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')]) + ] + ) + self.assert_sql_count(testing.db, go, 1) + + @testing.resolve_artifact_names + def test_contains_eager(self): + sess = create_session() + def go(): + eq_( + sess.query(Parent).join(Parent.children).options(contains_eager(Parent.children)).all(), + [ + Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]), + Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')]) + ] + ) + self.assert_sql_count(testing.db, go, 1) + +class SubClassEagerToSubclassTest(_base.MappedTest): + """Test eagerloads from subclass to subclass mappers""" + + run_setup_classes = 'once' + run_setup_mappers = 'once' + run_inserts = 'once' + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + Table('parent', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('type', String(10)), + ) + + Table('subparent', metadata, + Column('id', Integer, ForeignKey('parent.id'), primary_key=True), + Column('data', String(10)), + ) + + Table('base', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('type', String(10)), + ) + + Table('sub', metadata, + Column('id', Integer, ForeignKey('base.id'), primary_key=True), + Column('data', String(10)), + Column('subparent_id', Integer, ForeignKey('subparent.id'), nullable=False) + ) + + @classmethod + @testing.resolve_artifact_names + def setup_classes(cls): + class Parent(_base.ComparableEntity): + pass + + class Subparent(Parent): + pass + + class Base(_base.ComparableEntity): + pass + + class Sub(Base): + pass + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Parent, parent, polymorphic_on=parent.c.type, polymorphic_identity='b') + mapper(Subparent, subparent, inherits=Parent, polymorphic_identity='s', properties={ + 'children':relation(Sub) + }) + mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='b') + mapper(Sub, sub, inherits=Base, polymorphic_identity='s') + + @classmethod + @testing.resolve_artifact_names + def insert_data(cls): + sess = create_session() + p1 = Subparent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]) + p2 = Subparent(data='p2', children=[Sub(data='s4'), Sub(data='s5')]) + sess.add(p1) + sess.add(p2) + sess.flush() + + @testing.resolve_artifact_names + def test_eagerload(self): + sess = create_session() + def go(): + eq_( + sess.query(Subparent).options(eagerload(Subparent.children)).all(), + [ + Subparent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]), + Subparent(data='p2', children=[Sub(data='s4'), Sub(data='s5')]) + ] + ) + self.assert_sql_count(testing.db, go, 1) + + sess.expunge_all() + def go(): + eq_( + sess.query(Subparent).options(eagerload("children")).all(), + [ + Subparent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]), + Subparent(data='p2', children=[Sub(data='s4'), Sub(data='s5')]) + ] + ) + self.assert_sql_count(testing.db, go, 1) + + @testing.resolve_artifact_names + def test_contains_eager(self): + sess = create_session() + def go(): + eq_( + sess.query(Subparent).join(Subparent.children).options(contains_eager(Subparent.children)).all(), + [ + Subparent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]), + Subparent(data='p2', children=[Sub(data='s4'), Sub(data='s5')]) + ] + ) + self.assert_sql_count(testing.db, go, 1) + sess.expunge_all() + + def go(): + eq_( + sess.query(Subparent).join(Subparent.children).options(contains_eager("children")).all(), + [ + Subparent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]), + Subparent(data='p2', children=[Sub(data='s4'), Sub(data='s5')]) + ] + ) + self.assert_sql_count(testing.db, go, 1) + + -- 2.47.2