]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
moved modified_event() calls below the attribute extension fires. this basically...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Nov 2009 21:54:56 +0000 (21:54 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Nov 2009 21:54:56 +0000 (21:54 +0000)
lib/sqlalchemy/orm/attributes.py
test/orm/test_attributes.py

index dfc415d2a99759c0be4920e599d22d0c4f58958b..9fbcf3d20e36ce015994288eb0f77d1874ea103b 100644 (file)
@@ -431,10 +431,9 @@ class ScalarAttributeImpl(AttributeImpl):
         else:
             old = dict_.get(self.key, NO_VALUE)
 
-        state.modified_event(dict_, self, False, old)
-
         if self.extensions:
             self.fire_remove_event(state, dict_, old, None)
+        state.modified_event(dict_, self, False, old)
         del dict_[self.key]
 
     def get_history(self, state, dict_, passive=PASSIVE_OFF):
@@ -450,10 +449,9 @@ class ScalarAttributeImpl(AttributeImpl):
         else:
             old = dict_.get(self.key, NO_VALUE)
 
-        state.modified_event(dict_, self, False, old)
-
         if self.extensions:
             value = self.fire_replace_event(state, dict_, value, old, initiator)
+        state.modified_event(dict_, self, False, old)
         dict_[self.key] = value
 
     def fire_replace_event(self, state, dict_, value, previous, initiator):
@@ -520,14 +518,12 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl):
         if initiator is self:
             return
 
-        state.modified_event(dict_, self, True, NEVER_SET)
-        
         if self.extensions:
             old = self.get(state, dict_)
             value = self.fire_replace_event(state, dict_, value, old, initiator)
-            dict_[self.key] = value
-        else:
-            dict_[self.key] = value
+
+        state.modified_event(dict_, self, True, NEVER_SET)
+        dict_[self.key] = value
         state.mutable_dict[self.key] = value
 
 
@@ -584,17 +580,15 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         dict_[self.key] = value
 
     def fire_remove_event(self, state, dict_, value, initiator):
-        state.modified_event(dict_, self, False, value)
-
         if self.trackparent and value is not None:
             self.sethasparent(instance_state(value), False)
 
         for ext in self.extensions:
             ext.remove(state, value, initiator or self)
 
-    def fire_replace_event(self, state, dict_, value, previous, initiator):
-        state.modified_event(dict_, self, False, previous)
+        state.modified_event(dict_, self, False, value)
 
+    def fire_replace_event(self, state, dict_, value, previous, initiator):
         if self.trackparent:
             if previous is not value and previous not in (None, PASSIVE_NO_RESULT):
                 self.sethasparent(instance_state(previous), False)
@@ -602,6 +596,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         for ext in self.extensions:
             value = ext.set(state, value, previous, initiator or self)
 
+        state.modified_event(dict_, self, False, previous)
+
         if self.trackparent:
             if value is not None:
                 self.sethasparent(instance_state(value), True)
@@ -649,11 +645,11 @@ class CollectionAttributeImpl(AttributeImpl):
             return History.from_attribute(self, state, current)
 
     def fire_append_event(self, state, dict_, value, initiator):
-        state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
-
         for ext in self.extensions:
             value = ext.append(state, value, initiator or self)
 
+        state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
+
         if self.trackparent and value is not None:
             self.sethasparent(instance_state(value), True)
 
@@ -663,14 +659,14 @@ class CollectionAttributeImpl(AttributeImpl):
         state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
 
     def fire_remove_event(self, state, dict_, value, initiator):
-        state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
-
         if self.trackparent and value is not None:
             self.sethasparent(instance_state(value), False)
 
         for ext in self.extensions:
             ext.remove(state, value, initiator or self)
 
+        state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
+
     def delete(self, state, dict_):
         if self.key not in dict_:
             return
index 1aec6a02e03bf0ee3f7ccb16deb8ff1853fb2189..c312b37aa9ebef39e0e54f447bd42a1a955cfa49 100644 (file)
@@ -4,7 +4,7 @@ from sqlalchemy.orm.collections import collection
 from sqlalchemy.orm.interfaces import AttributeExtension
 from sqlalchemy import exc as sa_exc
 from sqlalchemy.test import *
