From 609c4f05a74c839c68d71e1fa90abb7b5fc1e897 Mon Sep 17 00:00:00 2001 From: Ants Aasma Date: Sat, 14 Jul 2007 00:36:05 +0000 Subject: [PATCH] add support for two phase commits, nested subtransactions and savepoints. refactors Transaction class into a hierarchy. --- CHANGES | 3 + lib/sqlalchemy/ansisql.py | 15 ++ lib/sqlalchemy/databases/mysql.py | 23 +++- lib/sqlalchemy/databases/postgres.py | 28 ++++ lib/sqlalchemy/engine/base.py | 199 +++++++++++++++++++++++---- lib/sqlalchemy/engine/default.py | 9 ++ lib/sqlalchemy/engine/threadlocal.py | 7 +- lib/sqlalchemy/sql.py | 15 ++ test/engine/transaction.py | 146 ++++++++++++++++++++ 9 files changed, 414 insertions(+), 31 deletions(-) diff --git a/CHANGES b/CHANGES index b1d68dad64..f975a3c20e 100644 --- a/CHANGES +++ b/CHANGES @@ -73,6 +73,9 @@ columns joined by a "group" to load as "undeferred". - sql + - added support for two phase commit, works with mysql and postgres so far. + - added a subtransaction implementation that uses savepoints. + - added support for savepoints. - DynamicMetaData has been renamed to ThreadLocalMetaData - BoundMetaData has been removed- regular MetaData is equivalent - significant architectural overhaul to SQL elements (ClauseElement). diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 188063a82d..b1e58dac22 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -793,7 +793,19 @@ class ANSICompiler(engine.Compiled): text += " WHERE " + self.get_str(delete_stmt._whereclause) self.strings[delete_stmt] = text + + def visit_savepoint(self, savepoint_stmt): + text = "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) + self.strings[savepoint_stmt] = text + def visit_rollback_to_savepoint(self, savepoint_stmt): + text = "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) + self.strings[savepoint_stmt] = text + + def visit_release_savepoint(self, savepoint_stmt): + text = "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident) + self.strings[savepoint_stmt] = text + def __str__(self): return self.get_str(self.statement) @@ -1080,6 +1092,9 @@ class ANSIIdentifierPreparer(object): def format_alias(self, alias): return self.__generic_obj_format(alias, alias.name) + def format_savepoint(self, savepoint): + return self.__generic_obj_format(savepoint, savepoint) + def format_table(self, table, use_schema=True, name=None): """Prepare a quoted table and schema name.""" diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 27d87847e2..7e59564445 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -949,7 +949,7 @@ class MySQLExecutionContext(default.DefaultExecutionContext): self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] def is_select(self): - return re.match(r'SELECT|SHOW|DESCRIBE', self.statement.lstrip(), re.I) is not None + return re.match(r'SELECT|SHOW|DESCRIBE|XA RECOVER', self.statement.lstrip(), re.I) is not None class MySQLDialect(ansisql.ANSIDialect): def __init__(self, **kwargs): @@ -1038,6 +1038,27 @@ class MySQLDialect(ansisql.ANSIDialect): except: pass + def do_begin_twophase(self, connection, xid): + connection.execute(sql.text("XA BEGIN :xid", bindparams=[sql.bindparam('xid',xid)])) + + def do_prepare_twophase(self, connection, xid): + connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)])) + connection.execute(sql.text("XA PREPARE :xid", bindparams=[sql.bindparam('xid',xid)])) + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + if not is_prepared: + connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)])) + connection.execute(sql.text("XA ROLLBACK :xid", bindparams=[sql.bindparam('xid',xid)])) + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + connection.execute(sql.text("XA COMMIT :xid", bindparams=[sql.bindparam('xid',xid)])) + + def do_recover_twophase(self, connection): + resultset = connection.execute(sql.text("XA RECOVER")) + return [row['data'][0:row['gtrid_length']] for row in resultset] + def is_disconnect(self, e): return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2013, 2014, 2045, 2055) diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 20088045e9..3c37ac0071 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -237,6 +237,34 @@ class PGDialect(ansisql.ANSIDialect): def schemadropper(self, *args, **kwargs): return PGSchemaDropper(self, *args, **kwargs) + def do_begin_twophase(self, connection, xid): + self.do_begin(connection.connection) + + def do_prepare_twophase(self, connection, xid): + connection.execute(sql.text("PREPARE TRANSACTION %(tid)s", bindparams=[sql.bindparam('tid', xid)])) + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + if is_prepared: + if recover: + #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions + # Must find out a way how to make the dbapi not open a transaction. + connection.execute(sql.text("ROLLBACK")) + connection.execute(sql.text("ROLLBACK PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)])) + else: + self.do_rollback(connection.connection) + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + if is_prepared: + if recover: + connection.execute(sql.text("ROLLBACK")) + connection.execute(sql.text("COMMIT PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)])) + else: + self.do_commit(connection.connection) + + def do_recover_twophase(self, connection): + resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts")) + return [row[0] for row in resultset] + def defaultrunner(self, connection, **kwargs): return PGDefaultRunner(connection, **kwargs) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 962c78bcfb..40c5e9bb38 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -9,7 +9,7 @@ higher-level statement-construction, connection-management, execution and result contexts.""" from sqlalchemy import exceptions, sql, schema, util, types, logging -import StringIO, sys, re +import StringIO, sys, re, random class Dialect(sql.AbstractDialect): @@ -211,6 +211,46 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() + def do_savepoint(self, connection, name): + """Create a savepoint with the given name on a SQLAlchemy connection.""" + + raise NotImplementedError() + + def do_rollback_to_savepoint(self, connection, name): + """Rollback a SQL Alchemy connection to the named savepoint.""" + + raise NotImplementedError() + + def do_release_savepoint(self, connection, name): + """Release the named savepoint on a SQL Alchemy connection.""" + + raise NotImplementedError() + + def do_begin_twophase(self, connection, xid): + """Begin a two phase transaction on the given connection.""" + + raise NotImplementedError() + + def do_prepare_twophase(self, connection, xid): + """Prepare a two phase transaction on the given connection.""" + + raise NotImplementedError() + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + """Rollback a two phase transaction on the given connection.""" + + raise NotImplementedError() + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + """Commit a two phase transaction on the given connection.""" + + raise NotImplementedError() + + def do_recover_twophase(self, connection): + """Recover list of uncommited prepared two phase transaction identifiers on the given connection.""" + + raise NotImplementedError() + def do_executemany(self, cursor, statement, parameters): """Provide an implementation of *cursor.executemany(statement, parameters)*.""" @@ -475,6 +515,7 @@ class Connection(Connectable): self.__connection = connection or engine.raw_connection() self.__transaction = None self.__close_with_result = close_with_result + self.__savepoint_seq = 0 def _get_connection(self): try: @@ -487,9 +528,6 @@ class Connection(Connectable): connection = property(_get_connection, doc="The underlying DBAPI connection managed by this Connection.") should_close_with_result = property(lambda s:s.__close_with_result, doc="Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode.") - def _create_transaction(self, parent): - return Transaction(self, parent) - def connect(self): """connect() is implemented to return self so that an incoming Engine or Connection object can be treated similarly.""" return self @@ -522,12 +560,34 @@ class Connection(Connectable): self.__connection.detach() - def begin(self): + def begin(self, nested=False): if self.__transaction is None: - self.__transaction = self._create_transaction(None) - return self.__transaction + self.__transaction = RootTransaction(self) + elif nested: + self.__transaction = NestedTransaction(self, self.__transaction) else: - return self._create_transaction(self.__transaction) + return Transaction(self, self.__transaction) + return self.__transaction + + def begin_nested(self): + return self.begin(nested=True) + + def begin_twophase(self, xid=None): + if self.__transaction is not None: + raise exceptions.InvalidRequestError("Cannot start a two phase transaction when a transaction is already started.") + if xid is None: + xid = "_sa_%032x" % random.randint(0,2**128) + self.__transaction = TwoPhaseTransaction(self, xid) + return self.__transaction + + def recover_twophase(self): + return self.__engine.dialect.do_recover_twophase(self) + + def rollback_prepared(self, xid, recover=False): + self.__engine.dialect.do_rollback_twophase(self, xid, recover=recover) + + def commit_prepared(self, xid, recover=False): + self.__engine.dialect.do_commit_twophase(self, xid, recover=recover) def in_transaction(self): return self.__transaction is not None @@ -559,6 +619,54 @@ class Connection(Connectable): raise exceptions.SQLError(None, None, e) self.__transaction = None + def _savepoint_impl(self, name=None): + if name is None: + self.__savepoint_seq += 1 + name = '__sa_savepoint_%s' % self.__savepoint_seq + if self.__connection.is_valid: + try: + self.__engine.dialect.do_savepoint(self, name) + return name + except Exception, e: + raise exceptions.SQLError(None, None, e) + + def _rollback_to_savepoint_impl(self, name, context): + if self.__connection.is_valid: + try: + self.__engine.dialect.do_rollback_to_savepoint(self, name) + except Exception, e: + raise exceptions.SQLError(None, None, e) + self.__transaction = context + + def _release_savepoint_impl(self, name, context): + if self.__connection.is_valid: + try: + self.__engine.dialect.do_release_savepoint(self, name) + except Exception, e: + raise exceptions.SQLError(None, None, e) + self.__transaction = context + + def _begin_twophase_impl(self, xid): + if self.__connection.is_valid: + self.__engine.dialect.do_begin_twophase(self, xid) + + def _prepare_twophase_impl(self, xid): + if self.__connection.is_valid: + assert isinstance(self.__transaction, TwoPhaseTransaction) + self.__engine.dialect.do_prepare_twophase(self, xid) + + def _rollback_twophase_impl(self, xid, is_prepared): + if self.__connection.is_valid: + assert isinstance(self.__transaction, TwoPhaseTransaction) + self.__engine.dialect.do_rollback_twophase(self, xid, is_prepared) + self.__transaction = None + + def _commit_twophase_impl(self, xid, is_prepared): + if self.__connection.is_valid: + assert isinstance(self.__transaction, TwoPhaseTransaction) + self.__engine.dialect.do_commit_twophase(self, xid, is_prepared) + self.__transaction = None + def _autocommit(self, statement): """When no Transaction is present, this is called after executions to provide "autocommit" behavior.""" # TODO: have the dialect determine if autocommit can be set on the connection directly without this @@ -722,30 +830,71 @@ class Transaction(object): """ def __init__(self, connection, parent): - self.__connection = connection - self.__parent = parent or self - self.__is_active = True - if self.__parent is self: - self.__connection._begin_impl() + self._connection = connection + self._parent = parent or self + self._is_active = True - connection = property(lambda s:s.__connection, doc="The Connection object referenced by this Transaction") - is_active = property(lambda s:s.__is_active) + connection = property(lambda s:s._connection, doc="The Connection object referenced by this Transaction") + is_active = property(lambda s:s._is_active) def rollback(self): - if not self.__parent.__is_active: + if not self._parent._is_active: return - if self.__parent is self: - self.__connection._rollback_impl() - self.__is_active = False - else: - self.__parent.rollback() + self._is_active = False + self._do_rollback() + + def _do_rollback(self): + self._parent.rollback() def commit(self): - if not self.__parent.__is_active: + if not self._parent._is_active: raise exceptions.InvalidRequestError("This transaction is inactive") - if self.__parent is self: - self.__connection._commit_impl() - self.__is_active = False + self._is_active = False + self._do_commit() + + def _do_commit(self): + pass + +class RootTransaction(Transaction): + def __init__(self, connection): + super(RootTransaction, self).__init__(connection, None) + self._connection._begin_impl() + + def _do_rollback(self): + self._connection._rollback_impl() + + def _do_commit(self): + self._connection._commit_impl() + +class NestedTransaction(Transaction): + def __init__(self, connection, parent): + super(NestedTransaction, self).__init__(connection, parent) + self._savepoint = self._connection._savepoint_impl() + + def _do_rollback(self): + self._connection._rollback_to_savepoint_impl(self._savepoint, self._parent) + + def _do_commit(self): + self._connection._release_savepoint_impl(self._savepoint, self._parent) + +class TwoPhaseTransaction(Transaction): + def __init__(self, connection, xid): + super(TwoPhaseTransaction, self).__init__(connection, None) + self._is_prepared = False + self.xid = xid + self._connection._begin_twophase_impl(self.xid) + + def prepare(self): + if not self._parent._is_active: + raise exceptions.InvalidRequestError("This transaction is inactive") + self._connection._prepare_twophase_impl(self.xid) + self._is_prepared = True + + def _do_rollback(self): + self._connection._rollback_twophase_impl(self.xid, self._is_prepared) + + def commit(self): + self._connection._commit_twophase_impl(self.xid, self._is_prepared) class Engine(Connectable): """ diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 25cfad11ee..c5e1e76ee3 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -95,6 +95,15 @@ class DefaultDialect(base.Dialect): #print "ENGINE COMMIT ON ", connection.connection connection.commit() + + def do_savepoint(self, connection, name): + connection.execute(sql.SavepointClause(name)) + + def do_rollback_to_savepoint(self, connection, name): + connection.execute(sql.RollbackToSavepointClause(name)) + + def do_release_savepoint(self, connection, name): + connection.execute(sql.ReleaseSavepointClause(name)) def do_executemany(self, cursor, statement, parameters, **kwargs): cursor.executemany(statement, parameters) diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 35313271ba..d87bc6f859 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -70,11 +70,8 @@ class TLConnection(base.Connection): self.__opencount += 1 return self - def _create_transaction(self, parent): - return TLTransaction(self, parent) - def _begin(self): - return base.Connection.begin(self) + return TLTransaction(self) def in_transaction(self): return self.session.in_transaction() @@ -91,7 +88,7 @@ class TLConnection(base.Connection): self.__opencount = 0 base.Connection.close(self) -class TLTransaction(base.Transaction): +class TLTransaction(base.RootTransaction): def _commit_impl(self): base.Transaction.commit(self) diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 60b2a3d326..32c20bc10f 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -3114,3 +3114,18 @@ class Delete(_UpdateBase): return self._whereclause, else: return () + +class _IdentifiedClause(ClauseElement): + def __init__(self, ident): + self.ident = ident + def supports_execution(self): + return True + +class SavepointClause(_IdentifiedClause): + pass + +class RollbackToSavepointClause(_IdentifiedClause): + pass + +class ReleaseSavepointClause(_IdentifiedClause): + pass diff --git a/test/engine/transaction.py b/test/engine/transaction.py index 96fe7acf4d..246d5cea50 100644 --- a/test/engine/transaction.py +++ b/test/engine/transaction.py @@ -115,7 +115,153 @@ class TransactionTest(testbase.PersistTest): result = connection.execute("select * from query_users") assert len(result.fetchall()) == 0 connection.close() + + @testbase.unsupported('sqlite') + def testnestedsubtransactionrollback(self): + connection = testbase.db.connect() + transaction = connection.begin() + connection.execute(users.insert(), user_id=1, user_name='user1') + trans2 = connection.begin_nested() + connection.execute(users.insert(), user_id=2, user_name='user2') + trans2.rollback() + connection.execute(users.insert(), user_id=3, user_name='user3') + transaction.commit() + + self.assertEquals( + connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), + [(1,),(3,)] + ) + connection.close() + + @testbase.unsupported('sqlite') + def testnestedsubtransactioncommit(self): + connection = testbase.db.connect() + transaction = connection.begin() + connection.execute(users.insert(), user_id=1, user_name='user1') + trans2 = connection.begin_nested() + connection.execute(users.insert(), user_id=2, user_name='user2') + trans2.commit() + connection.execute(users.insert(), user_id=3, user_name='user3') + transaction.commit() + + self.assertEquals( + connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), + [(1,),(2,),(3,)] + ) + connection.close() + + @testbase.unsupported('sqlite') + def testrollbacktosubtransaction(self): + connection = testbase.db.connect() + transaction = connection.begin() + connection.execute(users.insert(), user_id=1, user_name='user1') + trans2 = connection.begin_nested() + connection.execute(users.insert(), user_id=2, user_name='user2') + trans3 = connection.begin() + connection.execute(users.insert(), user_id=3, user_name='user3') + trans3.rollback() + connection.execute(users.insert(), user_id=4, user_name='user4') + transaction.commit() + + self.assertEquals( + connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), + [(1,),(4,)] + ) + connection.close() + + @testbase.supported('postgres', 'mysql') + def testtwophasetransaction(self): + connection = testbase.db.connect() + + transaction = connection.begin_twophase() + connection.execute(users.insert(), user_id=1, user_name='user1') + transaction.prepare() + transaction.commit() + + transaction = connection.begin_twophase() + connection.execute(users.insert(), user_id=2, user_name='user2') + transaction.commit() + transaction = connection.begin_twophase() + connection.execute(users.insert(), user_id=3, user_name='user3') + transaction.rollback() + + transaction = connection.begin_twophase() + connection.execute(users.insert(), user_id=4, user_name='user4') + transaction.prepare() + transaction.rollback() + + self.assertEquals( + connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), + [(1,),(2,)] + ) + connection.close() + + @testbase.supported('postgres', 'mysql') + def testmixedtransaction(self): + connection = testbase.db.connect() + + transaction = connection.begin_twophase() + connection.execute(users.insert(), user_id=1, user_name='user1') + + transaction2 = connection.begin() + connection.execute(users.insert(), user_id=2, user_name='user2') + + transaction3 = connection.begin_nested() + connection.execute(users.insert(), user_id=3, user_name='user3') + + transaction4 = connection.begin() + connection.execute(users.insert(), user_id=4, user_name='user4') + transaction4.commit() + + transaction3.rollback() + + connection.execute(users.insert(), user_id=5, user_name='user5') + + transaction2.commit() + + transaction.prepare() + + transaction.commit() + + self.assertEquals( + connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), + [(1,),(2,),(5,)] + ) + connection.close() + + @testbase.supported('postgres') + def testtwophaserecover(self): + # MySQL recovery doesn't currently seem to work correctly + # Prepared transactions disappear when connections are closed and even + # when they aren't it doesn't seem possible to use the recovery id. + connection = testbase.db.connect() + + transaction = connection.begin_twophase() + connection.execute(users.insert(), user_id=1, user_name='user1') + transaction.prepare() + + connection.close() + connection2 = testbase.db.connect() + + self.assertEquals( + connection2.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), + [] + ) + + recoverables = connection2.recover_twophase() + self.assertTrue( + transaction.xid in recoverables + ) + + connection2.commit_prepared(transaction.xid, recover=True) + + self.assertEquals( + connection2.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), + [(1,)] + ) + connection2.close() + class AutoRollbackTest(testbase.PersistTest): def setUpAll(self): global metadata -- 2.47.3