]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed recursion bug which could occur when moving
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 13 Sep 2010 06:09:38 +0000 (02:09 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 13 Sep 2010 06:09:38 +0000 (02:09 -0400)
an object from one reference to another, with
backrefs involved, where the initiating parent
was a subclass (with its own mapper) of the
previous parent.

CHANGES
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/dynamic.py
test/orm/test_attributes.py

diff --git a/CHANGES b/CHANGES
index c91596a092de2eb7aada9b4692f1a65773e3a020..a39fac9f0b0830f3f8191e29dd6602dc7646dfc3 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -6,6 +6,12 @@ CHANGES
 0.6.5
 =====
 - orm
+  - Fixed recursion bug which could occur when moving
+    an object from one reference to another, with 
+    backrefs involved, where the initiating parent
+    was a subclass (with its own mapper) of the 
+    previous parent.
+    
   - Added an assertion during flush which ensures
     that no NULL-holding identity keys were generated
     on "newly persistent" objects.
index 33069332d96605eab5dade988809e95610d707cc..b56de5f05dc93a07a0d8a106fd6450b697e27a43 100644 (file)
@@ -455,7 +455,7 @@ class ScalarAttributeImpl(AttributeImpl):
             self, state, dict_.get(self.key, NO_VALUE))
 
     def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
-        if initiator is self:
+        if initiator and initiator.parent_token is self.parent_token:
             return
 
         if self.active_history:
@@ -534,7 +534,7 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl):
         state.mutable_dict.pop(self.key)
 
     def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
-        if initiator is self:
+        if initiator and initiator.parent_token is self.parent_token:
             return
 
         if self.extensions:
@@ -596,7 +596,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         setter operation.
 
         """
-        if initiator is self:
+        if initiator and initiator.parent_token is self.parent_token:
             return
 
         if self.active_history:
@@ -622,7 +622,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
                 previous is not None and
                 previous is not PASSIVE_NO_RESULT):
                 self.sethasparent(instance_state(previous), False)
-
+        
         for ext in self.extensions:
             value = ext.set(state, value, previous, initiator or self)
 
@@ -726,7 +726,7 @@ class CollectionAttributeImpl(AttributeImpl):
             self.key, state, self.collection_factory)
 
     def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
-        if initiator is self:
+        if initiator and initiator.parent_token is self.parent_token:
             return
 
         collection = self.get_collection(state, dict_, passive=passive)
@@ -739,7 +739,7 @@ class CollectionAttributeImpl(AttributeImpl):
             collection.append_with_event(value, initiator)
 
     def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
-        if initiator is self:
+        if initiator and initiator.parent_token is self.parent_token:
             return
 
         collection = self.get_collection(state, state.dict, passive=passive)
@@ -759,7 +759,7 @@ class CollectionAttributeImpl(AttributeImpl):
         setter operation.
         """
 
