]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fixed three- and multi-level select and deferred inheritance
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Sep 2007 20:17:40 +0000 (20:17 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Sep 2007 20:17:40 +0000 (20:17 +0000)
  loading (i.e. abc inheritance with no select_table), [ticket:795]

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/strategies.py
test/orm/inheritance/abc_polymorphic.py [new file with mode: 0644]
test/orm/inheritance/alltests.py
test/orm/inheritance/basic.py
test/testlib/fixtures.py

diff --git a/CHANGES b/CHANGES
index 095e582b49b509739bc0a0ace1c33c6f8ee6fe6e..d39c17e9a0d39909b11c3d377e1d9a291dd1bc4f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -16,6 +16,9 @@ CHANGES
 - firebird has supports_sane_rowcount and supports_sane_multi_rowcount set 
   to False due to ticket #370 (right way).
 
+- fixed three- and multi-level select and deferred inheritance
+  loading (i.e. abc inheritance with no select_table), [ticket:795]
+
 0.4.0beta6
 ----------
 
index 9764a0ae63d1a1ad1c9c35a57315ba6d2de991d9..b2bffd6ea523cf64135e2dacd6393cedf7fd782b 100644 (file)
@@ -1442,10 +1442,7 @@ class Mapper(object):
         
         return instance
 
-    def _deferred_inheritance_condition(self, needs_tables):
-        cond = self.inherit_condition
-
-        param_names = []
+    def _deferred_inheritance_condition(self, base_mapper, needs_tables):
         def visit_binary(binary):
             leftcol = binary.left
             rightcol = binary.right
@@ -1457,8 +1454,17 @@ class Mapper(object):
             elif rightcol not in needs_tables:
                 binary.right = sql.bindparam(rightcol.name, None, type_=binary.right.type, unique=True)
                 param_names.append(rightcol)
-        cond = mapperutil.BinaryVisitor(visit_binary).traverse(cond, clone=True)
-        return cond, param_names
+
+        allconds = []
+        param_names = []
+
+        visitor = mapperutil.BinaryVisitor(visit_binary)
+        for mapper in self.iterate_to_root():
+            if mapper is base_mapper:
+                break
+            allconds.append(visitor.traverse(mapper.inherit_condition, clone=True))
+        
+        return sql.and_(*allconds), param_names
 
     def translate_row(self, tomapper, row):
         """Translate the column keys of a row into a new or proxied
@@ -1532,7 +1538,7 @@ class Mapper(object):
         if hosted_mapper is None or len(needs_tables)==0 or hosted_mapper.polymorphic_fetch == 'deferred':
             return
         
-        cond, param_names = self._deferred_inheritance_condition(needs_tables)
+        cond, param_names = self._deferred_inheritance_condition(hosted_mapper, needs_tables)
         statement = sql.select(needs_tables, cond, use_labels=True)
         def post_execute(instance, **flags):
             self.__log_debug("Post query loading instance " + mapperutil.instance_str(instance))
index b93993af8dba3640e930c94276ab4394c71c3e32..09b51c203f2bad6e5bececca26bd5c126e8d4d3a 100644 (file)
@@ -89,7 +89,7 @@ class ColumnLoader(LoaderStrategy):
             # 'deferred' polymorphic row fetcher, put a callable on the property.
             def new_execute(instance, row, isnew, **flags):
                 if isnew:
-                    sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self._get_deferred_inheritance_loader(instance, mapper, needs_tables))
+                    sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self._get_deferred_inheritance_loader(instance, mapper, hosted_mapper, needs_tables))
             if self._should_log_debug:
                 self.logger.debug("Returning deferred column fetcher for %s %s" % (mapper, self.key))
             return (new_execute, None, None)
@@ -99,18 +99,30 @@ class ColumnLoader(LoaderStrategy):
                 self.logger.debug("Returning no column fetcher for %s %s" % (mapper, self.key))
             return (None, None, None)
 
-    def _get_deferred_inheritance_loader(self, instance, mapper, needs_tables):
+    def _get_deferred_inheritance_loader(self, instance, mapper, hosted_mapper, needs_tables):
+        # create a deferred column loader which will query the remaining not-yet-loaded tables in an inheritance load.
+        # the mapper for the object creates the WHERE criterion using the mapper who originally 
+        # "hosted" the query and the list of tables which are unloaded between the "hosted" mapper
+        # and this mapper.  (i.e. A->B->C, the query used mapper A.  therefore will need B's and C's tables
+        # in the query).
         def create_statement():
-            cond, param_names = mapper._deferred_inheritance_condition(needs_tables)
+            # TODO: the SELECT statement here should be cached in the selectcontext.  we are somewhat duplicating 
+            # efforts from mapper._get_poly_select_loader as well and should look
+            # for ways to simplify.
+            cond, param_names = mapper._deferred_inheritance_condition(hosted_mapper, needs_tables)
             statement = sql.select(needs_tables, cond, use_labels=True)
             params = {}
             for c in param_names:
                 params[c.name] = mapper.get_attr_by_column(instance, c)
             return (statement, params)
             
+        # install the create_statement() callable using the deferred loading strategy
         strategy = self.parent_property._get_strategy(DeferredColumnLoader)
 
+        # assemble list of all ColumnProperties which will need to be loaded
         props = [p for p in mapper.iterate_properties if isinstance(p.strategy, ColumnLoader) and p.columns[0].table in needs_tables]
+        
+        # set the deferred loader on the instance attribute
         return strategy.setup_loader(instance, props=props, create_statement=create_statement)
 
 
diff --git a/test/orm/inheritance/abc_polymorphic.py b/test/orm/inheritance/abc_polymorphic.py
new file mode 100644 (file)
index 0000000..da90976
--- /dev/null
@@ -0,0 +1,90 @@
+import testbase
+from sqlalchemy import *
+from sqlalchemy import exceptions, util
+from sqlalchemy.orm import *
+from testlib import *
+from testlib import fixtures
+
+class ABCTest(ORMTest):
+    def define_tables(self, metadata):
+        global a, b, c
+        a = Table('a', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('adata', String(30)),
+            Column('type', String(30)),
+            )
+        b = Table('b', metadata,
+            Column('id', Integer, ForeignKey('a.id'), primary_key=True),
+            Column('bdata', String(30)))
+        c = Table('c', metadata, 
+            Column('id', Integer, ForeignKey('b.id'), primary_key=True),
+            Column('cdata', String(30)))
+    
+    def make_test(fetchtype):
+        def test_roundtrip(self):
+            class A(fixtures.Base):pass
+            class B(A):pass
+            class C(B):pass
+        
+            if fetchtype == 'union':
+                abc = a.outerjoin(b).outerjoin(c)
+                bc = a.join(b).outerjoin(c)
+            else:
+                abc = bc = None
+                    
+            mapper(A, a, select_table=abc, polymorphic_on=a.c.type, polymorphic_identity='a', polymorphic_fetch=fetchtype)
+            mapper(B, b, select_table=bc, inherits=A, polymorphic_identity='b', polymorphic_fetch=fetchtype)
+            mapper(C, c, inherits=B, polymorphic_identity='c')
+        
+            a1 = A(adata='a1')
+            b1 = B(bdata='b1', adata='b1')
+            b2 = B(bdata='b2', adata='b2')
+            b3 = B(bdata='b3', adata='b3')
+            c1 = C(cdata='c1', bdata='c1', adata='c1')
+            c2 = C(cdata='c2', bdata='c2', adata='c2')
+            c3 = C(cdata='c2', bdata='c2', adata='c2')
+        
+            sess = create_session()
+            for x in (a1, b1, b2, b3, c1, c2, c3):
+                sess.save(x)
+            sess.flush()
+            sess.clear()
+        
+            #for obj in sess.query(A).all():
+            #    print obj
+            assert [
+                A(adata='a1'),
+                B(bdata='b1', adata='b1'),
+                B(bdata='b2', adata='b2'),
+                B(bdata='b3', adata='b3'),
+                C(cdata='c1', bdata='c1', adata='c1'),
+                C(cdata='c2', bdata='c2', adata='c2'),
+                C(cdata='c2', bdata='c2', adata='c2'),
+            ] == sess.query(A).all()
+
+            assert [
+                B(bdata='b1', adata='b1'),
+                B(bdata='b2', adata='b2'),
+                B(bdata='b3', adata='b3'),
+                C(cdata='c1', bdata='c1', adata='c1'),
+                C(cdata='c2', bdata='c2', adata='c2'),
+                C(cdata='c2', bdata='c2', adata='c2'),
+            ] == sess.query(B).all()
+
+            assert [
+                C(cdata='c1', bdata='c1', adata='c1'),
+                C(cdata='c2', bdata='c2', adata='c2'),
+                C(cdata='c2', bdata='c2', adata='c2'),
+            ] == sess.query(C).all()
+
+        test_roundtrip.__name__ = 'test_%s' % fetchtype
+        return test_roundtrip
+        
+    test_union = make_test('union')
+    test_select = make_test('select')
+    test_deferred = make_test('deferred')
+    
+        
+if __name__ == '__main__':
+    testbase.main()
+        
\ No newline at end of file
index da59dd8fb72b9c83a907a3534a3d5e124cfea922..dc93ed9b38a856b29043fa2e0c6397345de61f6a 100644 (file)
@@ -10,6 +10,7 @@ def suite():
         'orm.inheritance.polymorph',
         'orm.inheritance.polymorph2',
         'orm.inheritance.poly_linked_list',
+        'orm.inheritance.abc_polymorphic',
         'orm.inheritance.abc_inheritance',
         'orm.inheritance.productspec',
         'orm.inheritance.magazine',
index fbdb4019e1fb9cd7834f6a01f409466e6f5145d5..a033d61eacebc123f8017a3937432b9f93fb5a25 100644 (file)
@@ -63,6 +63,7 @@ class O2MTest(ORMTest):
         self.assert_(compare == result)
         self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1')
 
+
 class GetTest(ORMTest):    
     def define_tables(self, metadata):
         global foo, bar, blub
index 1b05b366daa25d0d36cdde27cff3de82f8671261..ada254c375afb528749e20c057d371f4e1cf090b 100644 (file)
@@ -10,11 +10,11 @@ class Base(object):
             setattr(self, k, kwargs[k])
     
     # TODO: add recursion checks to this
-    #def __repr__(self):
-    #    return "%s(%s)" % (
-    #        (self.__class__.__name__), 
-    #        ','.join(["%s=%s" % (key, repr(getattr(self, key))) for key in self.__dict__ if not key.startswith('_')])
-    #    )
+    def __repr__(self):
+        return "%s(%s)" % (
+            (self.__class__.__name__), 
+            ','.join(["%s=%s" % (key, repr(getattr(self, key))) for key in self.__dict__ if not key.startswith('_')])
+        )
     
     def __ne__(self, other):
         return not self.__eq__(other)