]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fix to cascades on polymorphic relations, such that cascades
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 11 Dec 2007 17:29:08 +0000 (17:29 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 11 Dec 2007 17:29:08 +0000 (17:29 +0000)
from an object to a polymorphic collection continue cascading
along the set of attributes specific to each element in the collection.

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
test/orm/inheritance/basic.py

diff --git a/CHANGES b/CHANGES
index 6154bc3d18e631c3760b4a21b3ce44cc4b42e0ef..5a9893368cff6933a77c71cacc188fb731670dd3 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -87,6 +87,10 @@ CHANGES
    - also with dynamic, implemented correct count() behavior as well
      as other helper methods.
      
+   - fix to cascades on polymorphic relations, such that cascades
+     from an object to a polymorphic collection continue cascading 
+     along the set of attributes specific to each element in the collection.
+     
    - query.get() and query.load() do not take existing filter or other
      criterion into account; these methods *always* look up the given id
      in the database or return the current instance from the identity map, 
index 0f5dbaaf5468293f7ca7dd8901873fe24ad1e2c2..e9fe41fdc7ae11a37f119b2f4a559939021cb3fd 100644 (file)
@@ -1538,8 +1538,8 @@ def has_mapper(object):
 
     return hasattr(object, '_entity_name')
 
-def _state_mapper(state):
-    return state.class_._class_state.mappers[state.dict.get('_entity_name', None)]
+def _state_mapper(state, entity_name=None):
+    return state.class_._class_state.mappers[state.dict.get('_entity_name', entity_name)]
 
 def object_mapper(object, entity_name=None, raiseerror=True):
     """Given an object, return the primary Mapper associated with the object instance.
index 9394e9aeadd5b91613e90362c04870404c1ac4b4..4d41556a07527de4174017c62818a7b35a055386 100644 (file)
@@ -13,7 +13,7 @@ to handle flush-time dependency sorting and processing.
 
 from sqlalchemy import sql, schema, util, exceptions, logging
 from sqlalchemy.sql import util as sql_util, visitors, operators, ColumnElement
-from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency
+from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper
 from sqlalchemy.orm import session as sessionlib
 from sqlalchemy.orm import util as mapperutil
 from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty
@@ -365,8 +365,11 @@ class PropertyLoader(StrategizedProperty):
                     if not isinstance(c, self.mapper.class_):
                         raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__)))
                     recursive.add(c)
-                    yield (c, mapper)
-                    for (c2, m) in mapper.cascade_iterator(type, c._state, recursive):
+
+                    # cascade using the mapper local to this object, so that its individual properties are located
+                    instance_mapper = object_mapper(c, entity_name=mapper.entity_name)  
+                    yield (c, instance_mapper)
+                    for (c2, m) in instance_mapper.cascade_iterator(type, c._state, recursive):
                         yield (c2, m)
 
     def _get_target_class(self):
index 05603ac864cb7a24c1bf5f59b6c0fca3d5ba23fe..2ef76b6d8da762c69c017b538801963d35a5b7b4 100644 (file)
@@ -9,7 +9,6 @@ class O2MTest(ORMTest):
     """deals with inheritance and one-to-many relationships"""
     def define_tables(self, metadata):
         global foo, bar, blub
-        # the 'data' columns are to appease SQLite which cant handle a blank INSERT
         foo = Table('foo', metadata,
             Column('id', Integer, Sequence('foo_seq', optional=True),
                    primary_key=True),
@@ -65,7 +64,76 @@ class O2MTest(ORMTest):
         self.assert_(compare == result)
         self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1')
 
+class CascadeTest(ORMTest):
+    """that cascades on polymorphic relations continue
+    cascading along the path of the instance's mapper, not
+    the base mapper."""
+    
+    def define_tables(self, metadata):
+        global t1, t2, t3, t4
+        t1= Table('t1', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', String(30))
+            )
+            
+        t2 = Table('t2', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('t1id', Integer, ForeignKey('t1.id')),
+            Column('type', String(30)),
+            Column('data', String(30))
+        )
+        t3 = Table('t3', metadata, 
+            Column('id', Integer, ForeignKey('t2.id'), primary_key=True),
+            Column('moredata', String(30)))
+            
+        t4 = Table('t4', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('t3id', Integer, ForeignKey('t3.id')),
+            Column('data', String(30)))
+            
+    def test_cascade(self):
+        class T1(fixtures.Base):
+            pass
+        class T2(fixtures.Base):
+            pass
+        class T3(T2):
+            pass
+        class T4(fixtures.Base):
+            pass
+        
+        mapper(T1, t1, properties={
+            't2s':relation(T2, cascade="all")
+        })
+        mapper(T2, t2, polymorphic_on=t2.c.type, polymorphic_identity='t2') 
+        mapper(T3, t3, inherits=T2, polymorphic_identity='t3', properties={
+            't4s':relation(T4, cascade="all")
+        })
+        mapper(T4, t4)
+        
+        sess = create_session()
+        t1_1 = T1(data='t1')
+
+        t3_1 = T3(data ='t3', moredata='t3')
+        t2_1 = T2(data='t2')
+
+        t1_1.t2s.append(t2_1)
+        t1_1.t2s.append(t3_1)
+        
+        t4_1 = T4(data='t4')
+        t3_1.t4s.append(t4_1)
+
+        sess.save(t1_1)
 
+        
+        assert t4_1 in sess.new
+        sess.flush()
+        
+        sess.delete(t1_1)
+        assert t4_1 in sess.deleted
+        sess.flush()
+        
+    
+    
 class GetTest(ORMTest):
     def define_tables(self, metadata):
         global foo, bar, blub