]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The before_flush() hook on SessionExtension takes place
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Aug 2008 22:21:23 +0000 (22:21 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Aug 2008 22:21:23 +0000 (22:21 +0000)
before the list of new/dirty/deleted is calculated for the
final time, allowing routines within before_flush() to
further change the state of the Session before the flush
proceeds.   [ticket:1128]

- Reentrant calls to flush() raise an error.  This also
serves as a rudimentary, but not foolproof, check against
concurrent calls to Session.flush().

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/session.py
test/orm/session.py

diff --git a/CHANGES b/CHANGES
index afa8cc55a4573c28d20f25def1ffa14ad172d749..50a091daacd039b4eb96184081c1fe7e8aef8c0d 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -33,9 +33,19 @@ CHANGES
     - class.someprop.in_() raises NotImplementedError pending
       the implementation of "in_" for relation [ticket:1140]
 
-    - fixed primary key update for many-to-many collections
+    - Fixed primary key update for many-to-many collections
       where the collection had not been loaded yet
       [ticket:1127]
+    
+    - The before_flush() hook on SessionExtension takes place
+      before the list of new/dirty/deleted is calculated for the
+      final time, allowing routines within before_flush() to
+      further change the state of the Session before the flush 
+      proceeds.   [ticket:1128]
+      
+    - Reentrant calls to flush() raise an error.  This also
+      serves as a rudimentary, but not foolproof, check against 
+      concurrent calls to Session.flush().
       
     - Improved the behavior of query.join() when joining to
       joined-table inheritance subclasses, using explicit join
index ae356126e0badd9254afbcce8577d39ed551560f..cea61f28e9f5c848b48382757d72a26f023fed9b 100644 (file)
@@ -326,7 +326,6 @@ class Mapper(object):
     def dispose(self):
         # Disable any attribute-based compilation.
         self.compiled = True
-
         manager = self.class_manager
 
         if not self.non_primary and manager.mapper is self:
index 3e3b65664aeaba4201eb044327a79d7142dc6cdd..ca9be7ffb1d3db90cc2ac8fe24e3ce3e33fbc2f3 100644 (file)
@@ -218,11 +218,10 @@ class SessionTransaction(object):
 
     """
 
-    def __init__(self, session, parent=None, nested=False, reentrant_flush=False):
+    def __init__(self, session, parent=None, nested=False):
         self.session = session
         self._connections = {}
         self._parent = parent
-        self.reentrant_flush = reentrant_flush
         self.nested = nested
         self._active = True
         self._prepared = False
@@ -258,10 +257,10 @@ class SessionTransaction(object):
         engine = self.session.get_bind(bindkey, **kwargs)
         return self._connection_for_bind(engine)
 
-    def _begin(self, reentrant_flush=False, nested=False):
+    def _begin(self, nested=False):
         self._assert_is_active()
         return SessionTransaction(
-            self.session, self, reentrant_flush=reentrant_flush, nested=nested)
+            self.session, self, nested=nested)
 
     def _iterate_parents(self, upto=None):
         if self._parent is upto:
@@ -279,7 +278,7 @@ class SessionTransaction(object):
             self._deleted = self._parent._deleted
             return
 
-        if not self.reentrant_flush:
+        if not self.session._flushing:
             self.session.flush()
 
         self._new = weakref.WeakKeyDictionary()
@@ -356,7 +355,7 @@ class SessionTransaction(object):
             for subtransaction in stx._iterate_parents(upto=self):
                 subtransaction.commit()
 
-        if not self.reentrant_flush:
+        if not self.session._flushing:
             self.session.flush()
 
         if self._parent is None and self.session.twophase:
@@ -546,6 +545,7 @@ class Session(object):
         self._deleted = {}  # same
         self.bind = bind
         self.__binds = {}
+        self._flushing = False
         self.transaction = None
         self.hash_key = id(self)
         self.autoflush = autoflush
@@ -570,7 +570,7 @@ class Session(object):
             self.begin()
         _sessions[self.hash_key] = self
 
-    def begin(self, subtransactions=False, nested=False, _reentrant_flush=False):
+    def begin(self, subtransactions=False, nested=False):
         """Begin a transaction on this Session.
 
         If this Session is already within a transaction, either a plain
@@ -596,14 +596,14 @@ class Session(object):
         if self.transaction is not None:
             if subtransactions or nested:
                 self.transaction = self.transaction._begin(
-                    nested=nested, reentrant_flush=_reentrant_flush)
+                    nested=nested)
             else:
                 raise sa_exc.InvalidRequestError(
                     "A transaction is already begun.  Use subtransactions=True "
                     "to allow subtransactions.")
         else:
             self.transaction = SessionTransaction(
-                self, nested=nested, reentrant_flush=_reentrant_flush)
+                self, nested=nested)
         return self.transaction  # needed for __enter__/__exit__ hook
 
     def begin_nested(self):
