]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Changes made to new, dirty and deleted
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 9 Sep 2008 15:02:13 +0000 (15:02 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 9 Sep 2008 15:02:13 +0000 (15:02 +0000)
collections in
SessionExtension.before_flush() will take
effect for that flush.

CHANGES
lib/sqlalchemy/orm/unitofwork.py
test/orm/session.py

diff --git a/CHANGES b/CHANGES
index c796f534654e8611abe50e683c0d45ca7eb3dae8..5ba1b1eb899a120af35c9e4e7a67e40818828da7 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -8,6 +8,11 @@ CHANGES
       with "A=B" versus "B=A" leading to errors
       [ticket:1039]
 
+    - Changes made to new, dirty and deleted 
+      collections in
+      SessionExtension.before_flush() will take
+      effect for that flush.
+
 - mysql
     - Added MSMediumInteger type [ticket:1146].
 
index 66b68770d619ab87ce0d08ded2aa363c65d45657..e48c9e4cba91246cf8490fa7dbce59ac910227ff 100644 (file)
@@ -183,15 +183,19 @@ class UnitOfWork(object):
         if not dirty and not self.deleted and not self.new:
             return
         
-        deleted = util.Set(self.deleted)
-        new = util.Set(self.new)
-        
-        dirty = util.Set(dirty).difference(deleted)
-        
         flush_context = UOWTransaction(self, session)
 
         if session.extension is not None:
             session.extension.before_flush(session, flush_context, objects)
+            dirty = [x for x in self.identity_map.all_states()
+                if x.modified
+                or (x.class_._class_state.has_mutable_scalars and x.is_modified())
+            ]
+
+        deleted = util.Set(self.deleted)
+        new = util.Set(self.new)
+        
+        dirty = util.Set(dirty).difference(deleted)
 
         # create the set of all objects we want to operate upon
         if objects:
index 1a73d9b13cc6dd5a6380bc2e48241b3fc5b7b00f..61b81671a598db63aca7f006d4ea8de801a0adc1 100644 (file)
@@ -927,6 +927,36 @@ class SessionTest(TestBase, AssertsExecutionResults):
         conn = sess.connection()
         assert log == ['after_begin']
 
+    def test_before_flush_affects_dirty(self):
+        class User(fixtures.Base):
+            pass
+        mapper(User, users)
+
+        class MyExt(SessionExtension):
+            def before_flush(self, session, flush_context, objects):
+                for obj in list(session.identity_map.values()):
+                    obj.user_name += " modified"
+
+        sess = create_session(extension = MyExt(), autoflush=True)
+        u = User(user_name='u1')
+        sess.add(u)
+        sess.flush()
+        self.assertEquals(sess.query(User).order_by(User.user_name).all(),
+            [
+                User(user_name='u1')
+            ]
+        )
+
+        sess.add(User(user_name='u2'))
+        sess.flush()
+        sess.clear()
+        self.assertEquals(sess.query(User).order_by(User.user_name).all(),
+            [
+                User(user_name='u1 modified'),
+                User(user_name='u2')
+            ]
+        )
+
     def test_pickled_update(self):
         mapper(User, users)
         sess1 = create_session()