From: Mike Bayer Date: Sun, 17 Aug 2008 22:21:23 +0000 (+0000) Subject: - The before_flush() hook on SessionExtension takes place X-Git-Tag: rel_0_5rc1~55 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1b65c7eed5166d07cc145063f828beeb1d14cf02;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - The before_flush() hook on SessionExtension takes place before the list of new/dirty/deleted is calculated for the final time, allowing routines within before_flush() to further change the state of the Session before the flush proceeds. [ticket:1128] - Reentrant calls to flush() raise an error. This also serves as a rudimentary, but not foolproof, check against concurrent calls to Session.flush(). --- diff --git a/CHANGES b/CHANGES index afa8cc55a4..50a091daac 100644 --- a/CHANGES +++ b/CHANGES @@ -33,9 +33,19 @@ CHANGES - class.someprop.in_() raises NotImplementedError pending the implementation of "in_" for relation [ticket:1140] - - fixed primary key update for many-to-many collections + - Fixed primary key update for many-to-many collections where the collection had not been loaded yet [ticket:1127] + + - The before_flush() hook on SessionExtension takes place + before the list of new/dirty/deleted is calculated for the + final time, allowing routines within before_flush() to + further change the state of the Session before the flush + proceeds. [ticket:1128] + + - Reentrant calls to flush() raise an error. This also + serves as a rudimentary, but not foolproof, check against + concurrent calls to Session.flush(). - Improved the behavior of query.join() when joining to joined-table inheritance subclasses, using explicit join diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index ae356126e0..cea61f28e9 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -326,7 +326,6 @@ class Mapper(object): def dispose(self): # Disable any attribute-based compilation. self.compiled = True - manager = self.class_manager if not self.non_primary and manager.mapper is self: diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 3e3b65664a..ca9be7ffb1 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -218,11 +218,10 @@ class SessionTransaction(object): """ - def __init__(self, session, parent=None, nested=False, reentrant_flush=False): + def __init__(self, session, parent=None, nested=False): self.session = session self._connections = {} self._parent = parent - self.reentrant_flush = reentrant_flush self.nested = nested self._active = True self._prepared = False @@ -258,10 +257,10 @@ class SessionTransaction(object): engine = self.session.get_bind(bindkey, **kwargs) return self._connection_for_bind(engine) - def _begin(self, reentrant_flush=False, nested=False): + def _begin(self, nested=False): self._assert_is_active() return SessionTransaction( - self.session, self, reentrant_flush=reentrant_flush, nested=nested) + self.session, self, nested=nested) def _iterate_parents(self, upto=None): if self._parent is upto: @@ -279,7 +278,7 @@ class SessionTransaction(object): self._deleted = self._parent._deleted return - if not self.reentrant_flush: + if not self.session._flushing: self.session.flush() self._new = weakref.WeakKeyDictionary() @@ -356,7 +355,7 @@ class SessionTransaction(object): for subtransaction in stx._iterate_parents(upto=self): subtransaction.commit() - if not self.reentrant_flush: + if not self.session._flushing: self.session.flush() if self._parent is None and self.session.twophase: @@ -546,6 +545,7 @@ class Session(object): self._deleted = {} # same self.bind = bind self.__binds = {} + self._flushing = False self.transaction = None self.hash_key = id(self) self.autoflush = autoflush @@ -570,7 +570,7 @@ class Session(object): self.begin() _sessions[self.hash_key] = self - def begin(self, subtransactions=False, nested=False, _reentrant_flush=False): + def begin(self, subtransactions=False, nested=False): """Begin a transaction on this Session. If this Session is already within a transaction, either a plain @@ -596,14 +596,14 @@ class Session(object): if self.transaction is not None: if subtransactions or nested: self.transaction = self.transaction._begin( - nested=nested, reentrant_flush=_reentrant_flush) + nested=nested) else: raise sa_exc.InvalidRequestError( "A transaction is already begun. Use subtransactions=True " "to allow subtransactions.") else: self.transaction = SessionTransaction( - self, nested=nested, reentrant_flush=_reentrant_flush) + self, nested=nested) return self.transaction # needed for __enter__/__exit__ hook def begin_nested(self): @@ -912,7 +912,7 @@ class Session(object): return self._query_cls(entities, self, **kwargs) def _autoflush(self): - if self.autoflush and (self.transaction is None or not self.transaction.reentrant_flush): + if self.autoflush and not self._flushing: self.flush() def _finalize_loaded(self, states): @@ -1320,7 +1320,6 @@ class Session(object): def _contains_state(self, state): return state in self._new or self.identity_map.contains_state(state) - def flush(self, objects=None): """Flush all the object changes to the database. @@ -1343,6 +1342,17 @@ class Session(object): to only these objects, rather than all pending changes. """ + + if self._flushing: + raise sa_exc.InvalidRequestError("Session is already flushing") + + try: + self._flushing = True + self._flush(objects) + finally: + self._flushing = False + + def _flush(self, objects=None): if (not self.identity_map.check_modified() and not self._deleted and not self._new): return @@ -1352,16 +1362,16 @@ class Session(object): self.identity_map.modified = False return - deleted = set(self._deleted) - new = set(self._new) - - dirty = set(dirty).difference(deleted) - flush_context = UOWTransaction(self) if self.extension is not None: self.extension.before_flush(self, flush_context, objects) + deleted = set(self._deleted) + new = set(self._new) + + dirty = set(dirty).difference(deleted) + # create the set of all objects we want to operate upon if objects: # specific list passed in @@ -1404,7 +1414,7 @@ class Session(object): return flush_context.transaction = transaction = self.begin( - subtransactions=True, _reentrant_flush=True) + subtransactions=True) try: flush_context.execute() diff --git a/test/orm/session.py b/test/orm/session.py index 09b9df05d8..0282d28fd5 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -910,6 +910,72 @@ class SessionTest(_fixtures.FixtureTest): conn = sess.connection() assert log == ['after_begin'] + @testing.resolve_artifact_names + def test_before_flush(self): + """test that the flush plan can be affected during before_flush()""" + + mapper(User, users) + + class MyExt(sa.orm.session.SessionExtension): + def before_flush(self, session, flush_context, objects): + for obj in list(session.new) + list(session.dirty): + if isinstance(obj, User): + session.add(User(name='another %s' % obj.name)) + for obj in list(session.deleted): + if isinstance(obj, User): + x = session.query(User).filter(User.name=='another %s' % obj.name).one() + session.delete(x) + + sess = create_session(extension = MyExt(), autoflush=True) + u = User(name='u1') + sess.add(u) + sess.flush() + self.assertEquals(sess.query(User).order_by(User.name).all(), + [ + User(name='another u1'), + User(name='u1') + ] + ) + + sess.flush() + self.assertEquals(sess.query(User).order_by(User.name).all(), + [ + User(name='another u1'), + User(name='u1') + ] + ) + + u.name='u2' + sess.flush() + self.assertEquals(sess.query(User).order_by(User.name).all(), + [ + User(name='another u1'), + User(name='another u2'), + User(name='u2') + ] + ) + + sess.delete(u) + sess.flush() + self.assertEquals(sess.query(User).order_by(User.name).all(), + [ + User(name='another u1'), + ] + ) + + @testing.resolve_artifact_names + def test_reentrant_flush(self): + + mapper(User, users) + + class MyExt(sa.orm.session.SessionExtension): + def before_flush(s, session, flush_context, objects): + session.flush() + + sess = create_session(extension=MyExt()) + sess.add(User(name='foo')) + self.assertRaisesMessage(sa.exc.InvalidRequestError, "already flushing", sess.flush) + @testing.resolve_artifact_names def test_pickled_update(self): mapper(User, users)