]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
merged r6357 of rel_0_5 branch
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 16 Sep 2009 20:38:29 +0000 (20:38 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 16 Sep 2009 20:38:29 +0000 (20:38 +0000)
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/strategies.py
test/orm/inheritance/test_basic.py
test/orm/inheritance/test_query.py

index eaafe5761a43d8baed85e86f2c433ca583e36fb0..dace1978e48ec14df226a0cf70f9ca958b8af243 100644 (file)
@@ -668,11 +668,11 @@ class PropertyOption(MapperOption):
         self._process(query, False)
 
     def _process(self, query, raiseerr):
-        paths = self.__get_paths(query, raiseerr)
+        paths, mappers = self.__get_paths(query, raiseerr)
         if paths:
-            self.process_query_property(query, paths)
+            self.process_query_property(query, paths, mappers)
 
-    def process_query_property(self, query, paths):
+    def process_query_property(self, query, paths, mappers):
         pass
 
     def __find_entity(self, query, mapper, raiseerr):
@@ -718,7 +718,8 @@ class PropertyOption(MapperOption):
         path = None
         entity = None
         l = []
-
+        mappers = []
+        
         # _current_path implies we're in a secondary load
         # with an existing path
         current_path = list(query._current_path)
@@ -739,6 +740,7 @@ class PropertyOption(MapperOption):
                         entity = query._entity_zero()
                         path_element = entity.path_entity
                         mapper = entity.mapper
+                    mappers.append(mapper)
                     prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr)
                     key = token
                 elif isinstance(token, PropComparator):
@@ -746,8 +748,9 @@ class PropertyOption(MapperOption):
                     if not entity:
                         entity = self.__find_entity(query, token.parententity, raiseerr)
                         if not entity:
-                            return []
+                            return [], []
                         path_element = entity.path_entity
+                    mappers.append(prop.parent)
                     key = prop.key
                 else:
                     raise sa_exc.ArgumentError("mapper option expects string key or list of attributes")
@@ -757,7 +760,7 @@ class PropertyOption(MapperOption):
                     continue
                     
                 if prop is None:
-                    return []
+                    return [], []
 
                 path = build_path(path_element, prop.key, path)
                 l.append(path)
@@ -765,15 +768,17 @@ class PropertyOption(MapperOption):
                     path_element = mapper = token._of_type
                 else:
                     path_element = mapper = getattr(prop, 'mapper', None)
+
                 if path_element:
                     path_element = path_element.base_mapper
-        
+                    
+                
         # if current_path tokens remain, then
         # we didn't have an exact path match.
         if current_path:
-            return []
+            return [], []
             
