]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
TLEngine needed a partial rewrite....
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 May 2006 17:46:45 +0000 (17:46 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 May 2006 17:46:45 +0000 (17:46 +0000)
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/engine/threadlocal.py
test/transaction.py

index 83dfad04fca3d485f9a2c62133205656137a9120..16ca5299ecd28a106cf8fba4834e56663635969c 100644 (file)
@@ -194,6 +194,8 @@ class Connection(Connectable):
             return self.__transaction
         else:
             return self._create_transaction(self.__transaction)
+    def in_transaction(self):
+        return self.__transaction is not None
     def _begin_impl(self):
         if self.__engine.echo:
             self.__engine.log("BEGIN")
@@ -210,13 +212,13 @@ class Connection(Connectable):
         """when no Transaction is present, this is called after executions to provide "autocommit" behavior."""
         # TODO: have the dialect determine if autocommit can be set on the connection directly without this 
         # extra step
-        if self.__transaction is None and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP', statement.lstrip().upper()):
+        if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP', statement.lstrip().upper()):
             self._commit_impl()
     def close(self):
         if self.__connection is not None:
             self.__connection.close()
             self.__connection = None
-    def scalar(self, object, parameters, **kwargs):
+    def scalar(self, object, parameters=None, **kwargs):
         row = self.execute(object, parameters, **kwargs).fetchone()
         if row is not None:
             return row[0]
@@ -406,6 +408,10 @@ class ComposedSQLEngine(sql.Engine, Connectable):
                 conn.close()
     
     def transaction(self, callable_, connection=None, *args, **kwargs):
+        """executes the given function within a transaction boundary.  this is a shortcut for
+        explicitly calling begin() and commit() and optionally rollback() when execptions are raised.
+        The given *args and **kwargs will be passed to the function, as well as the Connection used 
+        in the transaction."""
         if connection is None:
             conn = self.contextual_connect()
         else:
index fbd9b8bab141b47c0796518c490b1e9537d457dc..51496a67df4b70913c74895f0f24aca669bcdba5 100644 (file)
@@ -54,7 +54,7 @@ class ThreadLocalEngineStrategy(EngineStrategy):
         dialect = module.dialect(**kwargs)
 
         poolargs = {}
-        for key in (('echo', 'echo_pool'), ('pool_size', 'pool_size'), ('max_overflow', 'max_overflow'), ('poolclass', 'poolclass'), ('pool_timeout','timeout')):
+        for key in (('echo', 'echo_pool'), ('pool_size', 'pool_size'), ('max_overflow', 'max_overflow'), ('poolclass', 'poolclass'), ('pool_timeout','timeout'), ('pool', 'pool')):
             if kwargs.has_key(key[0]):
                 poolargs[key[1]] = kwargs[key[0]]
         poolclass = getattr(module, 'poolclass', None)
index 85628c208d1e2bfa1be93fdd5ab76fae1d3696ff..610eedeaaa0fb4cab64aa3fc22d0d460e89a23cd 100644 (file)
@@ -6,39 +6,42 @@ import base, default
 will return the same connection for the same thread. also provides begin/commit methods on the engine itself
 which correspond to a thread-local transaction."""
 
-class TLTransaction(base.Transaction):
-    def rollback(self):
+class TLSession(object):
+    def __init__(self, engine):
+        self.engine = engine
+        self.__tcount = 0
+    def get_connection(self, close_with_result=False):
         try:
-            base.Transaction.rollback(self)
-        finally:
+            return self.__transaction
+        except AttributeError:
+            return base.Connection(self.engine, close_with_result=close_with_result)
+    def begin(self):
+        if self.__tcount == 0:
+            self.__transaction = self.get_connection()
+            self.__trans = self.__transaction.begin()
+        self.__tcount += 1
+    def rollback(self):
+        if self.__tcount > 0:
             try:
-                del self.connection.engine.context.transaction
-            except AttributeError:
-                pass
+                self.__trans.rollback()
+            finally:
+                del self.__transaction
+                del self.__trans
+                self.__tcount = 0
     def commit(self):
-        try:
-            base.Transaction.commit(self)
-            stack = self.connection.engine.context.transaction
-            stack.pop()
-            if len(stack) == 0:
-                del self.connection.engine.context.transaction
-        except:
+        if self.__tcount == 1:
             try:
-                del self.connection.engine.context.transaction
-            except AttributeError:
-                pass
-            raise
-            
-class TLConnection(base.Connection):
-    def _create_transaction(self, parent):
-        return TLTransaction(self, parent)
-    def begin(self):
-        t = base.Connection.begin(self)
-        if not hasattr(self.engine.context, 'transaction'):
-            self.engine.context.transaction = []
-        self.engine.context.transaction.append(t)
-        return t
-        
+                self._trans.commit()
+            finally:
+                del self.__transaction
+                del self._trans
+                self.__tcount = 0
+        elif self.__tcount > 1:
+            self.__tcount -= 1
+    def is_begun(self):
+        return self.__tcount > 0
+
+    
 class TLEngine(base.ComposedSQLEngine):
     """a ComposedSQLEngine that includes support for thread-local managed transactions.  This engine
     is better suited to be used with threadlocal Pool object."""
@@ -55,29 +58,23 @@ class TLEngine(base.ComposedSQLEngine):
         """returns a Connection that is not thread-locally scoped.  this is the equilvalent to calling
         "connect()" on a ComposedSQLEngine."""
         return base.Connection(self, self.connection_provider.unique_connection())
+
+    def _session(self):
+        if not hasattr(self.context, 'session'):
+            self.context.session = TLSession(self)
+        return self.context.session
+    session = property(_session, doc="returns the current thread's TLSession")
+
     def contextual_connect(self, **kwargs):
         """returns a TLConnection which is thread-locally scoped."""
-        return TLConnection(self, **kwargs)
+        return self.session.get_connection(**kwargs)
+        
     def begin(self):
-        return self.connect().begin()
+        return self.session.begin()
     def commit(self):
-        if hasattr(self.context, 'transaction'):
-            self.context.transaction[-1].commit()
+        self.session.commit()
     def rollback(self):
-        if hasattr(self.context, 'transaction'):
-            self.context.transaction[-1].rollback()
-    def transaction(self, func, *args, **kwargs):
-           """executes the given function within a transaction boundary.  this is a shortcut for
-           explicitly calling begin() and commit() and optionally rollback() when execptions are raised.
-           The given *args and **kwargs will be passed to the function as well, which could be handy
-           in constructing decorators."""
-           trans = self.begin()
-           try:
-               func(*args, **kwargs)
-           except:
-               trans.rollback()
-               raise
-           trans.commit()
+        self.session.rollback()
 
 class TLocalConnectionProvider(default.PoolConnectionProvider):
     def unique_connection(self):
index d76d7e0f407ff7e1c1248bf1ed557b745da5fe4a..244da164965e1bea0456dd92db76410f8f9180be 100644 (file)
@@ -5,6 +5,7 @@ import tables
 db = testbase.db
 from sqlalchemy import *
 
+
 class TransactionTest(testbase.PersistTest):
     def setUpAll(self):
         global users, metadata
@@ -34,6 +35,23 @@ class TransactionTest(testbase.PersistTest):
         assert len(result.fetchall()) == 0
         connection.close()
 
+    def testnesting(self):
+        connection = testbase.db.connect()
+        transaction = connection.begin()
+        connection.execute(users.insert(), user_id=1, user_name='user1')
+        connection.execute(users.insert(), user_id=2, user_name='user2')
+        connection.execute(users.insert(), user_id=3, user_name='user3')
+        trans2 = connection.begin()
+        connection.execute(users.insert(), user_id=4, user_name='user4')
+        connection.execute(users.insert(), user_id=5, user_name='user5')
+        trans2.commit()
+        transaction.rollback()
+        self.assert_(connection.scalar("select count(1) from query_users") == 0)
+
+        result = connection.execute("select * from query_users")
+        assert len(result.fetchall()) == 0
+        connection.close()
+        
 class AutoRollbackTest(testbase.PersistTest):
     def setUpAll(self):
         global metadata
@@ -58,6 +76,51 @@ class AutoRollbackTest(testbase.PersistTest):
         # comment out the rollback in pool/ConnectionFairy._close() to see !
         users.drop(conn2)
         conn2.close()
+
+class TLTransactionTest(testbase.PersistTest):
+    def setUpAll(self):
+        global users, metadata, tlengine
+        tlengine = create_engine(testbase.db_uri, strategy='threadlocal', echo=True)
+        metadata = MetaData()
+        users = Table('query_users', metadata,
+            Column('user_id', INT, primary_key = True),
+            Column('user_name', VARCHAR(20)),
+        )
+        users.create(tlengine)
+    def tearDown(self):
+        tlengine.execute(users.delete())
+    def tearDownAll(self):
+        users.drop(tlengine)
+        tlengine.dispose()
+        
+    @testbase.unsupported('mysql')
+    def testrollback(self):
+        """test a basic rollback"""
+        tlengine.begin()
+        tlengine.execute(users.insert(), user_id=1, user_name='user1')
+        tlengine.execute(users.insert(), user_id=2, user_name='user2')
+        tlengine.execute(users.insert(), user_id=3, user_name='user3')
+        tlengine.rollback()
+
+        result = tlengine.execute("select * from query_users")
+        assert len(result.fetchall()) == 0
+
+    @testbase.unsupported('mysql', 'sqlite')
+    def testnesting(self):
+        """test a basic rollback"""
+        external_connection = tlengine.connect()
+        self.assert_(external_connection.connection is not tlengine.contextual_connect().connection)
+        tlengine.begin()
+        tlengine.execute(users.insert(), user_id=1, user_name='user1')
+        tlengine.execute(users.insert(), user_id=2, user_name='user2')
+        tlengine.execute(users.insert(), user_id=3, user_name='user3')
+        tlengine.begin()
+        tlengine.execute(users.insert(), user_id=4, user_name='user4')
+        tlengine.execute(users.insert(), user_id=5, user_name='user5')
+        tlengine.commit()
+        tlengine.rollback()
+        self.assert_(external_connection.scalar("select count(1) from query_users") == 0)
+        external_connection.close()
         
 if __name__ == "__main__":
     testbase.main()