@@ -912,7 +912,7 @@ class Session(object):
         return self._query_cls(entities, self, **kwargs)
 
     def _autoflush(self):
-        if self.autoflush and (self.transaction is None or not self.transaction.reentrant_flush):
+        if self.autoflush and not self._flushing:
             self.flush()
 
     def _finalize_loaded(self, states):
@@ -1320,7 +1320,6 @@ class Session(object):
     def _contains_state(self, state):
         return state in self._new or self.identity_map.contains_state(state)
 
-
     def flush(self, objects=None):
         """Flush all the object changes to the database.
 
@@ -1343,6 +1342,17 @@ class Session(object):
           to only these objects, rather than all pending changes.
 
         """
+        
+        if self._flushing:
+            raise sa_exc.InvalidRequestError("Session is already flushing")
+            
+        try:
+            self._flushing = True
+            self._flush(objects)
+        finally:
+            self._flushing = False
+            
+    def _flush(self, objects=None):
         if (not self.identity_map.check_modified() and
             not self._deleted and not self._new):
             return
@@ -1352,16 +1362,16 @@ class Session(object):
             self.identity_map.modified = False
             return
 
-        deleted = set(self._deleted)
-        new = set(self._new)
-
-        dirty = set(dirty).difference(deleted)
-
         flush_context = UOWTransaction(self)
 
         if self.extension is not None:
             self.extension.before_flush(self, flush_context, objects)
 
+        deleted = set(self._deleted)
+        new = set(self._new)
+
+        dirty = set(dirty).difference(deleted)
+
         # create the set of all objects we want to operate upon
         if objects:
             # specific list passed in
@@ -1404,7 +1414,7 @@ class Session(object):
             return
 
         flush_context.transaction = transaction = self.begin(
-            subtransactions=True, _reentrant_flush=True)
+            subtransactions=True)
         try:
             flush_context.execute()
 
index 09b9df05d898e590acb72ef398151d50be5688b3..0282d28fd599e2e1567e844ab230eb32b8f8811b 100644 (file)
@@ -910,6 +910,72 @@ class SessionTest(_fixtures.FixtureTest):
         conn = sess.connection()
         assert log == ['after_begin']
 
+    @testing.resolve_artifact_names
+    def test_before_flush(self):
+        """test that the flush plan can be affected during before_flush()"""
+        
+        mapper(User, users)
+        
+        class MyExt(sa.orm.session.SessionExtension):
+            def before_flush(self, session, flush_context, objects):
+                for obj in list(session.new) + list(session.dirty):
+                    if isinstance(obj, User):
+                        session.add(User(name='another %s' % obj.name))
+                for obj in list(session.deleted):
+                    if isinstance(obj, User):
+                        x = session.query(User).filter(User.name=='another %s' % obj.name).one()
+                        session.delete(x)
+                        
+        sess = create_session(extension = MyExt(), autoflush=True)
+        u = User(name='u1')
+        sess.add(u)
+        sess.flush()
+        self.assertEquals(sess.query(User).order_by(User.name).all(), 
+            [
+                User(name='another u1'),
+                User(name='u1')
+            ]
+        )
+        
+        sess.flush()
+        self.assertEquals(sess.query(User).order_by(User.name).all(), 
+            [
+                User(name='another u1'),
+                User(name='u1')
+            ]
+        )
+
+        u.name='u2'
+        sess.flush()
+        self.assertEquals(sess.query(User).order_by(User.name).all(), 
+            [
+                User(name='another u1'),
+                User(name='another u2'),
+                User(name='u2')
+            ]
+        )
+
+        sess.delete(u)
+        sess.flush()
+        self.assertEquals(sess.query(User).order_by(User.name).all(), 
+            [
+                User(name='another u1'),
+            ]
+        )
+
+    @testing.resolve_artifact_names
+    def test_reentrant_flush(self):
+
+        mapper(User, users)
+
+        class MyExt(sa.orm.session.SessionExtension):
+            def before_flush(s, session, flush_context, objects):
+                session.flush()
+        
+        sess = create_session(extension=MyExt())
+        sess.add(User(name='foo'))
+        self.assertRaisesMessage(sa.exc.InvalidRequestError, "already flushing", sess.flush)
+
     @testing.resolve_artifact_names
     def test_pickled_update(self):
         mapper(User, users)