]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
more tlocal trans stuff
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 May 2006 20:27:08 +0000 (20:27 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 May 2006 20:27:08 +0000 (20:27 +0000)
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/threadlocal.py
test/transaction.py

index 16ca5299ecd28a106cf8fba4834e56663635969c..f3573711b2f85f03657365abe4407618d1c522f7 100644 (file)
@@ -345,6 +345,7 @@ class Transaction(object):
         if self.__parent is self:
             self.__connection._begin_impl()
     connection = property(lambda s:s.__connection, doc="The Connection object referenced by this Transaction")
+    is_active = property(lambda s:s.__is_active)
     def rollback(self):
         if not self.__parent.__is_active:
             raise exceptions.InvalidRequestError("This transaction is inactive")
index 84e5a7dc427b5839527a1940fdf716013011bb10..000a854b23c090afee63dd167af86d9b2add6bb7 100644 (file)
@@ -20,27 +20,32 @@ class TLSession(object):
             self.__transaction = tlconnection
             self.__trans = trans
         self.__tcount += 1
+    def reset(self):
+        try:
+            del self.__transaction
+            del self.__trans
+        except AttributeError:
+            pass
+        self.__tcount = 0
+        
     def begin(self):
         if self.__tcount == 0:
             self.__transaction = self.get_connection()
-            self.__trans = self.__transaction.begin()
+            self.__trans = self.__transaction._begin()
         self.__tcount += 1
+        return self.__trans
     def rollback(self):
         if self.__tcount > 0:
             try:
                 self.__trans.rollback()
             finally:
-                del self.__transaction
-                del self.__trans
-                self.__tcount = 0
+                self.reset()
     def commit(self):
         if self.__tcount == 1:
             try:
                 self.__trans.commit()
             finally:
-                del self.__transaction
-                del self.__trans
-                self.__tcount = 0
+                self.reset()
         elif self.__tcount > 1:
             self.__tcount -= 1
     def is_begun(self):
@@ -50,8 +55,28 @@ class TLConnection(base.Connection):
     def __init__(self, session, close_with_result):
         base.Connection.__init__(self, session.engine, close_with_result=close_with_result)
         self.__session = session
-    # TODO: get begin() to communicate with the Session to maintain the same transactional state
-       
+    session = property(lambda s:s.__session)
+    def _create_transaction(self, parent):
+        return TLTransaction(self, parent)
+    def _begin(self):
+        return base.Connection.begin(self)
+    def begin(self):
+        trans = base.Connection.begin(self)
+        self.__session.set_transaction(self, trans)
+        return trans
+
+class TLTransaction(base.Transaction):
+    def commit(self):
+        print "TL COMMIT"
+        base.Transaction.commit(self)
+        if not self.is_active:
+            print "RESET"
+            self.connection.session.reset()
+    def rollback(self):
+        base.Transaction.rollback(self)
+        if not self.is_active:
+            self.connection.session.reset()
+            
 class TLEngine(base.ComposedSQLEngine):
     """a ComposedSQLEngine that includes support for thread-local managed transactions.  This engine
     is better suited to be used with threadlocal Pool object."""
index 42baff46d6f56790a6f4fd96735ede57d94e3901..32627b1182c44e707658a8b2a2b77d03161315ff 100644 (file)
@@ -143,6 +143,26 @@ class TLTransactionTest(testbase.PersistTest):
             self.assert_(external_connection.scalar("select count(1) from query_users") == 0)
         finally:
             external_connection.close()
+
+    def testexplicitnesting(self):
+        """tests nesting of tranacstions"""
+        external_connection = tlengine.connect()
+        self.assert_(external_connection.connection is not tlengine.contextual_connect().connection)
+        conn = tlengine.contextual_connect()
+        trans = conn.begin()
+        tlengine.execute(users.insert(), user_id=1, user_name='user1')
+        tlengine.execute(users.insert(), user_id=2, user_name='user2')
+        tlengine.execute(users.insert(), user_id=3, user_name='user3')
+        tlengine.begin()
+        tlengine.execute(users.insert(), user_id=4, user_name='user4')
+        tlengine.execute(users.insert(), user_id=5, user_name='user5')
+        tlengine.commit()
+        trans.rollback()
+        conn.close()
+        try:
+            self.assert_(external_connection.scalar("select count(1) from query_users") == 0)
+        finally:
+            external_connection.close()
     
     def testconnections(self):
         """tests that contextual_connect is threadlocal"""