]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added "after_begin()" hook to Session
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 May 2008 00:47:36 +0000 (00:47 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 May 2008 00:47:36 +0000 (00:47 +0000)
- Session.rollback() will rollback on a prepared session

lib/sqlalchemy/orm/session.py
test/orm/session.py

index 076727486e026c1e3f488d69d2aac4fab6576c94..6c27f082ee363ccc4911fd1fc598a959811602b1 100644 (file)
@@ -127,6 +127,13 @@ class SessionExtension(object):
         state.  An actual commit() may or may not have occured, depending on whether or not
         the flush started its own transaction or participated in a larger transaction.
         """
+    
+    def after_begin(self, session, transaction, connection):
+        """Execute after a transaction is begun on a connection
+        
+        `transaction` is the SessionTransaction. This method is called after an
+        engine level transaction is begun on a connection.
+        """
 
 class SessionTransaction(object):
     """Represents a Session-level Transaction.
@@ -214,6 +221,8 @@ class SessionTransaction(object):
             transaction = conn.begin()
         
         self._connections[conn] = self._connections[conn.engine] = (conn, transaction, conn is not bind)
+        if self.session.extension is not None:
+            self.session.extension.after_begin(self.session, self, conn)
         return conn
 
     def prepare(self):
@@ -266,7 +275,7 @@ class SessionTransaction(object):
             for subtransaction in self.session.transaction._iterate_parents(upto=self):
                 subtransaction.close()
         
-        if self.is_active:
+        if self.is_active or self._prepared:
             for transaction in self._iterate_parents():
                 if transaction._parent is None or transaction.nested:
                     transaction._rollback_impl()
@@ -274,6 +283,7 @@ class SessionTransaction(object):
                     break
                 else:
                     transaction._deactivate()
+
         self.close()
         return self._parent
     
index c429add40bdc63da54ebeed54754912d7d77f750..49932f8d9d78718012a47b1feeb8e09f9631530f 100644 (file)
@@ -881,19 +881,20 @@ class SessionTest(TestBase, AssertsExecutionResults):
                 log.append('after_flush')
             def after_flush_postexec(self, session, flush_context):
                 log.append('after_flush_postexec')
+            def after_begin(self, session, transaction, connection):
+                log.append('after_begin')
         sess = create_session(extension = MyExt())
         u = User()
         sess.save(u)
         sess.flush()
-
-        assert log == ['before_flush', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec']
+        assert log == ['before_flush', 'after_begin', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec']
 
         log = []
         sess = create_session(transactional=True, extension=MyExt())
         u = User()
         sess.save(u)
         sess.flush()
-        assert log == ['before_flush', 'after_flush', 'after_flush_postexec']
+        assert log == ['before_flush', 'after_begin', 'after_flush', 'after_flush_postexec']
 
         log = []
         u.user_name = 'ed'
@@ -903,6 +904,11 @@ class SessionTest(TestBase, AssertsExecutionResults):
         log = []
         sess.commit()
         assert log == ['before_commit', 'after_commit']
+        
+        log = []
+        sess = create_session(transactional=True, extension=MyExt(), bind=testing.db)
+        conn = sess.connection()
+        assert log == ['after_begin']
 
     def test_pickled_update(self):
         mapper(User, users)