]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
recheck the dirty list if extensions are present
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 30 Aug 2008 18:30:53 +0000 (18:30 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 30 Aug 2008 18:30:53 +0000 (18:30 +0000)
lib/sqlalchemy/orm/session.py
test/orm/session.py

index 7195d5b1b027f773f238117550f559f95d7e576b..f2d4e150a81ddfa756a9b292d730174a1bdd3ef0 100644 (file)
@@ -1370,11 +1370,13 @@ class Session(object):
             self.identity_map.modified = False
             return
 
-        flush_context = UOWTransaction(self)
-
-        for ext in self.extensions:
-            ext.before_flush(self, flush_context, objects)
+        flush_context   = UOWTransaction(self)
 
+        if self.extensions:
+            for ext in self.extensions:
+                ext.before_flush(self, flush_context, objects)
+            dirty = self._dirty_states
+            
         deleted = set(self._deleted)
         new = set(self._new)
 
index f06b20515bc99bbd8c9baf91379babb73221ca09..c2e6e9d15c8969f0988b477e749b91179076cedc 100644 (file)
@@ -929,7 +929,7 @@ class SessionTest(_fixtures.FixtureTest):
                     if isinstance(obj, User):
                         x = session.query(User).filter(User.name=='another %s' % obj.name).one()
                         session.delete(x)
-                        
+                    
         sess = create_session(extension = MyExt(), autoflush=True)
         u = User(name='u1')
         sess.add(u)
@@ -967,6 +967,35 @@ class SessionTest(_fixtures.FixtureTest):
             ]
         )
 
+    @testing.resolve_artifact_names
+    def test_before_flush_affects_dirty(self):
+        mapper(User, users)
+        
+        class MyExt(sa.orm.session.SessionExtension):
+            def before_flush(self, session, flush_context, objects):
+                for obj in list(session.identity_map.values()):
+                    obj.name += " modified"
+                    
+        sess = create_session(extension = MyExt(), autoflush=True)
+        u = User(name='u1')
+        sess.add(u)
+        sess.flush()
+        self.assertEquals(sess.query(User).order_by(User.name).all(), 
+            [
+                User(name='u1')
+            ]
+        )
+        
+        sess.add(User(name='u2'))
+        sess.flush()
+        sess.clear()
+        self.assertEquals(sess.query(User).order_by(User.name).all(), 
+            [
+                User(name='u1 modified'),
+                User(name='u2')
+            ]
+        )
+
     @testing.resolve_artifact_names
     def test_reentrant_flush(self):