]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Added an assertion that prevents a @validates function
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 12 Sep 2009 19:59:39 +0000 (19:59 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 12 Sep 2009 19:59:39 +0000 (19:59 +0000)
or other AttributeExtension from loading an unloaded
collection such that internal state may be corrupted.
[ticket:1526]

CHANGES
lib/sqlalchemy/orm/attributes.py
test/orm/test_attributes.py

diff --git a/CHANGES b/CHANGES
index 68dd7cbe02e86936d4db7591afddff25340d538a..d3a5fb2e17a73787b8ec20a9837eb258e4093f15 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -390,7 +390,12 @@ CHANGES
     - Fixed bug which disallowed one side of a many-to-many 
       bidirectional reference to declare itself as "viewonly"
       [ticket:1507]
-
+    
+    - Added an assertion that prevents a @validates function
+      or other AttributeExtension from loading an unloaded
+      collection such that internal state may be corrupted.
+      [ticket:1526]
+      
     - Fixed bug which prevented two entities from mutually
       replacing each other's primary key values within a single
       flush() for some orderings of operations.  [ticket:1519]
index 3ca7d831198c93ffa03eead901dd3dc0cb01075d..6fa8d54c4e997241a9b60119ceda70e911c9417c 100644 (file)
@@ -700,6 +700,7 @@ class CollectionAttributeImpl(AttributeImpl):
         collection = self.get_collection(state, dict_, passive=passive)
         if collection is PASSIVE_NO_RESULT:
             value = self.fire_append_event(state, dict_, value, initiator)
+            assert self.key not in dict_, "Collection was loaded during event handling."
             state.get_pending(self.key).append(value)
         else:
             collection.append_with_event(value, initiator)
@@ -711,6 +712,7 @@ class CollectionAttributeImpl(AttributeImpl):
         collection = self.get_collection(state, state.dict, passive=passive)
         if collection is PASSIVE_NO_RESULT:
             self.fire_remove_event(state, dict_, value, initiator)
+            assert self.key not in dict_, "Collection was loaded during event handling."
             state.get_pending(self.key).remove(value)
         else:
             collection.remove_with_event(value, initiator)
index b481b06791e99f7674e640255f24bc1757902c0f..99f0a49d39c470809554d7c3ce44a241f9936077 100644 (file)
@@ -4,7 +4,7 @@ from sqlalchemy.orm.collections import collection
 from sqlalchemy.orm.interfaces import AttributeExtension
 from sqlalchemy import exc as sa_exc
 from sqlalchemy.test import *
-from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.testing import eq_, assert_raises
 from test.orm import _base
 from sqlalchemy.test.util import gc_collect
 from sqlalchemy.util import cmp, jython
@@ -220,7 +220,48 @@ class AttributesTest(_base.ORMTest):
         a.email_address = 'foo@bar.com'
         u.addresses.append(a)
         self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com')
+    
+    def test_extension_lazyload_assertion(self):
+        class Foo(_base.BasicEntity):
+            pass
+        class Bar(_base.BasicEntity):
+            pass
+
+        class ReceiveEvents(AttributeExtension):
+            def append(self, state, child, initiator):
+                state.obj().bars
+                return child
+
+            def remove(self, state, child, initiator):
+                state.obj().bars
+                return child
+
+            def set(self, state, child, oldchild, initiator):
+                return child
 
+        attributes.register_class(Foo)
+        attributes.register_class(Bar)
+
+        bar1, bar2, bar3 = [Bar(id=1), Bar(id=2), Bar(id=3)]
+        def func1(**kw):
+            if kw.get('passive') is attributes.PASSIVE_NO_FETCH:
+                return attributes.PASSIVE_NO_RESULT
+            
+            return [bar1, bar2, bar3]
+
+        attributes.register_attribute(Foo, 'bars', uselist=True, callable_=lambda o:func1, useobject=True, extension=[ReceiveEvents()])
+        attributes.register_attribute(Bar, 'foos', uselist=True, useobject=True, extension=[attributes.GenericBackrefExtension('bars')])
+
+        x = Foo()
+        assert_raises(AssertionError, Bar(id=4).foos.append, x)
+        
+        x.bars
+        b = Bar(id=4)
+        b.foos.append(x)
+        attributes.instance_state(x).expire_attributes(['bars'])
+        assert_raises(AssertionError, b.foos.remove, x)
+        
+        
     def test_scalar_listener(self):
         # listeners on ScalarAttributeImpl and MutableScalarAttributeImpl aren't used normally.
         # test that they work for the benefit of user extensions