-from sqlalchemy.test.testing import eq_, assert_raises
+from sqlalchemy.test.testing import eq_, ne_, assert_raises
 from test.orm import _base
 from sqlalchemy.test.util import gc_collect
 from sqlalchemy.util import cmp, jython
@@ -221,6 +221,103 @@ class AttributesTest(_base.ORMTest):
         u.addresses.append(a)
         self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com')
     
+    def test_extension_commit_attr(self):
+        """test that an extension which commits attribute history
+        maintains the end-result history.
+        
+        This won't work in conjunction with some unitofwork extensions.
+        
+        """
+        
+        class Foo(_base.ComparableEntity):
+            pass
+        class Bar(_base.ComparableEntity):
+            pass
+        
+        class ReceiveEvents(AttributeExtension):
+            def __init__(self, key):
+                self.key = key
+                
+            def append(self, state, child, initiator):
+                if commit:
+                    state.commit_all(state.dict)
+                return child
+
+            def remove(self, state, child, initiator):
+                if commit:
+                    state.commit_all(state.dict)
+                return child
+
+            def set(self, state, child, oldchild, initiator):
+                if commit:
+                    state.commit_all(state.dict)
+                return child
+
+        attributes.register_class(Foo)
+        attributes.register_class(Bar)
+
+        b1, b2, b3, b4 = Bar(id='b1'), Bar(id='b2'), Bar(id='b3'), Bar(id='b4')
+        
+        def loadcollection(**kw):
+            if kw.get('passive') is attributes.PASSIVE_NO_FETCH:
+                return attributes.PASSIVE_NO_RESULT
+            return [b1, b2]
+        
+        def loadscalar(**kw):
+            if kw.get('passive') is attributes.PASSIVE_NO_FETCH:
+                return attributes.PASSIVE_NO_RESULT
+            return b2
+            
+        attributes.register_attribute(Foo, 'bars', 
+                               uselist=True, 
+                               useobject=True, 
+                               callable_=lambda o:loadcollection,
+                               extension=[ReceiveEvents('bars')])
+                               
+        attributes.register_attribute(Foo, 'bar', 
+                              uselist=False, 
+                              useobject=True, 
+                              callable_=lambda o:loadscalar,
+                              extension=[ReceiveEvents('bar')])
+                              
+        attributes.register_attribute(Foo, 'scalar', 
+                            uselist=False, 
+                            useobject=False, extension=[ReceiveEvents('scalar')])
+        
+            
+        def create_hist():
+            def hist(key, shouldmatch, fn, *arg):
+                attributes.instance_state(f1).commit_all(attributes.instance_dict(f1))
+                fn(*arg)
+                histories.append((shouldmatch, attributes.get_history(f1, key)))
+
+            f1 = Foo()
+            hist('bars', True, f1.bars.append, b3)
+            hist('bars', True, f1.bars.append, b4)
+            hist('bars', False, f1.bars.remove, b2)
+            hist('bar', True, setattr, f1, 'bar', b3)
+            hist('bar', True, setattr, f1, 'bar', None)
+            hist('bar', True, setattr, f1, 'bar', b4)
+            hist('scalar', True, setattr, f1, 'scalar', 5)
+            hist('scalar', True, setattr, f1, 'scalar', None)
+            hist('scalar', True, setattr, f1, 'scalar', 4)
+        
+        histories = []
+        commit = False
+        create_hist()
+        without_commit = list(histories)
+        histories[:] = []
+        commit = True
+        create_hist()
+        with_commit = histories
+        for without, with_ in zip(without_commit, with_commit):
+            shouldmatch, woc = without
+            shouldmatch, wic = with_
+            if shouldmatch:
+                eq_(woc, wic)
+            else:
+                ne_(woc, wic)
+        
     def test_extension_lazyload_assertion(self):
         class Foo(_base.BasicEntity):
             pass