]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed an obscure issue whereby a joined-table subclass
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 28 Aug 2009 20:29:08 +0000 (20:29 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 28 Aug 2009 20:29:08 +0000 (20:29 +0000)
with a self-referential eager load on the base class
would populate the related object's "subclass" table with
data from the "subclass" table of the parent.
[ticket:1485]

CHANGES
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/util.py
test/orm/inheritance/test_basic.py

diff --git a/CHANGES b/CHANGES
index c81af33be465a497a4e9b474feb5f06d71b2b3d6..fad467e8acba66c265b71aa93fe3f45009c4953c 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -358,6 +358,12 @@ CHANGES
       composite primary key would fail on updates.
       Continuation of [ticket:1300].
     
+    - Fixed an obscure issue whereby a joined-table subclass
+      with a self-referential eager load on the base class
+      would populate the related object's "subclass" table with
+      data from the "subclass" table of the parent.
+      [ticket:1485]
+      
     - relations() now have greater ability to be "overridden",
       meaning a subclass that explicitly specifies a relation()
       overriding that of the parent class will be honored
index a76eae0e0cc6ba7ddd23e93323948be27b9cf6c4..b3290a2d6348ce3c7588b22f31ef6e6607f4becd 100644 (file)
@@ -116,7 +116,7 @@ class ColumnLoader(LoaderStrategy):
         key, col = self.key, self.columns[0]
         if adapter:
             col = adapter.columns[col]
-        if col in row:
+        if col is not None and col in row:
             def new_execute(state, dict_, row, **flags):
                 dict_[key] = row[col]
                 
@@ -700,7 +700,7 @@ class EagerLoader(AbstractRelationLoader):
 
         # create AliasedClauses object to build up the eager query.  
         clauses = mapperutil.ORMAdapter(mapperutil.AliasedClass(self.mapper), 
-                    equivalents=self.mapper._equivalent_columns)
+                    equivalents=self.mapper._equivalent_columns, adapt_required=True)
 
         join_to_left = False
         if adapter:
index c858ca10265f8ae6657f9219dd1e3b795b3780ee..bc23d8c6d8bc14efc8e46f916c1fa5fe8520208c 100644 (file)
@@ -257,13 +257,13 @@ class ORMAdapter(sql_util.ColumnAdapter):
     and the AliasedClass if any is referenced.
 
     """
-    def __init__(self, entity, equivalents=None, chain_to=None):
+    def __init__(self, entity, equivalents=None, chain_to=None, adapt_required=False):
         self.mapper, selectable, is_aliased_class = _entity_info(entity)
         if is_aliased_class:
             self.aliased_class = entity
         else:
             self.aliased_class = None
-        sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to)
+        sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to, adapt_required=adapt_required)
 
     def replace(self, elem):
         entity = elem._annotations.get('parentmapper', None)
index 27ae3e624750ed4b7ce19b8d9935fc430ed78630..9be405e2192b47c0abb44e0c09b5eada3773b201 100644 (file)
@@ -498,11 +498,12 @@ class ColumnAdapter(ClauseAdapter):
     adapted_row() factory.
     
     """
-    def __init__(self, selectable, equivalents=None, chain_to=None, include=None, exclude=None):
+    def __init__(self, selectable, equivalents=None, chain_to=None, include=None, exclude=None, adapt_required=False):
         ClauseAdapter.__init__(self, selectable, equivalents, include, exclude)
         if chain_to:
             self.chain(chain_to)
         self.columns = util.populate_column_dict(self._locate_col)
+        self.adapt_required = adapt_required
 
     def wrap(self, adapter):
         ac = self.__class__.__new__(self.__class__)
@@ -530,6 +531,16 @@ class ColumnAdapter(ClauseAdapter):
             # anonymize labels in case they have a hardcoded name
             if isinstance(c, expression._Label):
                 c = c.label(None)
+                
+        # adapt_required indicates that if we got the same column
+        # back which we put in (i.e. it passed through), 
+        # it's not correct.  this is used by eagerloading which
+        # knows that all columns and expressions need to be adapted
+        # to a result row, and a "passthrough" is definitely targeting
+        # the wrong column.
+        if self.adapt_required and c is col:
+            return None
+            
         return c    
 
     def adapted_row(self, row):
index 778b08a272a07ab1808f414b3fdebc572cbee2ca..713ae3b5fecfcf8edc35f824a88e6938be87f858 100644 (file)
@@ -392,7 +392,62 @@ class EagerLazyTest(_base.MappedTest):
         self.assert_(len(q.first().lazy) == 1)
         self.assert_(len(q.first().eager) == 1)
 
+class EagerTargetingTest(_base.MappedTest):
+    """test a scenario where joined table inheritance might be confused as an eagerly loaded joined table."""
+    
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('a_table', metadata,
+           Column('id', Integer, primary_key=True),
+           Column('name', String(50)),
+           Column('type', String(30), nullable=False),
+           Column('parent_id', Integer, ForeignKey('a_table.id'))
+        )
+
+        Table('b_table', metadata,
+           Column('id', Integer, ForeignKey('a_table.id'), primary_key=True),
+           Column('b_data', String(50)),
+        )
+    
+    @testing.resolve_artifact_names
+    def test_adapt_stringency(self):
+        class A(_base.ComparableEntity):
+            pass
+        class B(A):
+            pass
+        
+        mapper(A, a_table, polymorphic_on=a_table.c.type, polymorphic_identity='A', 
+                properties={
+                    'children': relation(A, order_by=a_table.c.name)
+            })
+
+        mapper(B, b_table, inherits=A, polymorphic_identity='B', properties={
+                'b_derived':column_property(b_table.c.b_data + "DATA")
+                })
+        
+        sess=create_session()
 
+        b1=B(id=1, name='b1',b_data='i')
+        sess.add(b1)
+        sess.flush()
+
+        b2=B(id=2, name='b2', b_data='l', parent_id=1)
+        sess.add(b2)
+        sess.flush()
+
+        bid=b1.id
+
+        sess.expunge_all()
+        node = sess.query(B).filter(B.id==bid).all()[0]
+        eq_(node, B(id=1, name='b1',b_data='i'))
+        eq_(node.children[0], B(id=2, name='b2',b_data='l'))
+        
+        sess.expunge_all()
+        node = sess.query(B).options(eagerload(B.children)).filter(B.id==bid).all()[0]
+        eq_(node, B(id=1, name='b1',b_data='i'))
+        eq_(node.children[0], B(id=2, name='b2',b_data='l'))
+        
+        
 class FlushTest(_base.MappedTest):
     """test dependency sorting among inheriting mappers"""
     @classmethod