-        return l
+        return l, mappers
 
 class AttributeExtension(object):
     """An event handler for individual attribute change events.
@@ -823,7 +828,7 @@ class StrategizedOption(PropertyOption):
     def is_chained(self):
         return False
 
-    def process_query_property(self, query, paths):
+    def process_query_property(self, query, paths, mappers):
         if self.is_chained():
             for path in paths:
                 query._attributes[("loaderstrategy", path)] = self.get_strategy_class()
index ed742a2bfab795b6e833510ef9cf5169078c083d..4ab1a49486949ee7eba43f14bbad2202844ca450 100644 (file)
@@ -849,16 +849,17 @@ class LoadEagerFromAliasOption(PropertyOption):
         # dont run this option on a secondary load
         pass
         
-    def process_query_property(self, query, paths):
+    def process_query_property(self, query, paths, mappers):
         if self.alias:
             if isinstance(self.alias, basestring):
-                (mapper, propname) = paths[-1][-2:]
-
+                mapper = mappers[-1]
+                (root_mapper, propname) = paths[-1][-2:]
                 prop = mapper.get_property(propname, resolve_synonyms=True)
                 self.alias = prop.target.alias(self.alias)
             query._attributes[("user_defined_eager_row_processor", paths[-1])] = sql_util.ColumnAdapter(self.alias)
         else:
-            (mapper, propname) = paths[-1][-2:]
+            (root_mapper, propname) = paths[-1][-2:]
+            mapper = mappers[-1]
             prop = mapper.get_property(propname, resolve_synonyms=True)
             adapter = query._polymorphic_adapters.get(prop.mapper, None)
             query._attributes[("user_defined_eager_row_processor", paths[-1])] = adapter
index 5ed6d1735f557f3394444f8777a5ea09d7806019..4f329a91df1b0312a540d943e06195c6e3ebce3f 100644 (file)
@@ -447,89 +447,6 @@ class EagerTargetingTest(_base.MappedTest):
         eq_(node, B(id=1, name='b1',b_data='i'))
         eq_(node.children[0], B(id=2, name='b2',b_data='l'))
         
-class EagerToSubclassTest(_base.MappedTest):
-    """Test eagerloads to subclass mappers"""
-    
-    run_setup_classes = 'once'
-    run_setup_mappers = 'once'
-    run_inserts = 'once'
-    run_deletes = None
-    
-    @classmethod
-    def define_tables(cls, metadata):
-        Table('parent', metadata,
-            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
-            Column('data', String(10)),
-        )
-        
-        Table('base', metadata,
-            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
-            Column('type', String(10)),
-        )
-
-        Table('sub', metadata,
-            Column('id', Integer, ForeignKey('base.id'), primary_key=True),
-            Column('data', String(10)),
-            Column('parent_id', Integer, ForeignKey('parent.id'), nullable=False)
-        )
-    
-    @classmethod
-    @testing.resolve_artifact_names
-    def setup_classes(cls):
-        class Parent(_base.ComparableEntity):
-            pass
-        
-        class Base(_base.ComparableEntity):
-            pass
-        
-        class Sub(Base):
-            pass
-    
-    @classmethod
-    @testing.resolve_artifact_names
-    def setup_mappers(cls):
-        mapper(Parent, parent, properties={
-            'children':relation(Sub)
-        })
-        mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='b')
-        mapper(Sub, sub, inherits=Base, polymorphic_identity='s')
-    
-    @classmethod
-    @testing.resolve_artifact_names
-    def insert_data(cls):
-        sess = create_session()
-        p1 = Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')])
-        p2 = Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')])
-        sess.add(p1)
-        sess.add(p2)
-        sess.flush()
-        
-    @testing.resolve_artifact_names
-    def test_eagerload(self):
-        sess = create_session()
-        def go():
-            eq_(
-                sess.query(Parent).options(eagerload(Parent.children)).all(), 
-                [
-                    Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]),
-                    Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')])
-                ]
-            )
-        self.assert_sql_count(testing.db, go, 1)
-
-    @testing.resolve_artifact_names
-    def test_contains_eager(self):
-        sess = create_session()
-        def go():
-            eq_(
-                sess.query(Parent).join(Parent.children).options(contains_eager(Parent.children)).all(), 
-                [
-                    Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]),
-                    Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')])
-                ]
-            )
-        self.assert_sql_count(testing.db, go, 1)
-        
 class FlushTest(_base.MappedTest):
     """test dependency sorting among inheriting mappers"""
     @classmethod
index 243ed4a7ba860eb349a039c58097900249c73b9d..c74ddcad6f48d1fe952fb96924bb8ef112f3c55f 100644 (file)
@@ -1115,3 +1115,201 @@ class SelfReferentialM2MTest(_base.MappedTest, AssertsCompiledSQL):
         
         assert q.first() is c1
 
+class EagerToSubclassTest(_base.MappedTest):
+    """Test eagerloads to subclass mappers"""
+
+    run_setup_classes = 'once'
+    run_setup_mappers = 'once'
+    run_inserts = 'once'
+    run_deletes = None
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('parent', metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('data', String(10)),
+        )
+
+        Table('base', metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('type', String(10)),
+        )
+
+        Table('sub', metadata,
+            Column('id', Integer, ForeignKey('base.id'), primary_key=True),
+            Column('data', String(10)),
+            Column('parent_id', Integer, ForeignKey('parent.id'), nullable=False)
+        )
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_classes(cls):
+        class Parent(_base.ComparableEntity):
+            pass
+
+        class Base(_base.ComparableEntity):
+            pass
+
+        class Sub(Base):
+            pass
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        mapper(Parent, parent, properties={
+            'children':relation(Sub)
+        })
+        mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='b')
+        mapper(Sub, sub, inherits=Base, polymorphic_identity='s')
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def insert_data(cls):
+        sess = create_session()
+        p1 = Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')])
+        p2 = Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')])
+        sess.add(p1)
+        sess.add(p2)
+        sess.flush()
+
+    @testing.resolve_artifact_names
+    def test_eagerload(self):
+        sess = create_session()
+        def go():
+            eq_(
+                sess.query(Parent).options(eagerload(Parent.children)).all(), 
+                [
+                    Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]),
+                    Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')])
+                ]
+            )
+        self.assert_sql_count(testing.db, go, 1)
+
+    @testing.resolve_artifact_names
+    def test_contains_eager(self):
+        sess = create_session()
+        def go():
+            eq_(
+                sess.query(Parent).join(Parent.children).options(contains_eager(Parent.children)).all(), 
+                [
+                    Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]),
+                    Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')])
+                ]
+            )
+        self.assert_sql_count(testing.db, go, 1)
+
+class SubClassEagerToSubclassTest(_base.MappedTest):
+    """Test eagerloads from subclass to subclass mappers"""
+
+    run_setup_classes = 'once'
+    run_setup_mappers = 'once'
+    run_inserts = 'once'
+    run_deletes = None
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('parent', metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('type', String(10)),
+        )
+
+        Table('subparent', metadata,
+            Column('id', Integer, ForeignKey('parent.id'), primary_key=True),
+            Column('data', String(10)),
+        )
+
+        Table('base', metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('type', String(10)),
+        )
+
+        Table('sub', metadata,
+            Column('id', Integer, ForeignKey('base.id'), primary_key=True),
+            Column('data', String(10)),
+            Column('subparent_id', Integer, ForeignKey('subparent.id'), nullable=False)
+        )
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_classes(cls):
+        class Parent(_base.ComparableEntity):
+            pass
+
+        class Subparent(Parent):
+            pass
+
+        class Base(_base.ComparableEntity):
+            pass
+
+        class Sub(Base):
+            pass
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        mapper(Parent, parent, polymorphic_on=parent.c.type, polymorphic_identity='b')
+        mapper(Subparent, subparent, inherits=Parent, polymorphic_identity='s', properties={
+            'children':relation(Sub)
+        })
+        mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='b')
+        mapper(Sub, sub, inherits=Base, polymorphic_identity='s')
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def insert_data(cls):
+        sess = create_session()
+        p1 = Subparent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')])
+        p2 = Subparent(data='p2', children=[Sub(data='s4'), Sub(data='s5')])
+        sess.add(p1)
+        sess.add(p2)
+        sess.flush()
+
+    @testing.resolve_artifact_names
+    def test_eagerload(self):
+        sess = create_session()
+        def go():
+            eq_(
+                sess.query(Subparent).options(eagerload(Subparent.children)).all(), 
+                [
+                    Subparent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]),
+                    Subparent(data='p2', children=[Sub(data='s4'), Sub(data='s5')])
+                ]
+            )
+        self.assert_sql_count(testing.db, go, 1)
+
+        sess.expunge_all()
+        def go():
+            eq_(
+                sess.query(Subparent).options(eagerload("children")).all(), 
+                [
+                    Subparent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]),
+                    Subparent(data='p2', children=[Sub(data='s4'), Sub(data='s5')])
+                ]
+            )
+        self.assert_sql_count(testing.db, go, 1)
+
+    @testing.resolve_artifact_names
+    def test_contains_eager(self):
+        sess = create_session()
+        def go():
+            eq_(
+                sess.query(Subparent).join(Subparent.children).options(contains_eager(Subparent.children)).all(), 
+                [
+                    Subparent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]),
+                    Subparent(data='p2', children=[Sub(data='s4'), Sub(data='s5')])
+                ]
+            )
+        self.assert_sql_count(testing.db, go, 1)
+        sess.expunge_all()
+
+        def go():
+            eq_(
+                sess.query(Subparent).join(Subparent.children).options(contains_eager("children")).all(), 
+                [
+                    Subparent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]),
+                    Subparent(data='p2', children=[Sub(data='s4'), Sub(data='s5')])
+                ]
+            )
+        self.assert_sql_count(testing.db, go, 1)
+
+