-        if initiator is self:
+        if initiator and initiator.parent_token is self.parent_token:
             return
 
         self._set_iterable(
index 0789d9626c9f6b1f4f773f2d88ca486871ed2233..b523295232b6dca56e1321c17dc79006c2059a8e 100644 (file)
@@ -587,7 +587,7 @@ class CollectionAdapter(object):
     def fire_append_event(self, item, initiator=None):
         """Notify that a entity has entered the collection.
 
-        Initiator is the InstrumentedAttribute that initiated the membership
+        Initiator is a token owned by the InstrumentedAttribute that initiated the membership
         mutation, and should be left as None unless you are passing along
         an initiator value from a chained operation.
 
index d558380114fd4cee7230c1c7b7ccfe0a55084974..c5ddaca40be7644f5def5f03d7477b9781470763 100644 (file)
@@ -112,7 +112,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
 
     def set(self, state, dict_, value, initiator,
                         passive=attributes.PASSIVE_OFF):
-        if initiator is self:
+        if initiator and initiator.parent_token is self.parent_token:
             return
 
         self._set_iterable(state, dict_, value)
index 3a8a320e352ad55871bd8664979dccfb1ef03ed6..742e9d87475612c4a13e75cde2ea4ba86815f3bc 100644 (file)
@@ -693,14 +693,18 @@ class UtilTest(_base.ORMTest):
 
 class BackrefTest(_base.ORMTest):
 
-    def test_manytomany(self):
+    def test_m2m(self):
         class Student(object):pass
         class Course(object):pass
 
         attributes.register_class(Student)
         attributes.register_class(Course)
-        attributes.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students'), useobject=True)
-        attributes.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses'), useobject=True)
+        attributes.register_attribute(Student, 'courses', uselist=True,
+                extension=attributes.GenericBackrefExtension('students'
+                ), useobject=True)
+        attributes.register_attribute(Course, 'students', uselist=True,
+                extension=attributes.GenericBackrefExtension('courses'
+                ), useobject=True)
 
         s = Student()
         c = Course()
@@ -717,14 +721,18 @@ class BackrefTest(_base.ORMTest):
         s1.courses.remove(c)
         self.assert_(c.students == [s2,s3])
 
-    def test_onetomany(self):
+    def test_o2m(self):
         class Post(object):pass
         class Blog(object):pass
 
         attributes.register_class(Post)
         attributes.register_class(Blog)
-        attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True)
-        attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True)
+        attributes.register_attribute(Post, 'blog', uselist=False,
+                extension=attributes.GenericBackrefExtension('posts'),
+                trackparent=True, useobject=True)
+        attributes.register_attribute(Blog, 'posts', uselist=True,
+                extension=attributes.GenericBackrefExtension('blog'),
+                trackparent=True, useobject=True)
         b = Blog()
         (p1, p2, p3) = (Post(), Post(), Post())
         b.posts.append(p1)
@@ -748,13 +756,17 @@ class BackrefTest(_base.ORMTest):
         p5.blog = None
         del p5.blog
 
-    def test_onetoone(self):
+    def test_o2o(self):
         class Port(object):pass
         class Jack(object):pass
         attributes.register_class(Port)
         attributes.register_class(Jack)
-        attributes.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port'), useobject=True)
-        attributes.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack'), useobject=True)
+        attributes.register_attribute(Port, 'jack', uselist=False,
+                extension=attributes.GenericBackrefExtension('port'),
+                useobject=True)
+        attributes.register_attribute(Jack, 'port', uselist=False,
+                extension=attributes.GenericBackrefExtension('jack'),
+                useobject=True)
         p = Port()
         j = Jack()
         p.jack = j
@@ -764,6 +776,96 @@ class BackrefTest(_base.ORMTest):
         j.port = None
         self.assert_(p.jack is None)
 
+    def test_symmetric_o2o_inheritance(self):
+        """Test that backref 'initiator' catching goes against
+        a token that is global to all InstrumentedAttribute objects
+        within a particular class, not just the indvidual IA object
+        since we use distinct objects in an inheritance scenario.
+        
+        """
+        class Parent(object):
+            pass
+        class Child(object):
+            pass
+        class SubChild(Child):
+            pass
+
+        p_token = object()
+        c_token = object()
+        
+        attributes.register_class(Parent)
+        attributes.register_class(Child)
+        attributes.register_class(SubChild)
+        attributes.register_attribute(Parent, 'child', uselist=False,
+                extension=attributes.GenericBackrefExtension('parent'),
+                parent_token = p_token,
+                useobject=True)
+        attributes.register_attribute(Child, 'parent', uselist=False,
+                extension=attributes.GenericBackrefExtension('child'),
+                parent_token = c_token,
+                useobject=True)
+        attributes.register_attribute(SubChild, 'parent',
+                uselist=False,
+                extension=attributes.GenericBackrefExtension('child'),
+                parent_token = c_token,
+                useobject=True)
+        
+        p1 = Parent()
+        c1 = Child()
+        p1.child = c1
+        
+        c2 = SubChild()
+        c2.parent = p1
+
+    def test_symmetric_o2m_inheritance(self):
+        class Parent(object):
+            pass
+        class SubParent(Parent):
+            pass
+        class Child(object):
+            pass
+
+        p_token = object()
+        c_token = object()
+        
+        attributes.register_class(Parent)
+        attributes.register_class(SubParent)
+        attributes.register_class(Child)
+        attributes.register_attribute(Parent, 'children', uselist=True,
+                extension=attributes.GenericBackrefExtension('parent'),
+                parent_token = p_token,
+                useobject=True)
+        attributes.register_attribute(SubParent, 'children', uselist=True,
+                extension=attributes.GenericBackrefExtension('parent'),
+                parent_token = p_token,
+                useobject=True)
+        attributes.register_attribute(Child, 'parent', uselist=False,
+                extension=attributes.GenericBackrefExtension('children'),
+                parent_token = c_token,
+                useobject=True)
+        
+        p1 = Parent()
+        p2 = SubParent()
+        c1 = Child()
+        
+        p1.children.append(c1)
+
+        assert c1.parent is p1
+        assert c1 in p1.children
+        
+        p2.children.append(c1)
+        assert c1.parent is p2
+        
+        # note its still in p1.children -
+        # the event model currently allows only
+        # one level deep.  without the parent_token,
+        # it keeps going until a ValueError is raised
+        # and this condition changes.
+        assert c1 in p1.children
+        
+        
+        
+        
 class PendingBackrefTest(_base.ORMTest):
     def setup(self):
         global Post, Blog, called, lazy_load