From: Mike Bayer Date: Fri, 17 Aug 2007 19:13:51 +0000 (+0000) Subject: - threadlocal TLConnection, when closes for real, forces parent TLSession X-Git-Tag: rel_0_4beta4~51 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=27cf3a232dc334f905ebb4ab605c419a0d0b7219;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - threadlocal TLConnection, when closes for real, forces parent TLSession to rollback/dispose of transaction --- diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 62f402be5c..982be3f052 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -28,6 +28,11 @@ class TLSession(object): pass self.__tcount = 0 + def _conn_closed(self): + if self.__tcount == 1: + self.__trans._trans.rollback() + self.reset() + def in_transaction(self): return self.__tcount > 0 @@ -104,6 +109,7 @@ class TLConnection(base.Connection): def close(self): if self.__opencount == 1: base.Connection.close(self) + self.__session._conn_closed() self.__opencount -= 1 def _force_close(self): diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 6263a2e525..5b39be7d0c 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -241,7 +241,6 @@ class SessionTransaction(object): return for t in util.Set(self.__connections.values()): if t[2]: - # closing the connection will also issue a rollback() t[0].close() self.session.transaction = None diff --git a/test/engine/transaction.py b/test/engine/transaction.py index bd912a4df0..7adefbac2b 100644 --- a/test/engine/transaction.py +++ b/test/engine/transaction.py @@ -311,6 +311,19 @@ class TLTransactionTest(PersistTest): def tearDownAll(self): users.drop(tlengine) tlengine.dispose() + + def testclose(self): + """test that when connections are closed for real, transactions are rolled back and disposed.""" + + c = tlengine.contextual_connect() + c.begin() + assert tlengine.session.in_transaction() + assert hasattr(tlengine.session, '_TLSession__transaction') + assert hasattr(tlengine.session, '_TLSession__trans') + c.close() + assert not tlengine.session.in_transaction() + assert not hasattr(tlengine.session, '_TLSession__transaction') + assert not hasattr(tlengine.session, '_TLSession__trans') def testrollback(self): """test a basic rollback""" @@ -343,6 +356,8 @@ class TLTransactionTest(PersistTest): external_connection.close() def testcommits(self): + assert tlengine.connect().execute("select count(1) from query_users").scalar() == 0 + connection = tlengine.contextual_connect() transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name='user1') @@ -355,7 +370,8 @@ class TLTransactionTest(PersistTest): transaction = connection.begin() result = connection.execute("select * from query_users") - assert len(result.fetchall()) == 3 + l = result.fetchall() + assert len(l) == 3, "expected 3 got %d" % len(l) transaction.commit() def testrollback_off_conn(self):