From: Mike Bayer Date: Sun, 28 May 2006 20:27:08 +0000 (+0000) Subject: more tlocal trans stuff X-Git-Tag: rel_0_2_1~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=123cb1064ce6ecd841d32df08f7969acbc154cd9;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git more tlocal trans stuff --- diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 16ca5299ec..f3573711b2 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -345,6 +345,7 @@ class Transaction(object): if self.__parent is self: self.__connection._begin_impl() connection = property(lambda s:s.__connection, doc="The Connection object referenced by this Transaction") + is_active = property(lambda s:s.__is_active) def rollback(self): if not self.__parent.__is_active: raise exceptions.InvalidRequestError("This transaction is inactive") diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 84e5a7dc42..000a854b23 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -20,27 +20,32 @@ class TLSession(object): self.__transaction = tlconnection self.__trans = trans self.__tcount += 1 + def reset(self): + try: + del self.__transaction + del self.__trans + except AttributeError: + pass + self.__tcount = 0 + def begin(self): if self.__tcount == 0: self.__transaction = self.get_connection() - self.__trans = self.__transaction.begin() + self.__trans = self.__transaction._begin() self.__tcount += 1 + return self.__trans def rollback(self): if self.__tcount > 0: try: self.__trans.rollback() finally: - del self.__transaction - del self.__trans - self.__tcount = 0 + self.reset() def commit(self): if self.__tcount == 1: try: self.__trans.commit() finally: - del self.__transaction - del self.__trans - self.__tcount = 0 + self.reset() elif self.__tcount > 1: self.__tcount -= 1 def is_begun(self): @@ -50,8 +55,28 @@ class TLConnection(base.Connection): def __init__(self, session, close_with_result): base.Connection.__init__(self, session.engine, close_with_result=close_with_result) self.__session = session - # TODO: get begin() to communicate with the Session to maintain the same transactional state - + session = property(lambda s:s.__session) + def _create_transaction(self, parent): + return TLTransaction(self, parent) + def _begin(self): + return base.Connection.begin(self) + def begin(self): + trans = base.Connection.begin(self) + self.__session.set_transaction(self, trans) + return trans + +class TLTransaction(base.Transaction): + def commit(self): + print "TL COMMIT" + base.Transaction.commit(self) + if not self.is_active: + print "RESET" + self.connection.session.reset() + def rollback(self): + base.Transaction.rollback(self) + if not self.is_active: + self.connection.session.reset() + class TLEngine(base.ComposedSQLEngine): """a ComposedSQLEngine that includes support for thread-local managed transactions. This engine is better suited to be used with threadlocal Pool object.""" diff --git a/test/transaction.py b/test/transaction.py index 42baff46d6..32627b1182 100644 --- a/test/transaction.py +++ b/test/transaction.py @@ -143,6 +143,26 @@ class TLTransactionTest(testbase.PersistTest): self.assert_(external_connection.scalar("select count(1) from query_users") == 0) finally: external_connection.close() + + def testexplicitnesting(self): + """tests nesting of tranacstions""" + external_connection = tlengine.connect() + self.assert_(external_connection.connection is not tlengine.contextual_connect().connection) + conn = tlengine.contextual_connect() + trans = conn.begin() + tlengine.execute(users.insert(), user_id=1, user_name='user1') + tlengine.execute(users.insert(), user_id=2, user_name='user2') + tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.begin() + tlengine.execute(users.insert(), user_id=4, user_name='user4') + tlengine.execute(users.insert(), user_id=5, user_name='user5') + tlengine.commit() + trans.rollback() + conn.close() + try: + self.assert_(external_connection.scalar("select count(1) from query_users") == 0) + finally: + external_connection.close() def testconnections(self): """tests that contextual_connect is threadlocal"""