From: Mike Bayer Date: Sat, 30 Aug 2008 18:30:53 +0000 (+0000) Subject: recheck the dirty list if extensions are present X-Git-Tag: rel_0_5rc1~29 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=91d8a876040f3d29493aef1b19f215f417848942;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git recheck the dirty list if extensions are present --- diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 7195d5b1b0..f2d4e150a8 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1370,11 +1370,13 @@ class Session(object): self.identity_map.modified = False return - flush_context = UOWTransaction(self) - - for ext in self.extensions: - ext.before_flush(self, flush_context, objects) + flush_context = UOWTransaction(self) + if self.extensions: + for ext in self.extensions: + ext.before_flush(self, flush_context, objects) + dirty = self._dirty_states + deleted = set(self._deleted) new = set(self._new) diff --git a/test/orm/session.py b/test/orm/session.py index f06b20515b..c2e6e9d15c 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -929,7 +929,7 @@ class SessionTest(_fixtures.FixtureTest): if isinstance(obj, User): x = session.query(User).filter(User.name=='another %s' % obj.name).one() session.delete(x) - + sess = create_session(extension = MyExt(), autoflush=True) u = User(name='u1') sess.add(u) @@ -967,6 +967,35 @@ class SessionTest(_fixtures.FixtureTest): ] ) + @testing.resolve_artifact_names + def test_before_flush_affects_dirty(self): + mapper(User, users) + + class MyExt(sa.orm.session.SessionExtension): + def before_flush(self, session, flush_context, objects): + for obj in list(session.identity_map.values()): + obj.name += " modified" + + sess = create_session(extension = MyExt(), autoflush=True) + u = User(name='u1') + sess.add(u) + sess.flush() + self.assertEquals(sess.query(User).order_by(User.name).all(), + [ + User(name='u1') + ] + ) + + sess.add(User(name='u2')) + sess.flush() + sess.clear() + self.assertEquals(sess.query(User).order_by(User.name).all(), + [ + User(name='u1 modified'), + User(name='u2') + ] + ) + @testing.resolve_artifact_names def test_reentrant_flush(self):