From: Mike Bayer Date: Tue, 9 Sep 2008 15:02:13 +0000 (+0000) Subject: - Changes made to new, dirty and deleted X-Git-Tag: rel_0_4_8~7 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=f17acc1caeefd5f967499d96bbd4ba36954171a3;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Changes made to new, dirty and deleted collections in SessionExtension.before_flush() will take effect for that flush. --- diff --git a/CHANGES b/CHANGES index c796f53465..5ba1b1eb89 100644 --- a/CHANGES +++ b/CHANGES @@ -8,6 +8,11 @@ CHANGES with "A=B" versus "B=A" leading to errors [ticket:1039] + - Changes made to new, dirty and deleted + collections in + SessionExtension.before_flush() will take + effect for that flush. + - mysql - Added MSMediumInteger type [ticket:1146]. diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 66b68770d6..e48c9e4cba 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -183,15 +183,19 @@ class UnitOfWork(object): if not dirty and not self.deleted and not self.new: return - deleted = util.Set(self.deleted) - new = util.Set(self.new) - - dirty = util.Set(dirty).difference(deleted) - flush_context = UOWTransaction(self, session) if session.extension is not None: session.extension.before_flush(session, flush_context, objects) + dirty = [x for x in self.identity_map.all_states() + if x.modified + or (x.class_._class_state.has_mutable_scalars and x.is_modified()) + ] + + deleted = util.Set(self.deleted) + new = util.Set(self.new) + + dirty = util.Set(dirty).difference(deleted) # create the set of all objects we want to operate upon if objects: diff --git a/test/orm/session.py b/test/orm/session.py index 1a73d9b13c..61b81671a5 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -927,6 +927,36 @@ class SessionTest(TestBase, AssertsExecutionResults): conn = sess.connection() assert log == ['after_begin'] + def test_before_flush_affects_dirty(self): + class User(fixtures.Base): + pass + mapper(User, users) + + class MyExt(SessionExtension): + def before_flush(self, session, flush_context, objects): + for obj in list(session.identity_map.values()): + obj.user_name += " modified" + + sess = create_session(extension = MyExt(), autoflush=True) + u = User(user_name='u1') + sess.add(u) + sess.flush() + self.assertEquals(sess.query(User).order_by(User.user_name).all(), + [ + User(user_name='u1') + ] + ) + + sess.add(User(user_name='u2')) + sess.flush() + sess.clear() + self.assertEquals(sess.query(User).order_by(User.user_name).all(), + [ + User(user_name='u1 modified'), + User(user_name='u2') + ] + ) + def test_pickled_update(self): mapper(User, users) sess1 = create_session()