]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
more fixes to transaction nesting, interacts better with close() statement
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Jun 2006 18:25:35 +0000 (18:25 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Jun 2006 18:25:35 +0000 (18:25 +0000)
CHANGES
lib/sqlalchemy/engine/threadlocal.py
test/engine/transaction.py

diff --git a/CHANGES b/CHANGES
index d22e19fc3a5d4a850959a6f60ad1ccc696fc5fde..68007a4e6545ecfbe34f26b11d52e66d8ec089d1 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -23,6 +23,7 @@ instance called "default_metadata".  leaving MetaData arg to Table
 out will use the default metadata.
 - fixes to session cascade behavior, entity_name propigation
 - reorganized unittests into subdirectories
+- more fixes to threadlocal connection nesting patterns
 
 0.2.1
 - "pool" argument to create_engine() properly propigates
index e78098fa3ecad3f1c4beb43db2f1da53f580c460..040c6bd9fa1b1f7cfcc7edb3961dc0ac7d644f2b 100644 (file)
@@ -12,7 +12,7 @@ class TLSession(object):
         self.__tcount = 0
     def get_connection(self, close_with_result=False):
         try:
-            return self.__transaction
+            return self.__transaction._increment_connect()
         except AttributeError:
             return TLConnection(self, close_with_result=close_with_result)
     def set_transaction(self, tlconnection, trans):
@@ -38,13 +38,13 @@ class TLSession(object):
     def rollback(self):
         if self.__tcount > 0:
             try:
-                self.__trans.rollback()
+                self.__trans._rollback_impl()
             finally:
                 self.reset()
     def commit(self):
         if self.__tcount == 1:
             try:
-                self.__trans.commit()
+                self.__trans._commit_impl()
             finally:
                 self.reset()
         elif self.__tcount > 1:
@@ -56,25 +56,31 @@ class TLConnection(base.Connection):
     def __init__(self, session, close_with_result):
         base.Connection.__init__(self, session.engine, close_with_result=close_with_result)
         self.__session = session
+        self.__opencount = 1
     session = property(lambda s:s.__session)
+    def _increment_connect(self):
+        self.__opencount += 1
+        return self
     def _create_transaction(self, parent):
         return TLTransaction(self, parent)
     def _begin(self):
         return base.Connection.begin(self)
     def begin(self):
-        trans = base.Connection.begin(self)
-        self.__session.set_transaction(self, trans)
-        return trans
-
+        return self.session.begin()
+    def close(self):
+        if self.__opencount == 1:
+            base.Connection.close(self)
+        self.__opencount -= 1
+        
 class TLTransaction(base.Transaction):
-    def commit(self):
+    def _commit_impl(self):
         base.Transaction.commit(self)
-        if not self.is_active:
-            self.connection.session.reset()
-    def rollback(self):
+    def _rollback_impl(self):
         base.Transaction.rollback(self)
-        if not self.is_active:
-            self.connection.session.reset()
+    def commit(self):
+        self.connection.session.commit()
+    def rollback(self):
+        self.connection.session.rollback()
             
 class TLEngine(base.ComposedSQLEngine):
     """a ComposedSQLEngine that includes support for thread-local managed transactions.  This engine
index 408c9dc998c56e62476db490828c2a87a0a0fd6d..a06701ff3da5d3b08db704cd2a9feb638a539614 100644 (file)
@@ -198,13 +198,34 @@ class TLTransactionTest(testbase.PersistTest):
             self.assert_(external_connection.scalar("select count(1) from query_users") == 0)
         finally:
             external_connection.close()
-    
+
+    def testsessionnesting(self):
+        class User(object):
+            pass
+        try:
+            mapper(User, users)
+
+            sess = create_session(bind_to=tlengine)
+            print "STEP1"
+            tlengine.begin()
+            print "STEP2"
+            u = User()
+            sess.save(u)
+            print "STEP3"
+            sess.flush()
+            print "STEP4"
+            tlengine.commit()
+            print "STEP5"
+        finally:
+            clear_mappers()
+
     def testconnections(self):
         """tests that contextual_connect is threadlocal"""
         c1 = tlengine.contextual_connect()
         c2 = tlengine.contextual_connect()
         assert c1.connection is c2.connection
-        c1.close()
+        c2.close()
+        assert c1.connection.connection is not None
         
 if __name__ == "__main__":
     testbase.main()