From 53faada96c484b11b7a4632dc061dbce3661dbbe Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 14 Jun 2006 15:50:40 +0000 Subject: [PATCH] fixed nested rollbacks --- lib/sqlalchemy/engine/base.py | 2 +- test/engine/transaction.py | 36 ++++++++++++++++++++++++++++++----- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index dafef729ae..d9e3f4ed83 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -361,7 +361,7 @@ class Transaction(object): is_active = property(lambda s:s.__is_active) def rollback(self): if not self.__parent.__is_active: - raise exceptions.InvalidRequestError("This transaction is inactive") + return if self.__parent is self: self.__connection._rollback_impl() self.__is_active = False diff --git a/test/engine/transaction.py b/test/engine/transaction.py index a06701ff3d..b8f8af96ed 100644 --- a/test/engine/transaction.py +++ b/test/engine/transaction.py @@ -35,6 +35,36 @@ class TransactionTest(testbase.PersistTest): assert len(result.fetchall()) == 0 connection.close() + @testbase.unsupported('mysql') + def testnestedrollback(self): + connection = testbase.db.connect() + + try: + transaction = connection.begin() + try: + connection.execute(users.insert(), user_id=1, user_name='user1') + connection.execute(users.insert(), user_id=2, user_name='user2') + connection.execute(users.insert(), user_id=3, user_name='user3') + trans2 = connection.begin() + try: + connection.execute(users.insert(), user_id=4, user_name='user4') + connection.execute(users.insert(), user_id=5, user_name='user5') + raise Exception("uh oh") + trans2.commit() + except: + trans2.rollback() + raise + transaction.rollback() + except Exception, e: + transaction.rollback() + raise + except Exception, e: + try: + assert str(e) == 'uh oh' # and not "This transaction is inactive" + finally: + connection.close() + + @testbase.unsupported('mysql') def testnesting(self): connection = testbase.db.connect() @@ -206,19 +236,15 @@ class TLTransactionTest(testbase.PersistTest): mapper(User, users) sess = create_session(bind_to=tlengine) - print "STEP1" tlengine.begin() - print "STEP2" u = User() sess.save(u) - print "STEP3" sess.flush() - print "STEP4" tlengine.commit() - print "STEP5" finally: clear_mappers() + def testconnections(self): """tests that contextual_connect is threadlocal""" c1 = tlengine.contextual_connect() -- 2.47.2