]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add support for two phase commits, nested subtransactions and savepoints. refactors...
authorAnts Aasma <ants.aasma@gmail.com>
Sat, 14 Jul 2007 00:36:05 +0000 (00:36 +0000)
committerAnts Aasma <ants.aasma@gmail.com>
Sat, 14 Jul 2007 00:36:05 +0000 (00:36 +0000)
CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/threadlocal.py
lib/sqlalchemy/sql.py
test/engine/transaction.py

diff --git a/CHANGES b/CHANGES
index b1d68dad6455de624df9ee6d21e02e7052744cd8..f975a3c20e59b12e38b98779b44f238fdad16347 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -73,6 +73,9 @@
       columns joined by a "group" to load as "undeferred".
 
 - sql
+  - added support for two phase commit, works with mysql and postgres so far.
+  - added a subtransaction implementation that uses savepoints.
+  - added support for savepoints.
   - DynamicMetaData has been renamed to ThreadLocalMetaData
   - BoundMetaData has been removed- regular MetaData is equivalent
   - significant architectural overhaul to SQL elements (ClauseElement).  
index 188063a82da1346fbe324137a14d4d8f57bee14e..b1e58dac22dbb1c650fa0837f8040d87bd93d947 100644 (file)
@@ -793,7 +793,19 @@ class ANSICompiler(engine.Compiled):
             text += " WHERE " + self.get_str(delete_stmt._whereclause)
 
         self.strings[delete_stmt] = text
+        
+    def visit_savepoint(self, savepoint_stmt):
+        text = "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+        self.strings[savepoint_stmt] = text
 
+    def visit_rollback_to_savepoint(self, savepoint_stmt):
+        text = "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+        self.strings[savepoint_stmt] = text
+    
+    def visit_release_savepoint(self, savepoint_stmt):
+        text = "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+        self.strings[savepoint_stmt] = text
+    
     def __str__(self):
         return self.get_str(self.statement)
 
@@ -1080,6 +1092,9 @@ class ANSIIdentifierPreparer(object):
     def format_alias(self, alias):
         return self.__generic_obj_format(alias, alias.name)
 
+    def format_savepoint(self, savepoint):
+        return self.__generic_obj_format(savepoint, savepoint)
+
     def format_table(self, table, use_schema=True, name=None):
         """Prepare a quoted table and schema name."""
 
index 27d87847e2e7d8c17565575129c90692d5e6bf16..7e5956444591d19889556ad18fc8b89db16f9797 100644 (file)
@@ -949,7 +949,7 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
                 self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
             
     def is_select(self):
-        return re.match(r'SELECT|SHOW|DESCRIBE', self.statement.lstrip(), re.I) is not None
+        return re.match(r'SELECT|SHOW|DESCRIBE|XA RECOVER', self.statement.lstrip(), re.I) is not None
 
 class MySQLDialect(ansisql.ANSIDialect):
     def __init__(self, **kwargs):
@@ -1038,6 +1038,27 @@ class MySQLDialect(ansisql.ANSIDialect):
         except:
             pass
 
+    def do_begin_twophase(self, connection, xid):
+        connection.execute(sql.text("XA BEGIN :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+    def do_prepare_twophase(self, connection, xid):
+        connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)]))
+        connection.execute(sql.text("XA PREPARE :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+    def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+        if not is_prepared:
+            connection.execute(sql.text("XA END :xid", bindparams=[sql.bindparam('xid',xid)]))
+        connection.execute(sql.text("XA ROLLBACK :xid", bindparams=[sql.bindparam('xid',xid)]))
+
+    def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+        if not is_prepared:
+            self.do_prepare_twophase(connection, xid)
+        connection.execute(sql.text("XA COMMIT :xid", bindparams=[sql.bindparam('xid',xid)]))
+    
+    def do_recover_twophase(self, connection):
+        resultset = connection.execute(sql.text("XA RECOVER"))
+        return [row['data'][0:row['gtrid_length']] for row in resultset]
+
     def is_disconnect(self, e):
         return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2013, 2014, 2045, 2055)
 
index 20088045e98bef4c8303fe411e30f6c667afb655..3c37ac0071f1072a41102f28f059593fc94d8f52 100644 (file)
@@ -237,6 +237,34 @@ class PGDialect(ansisql.ANSIDialect):
     def schemadropper(self, *args, **kwargs):
         return PGSchemaDropper(self, *args, **kwargs)
 
+    def do_begin_twophase(self, connection, xid):
+        self.do_begin(connection.connection)
+
+    def do_prepare_twophase(self, connection, xid):
+        connection.execute(sql.text("PREPARE TRANSACTION %(tid)s", bindparams=[sql.bindparam('tid', xid)]))
+
+    def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+        if is_prepared:
+            if recover:
+                #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions
+                # Must find out a way how to make the dbapi not open a transaction.
+                connection.execute(sql.text("ROLLBACK"))
+            connection.execute(sql.text("ROLLBACK PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)]))
+        else:
+            self.do_rollback(connection.connection)
+
+    def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+        if is_prepared:
+            if recover:
+                connection.execute(sql.text("ROLLBACK"))
+            connection.execute(sql.text("COMMIT PREPARED %(tid)s", bindparams=[sql.bindparam('tid', xid)]))
+        else:
+            self.do_commit(connection.connection)
+
+    def do_recover_twophase(self, connection):
+        resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
+        return [row[0] for row in resultset]
+
     def defaultrunner(self, connection, **kwargs):
         return PGDefaultRunner(connection, **kwargs)
 
index 962c78bcfbe33482fa8ab8e5c0f14749d54d8808..40c5e9bb38ae9dce65adf1551561dac4f2f6f596 100644 (file)
@@ -9,7 +9,7 @@ higher-level statement-construction, connection-management,
 execution and result contexts."""
 
 from sqlalchemy import exceptions, sql, schema, util, types, logging
-import StringIO, sys, re
+import StringIO, sys, re, random
 
 
 class Dialect(sql.AbstractDialect):
@@ -211,6 +211,46 @@ class Dialect(sql.AbstractDialect):
 
         raise NotImplementedError()
 
+    def do_savepoint(self, connection, name):
+        """Create a savepoint with the given name on a SQLAlchemy connection."""
+
+        raise NotImplementedError()
+
+    def do_rollback_to_savepoint(self, connection, name):
+        """Rollback a SQL Alchemy connection to the named savepoint."""
+
+        raise NotImplementedError()
+
+    def do_release_savepoint(self, connection, name):
+        """Release the named savepoint on a SQL Alchemy connection."""
+
+        raise NotImplementedError()
+
+    def do_begin_twophase(self, connection, xid):
+        """Begin a two phase transaction on the given connection."""
+
+        raise NotImplementedError()
+
+    def do_prepare_twophase(self, connection, xid):
+        """Prepare a two phase transaction on the given connection."""
+
+        raise NotImplementedError()
+
+    def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+        """Rollback a two phase transaction on the given connection."""
+
+        raise NotImplementedError()
+
+    def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+        """Commit a two phase transaction on the given connection."""
+
+        raise NotImplementedError()
+
+    def do_recover_twophase(self, connection):
+        """Recover list of uncommited prepared two phase transaction identifiers on the given connection."""
+
+        raise NotImplementedError()
+
     def do_executemany(self, cursor, statement, parameters):
         """Provide an implementation of *cursor.executemany(statement, parameters)*."""
 
@@ -475,6 +515,7 @@ class Connection(Connectable):
         self.__connection = connection or engine.raw_connection()
         self.__transaction = None
         self.__close_with_result = close_with_result
+        self.__savepoint_seq = 0
 
     def _get_connection(self):
         try:
@@ -487,9 +528,6 @@ class Connection(Connectable):
     connection = property(_get_connection, doc="The underlying DBAPI connection managed by this Connection.")
     should_close_with_result = property(lambda s:s.__close_with_result, doc="Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode.")
 
-    def _create_transaction(self, parent):
-        return Transaction(self, parent)
-
     def connect(self):
         """connect() is implemented to return self so that an incoming Engine or Connection object can be treated similarly."""
         return self
@@ -522,12 +560,34 @@ class Connection(Connectable):
         
         self.__connection.detach()
         
-    def begin(self):
+    def begin(self, nested=False):
         if self.__transaction is None:
-            self.__transaction = self._create_transaction(None)
-            return self.__transaction
+            self.__transaction = RootTransaction(self)
+        elif nested:
+            self.__transaction = NestedTransaction(self, self.__transaction)
         else:
-            return self._create_transaction(self.__transaction)
+            return Transaction(self, self.__transaction)
+        return self.__transaction
+
+    def begin_nested(self):
+        return self.begin(nested=True)
+    
+    def begin_twophase(self, xid=None):
+        if self.__transaction is not None:
+            raise exceptions.InvalidRequestError("Cannot start a two phase transaction when a transaction is already started.")
+        if xid is None:
+            xid = "_sa_%032x" % random.randint(0,2**128)
+        self.__transaction = TwoPhaseTransaction(self, xid)
+        return self.__transaction
+        
+    def recover_twophase(self):
+        return self.__engine.dialect.do_recover_twophase(self)
+    
+    def rollback_prepared(self, xid, recover=False):
+        self.__engine.dialect.do_rollback_twophase(self, xid, recover=recover)
+    
+    def commit_prepared(self, xid, recover=False):
+        self.__engine.dialect.do_commit_twophase(self, xid, recover=recover)
 
     def in_transaction(self):
         return self.__transaction is not None
@@ -559,6 +619,54 @@ class Connection(Connectable):
                 raise exceptions.SQLError(None, None, e)
         self.__transaction = None
 
+    def _savepoint_impl(self, name=None):
+        if name is None:
+            self.__savepoint_seq += 1
+            name = '__sa_savepoint_%s' % self.__savepoint_seq
+        if self.__connection.is_valid:
+            try:
+                self.__engine.dialect.do_savepoint(self, name)
+                return name
+            except Exception, e:
+                raise exceptions.SQLError(None, None, e)
+    
+    def _rollback_to_savepoint_impl(self, name, context):
+        if self.__connection.is_valid:
+            try:
+                self.__engine.dialect.do_rollback_to_savepoint(self, name)
+            except Exception, e:
+                raise exceptions.SQLError(None, None, e)
+        self.__transaction = context
+    
+    def _release_savepoint_impl(self, name, context):
+        if self.__connection.is_valid:
+            try:
+                self.__engine.dialect.do_release_savepoint(self, name)
+            except Exception, e:
+                raise exceptions.SQLError(None, None, e)
+        self.__transaction = context
+    
+    def _begin_twophase_impl(self, xid):
+        if self.__connection.is_valid:
+            self.__engine.dialect.do_begin_twophase(self, xid)
+    
+    def _prepare_twophase_impl(self, xid):
+        if self.__connection.is_valid:
+            assert isinstance(self.__transaction, TwoPhaseTransaction)
+            self.__engine.dialect.do_prepare_twophase(self, xid)
+    
+    def _rollback_twophase_impl(self, xid, is_prepared):
+        if self.__connection.is_valid:
+            assert isinstance(self.__transaction, TwoPhaseTransaction)
+            self.__engine.dialect.do_rollback_twophase(self, xid, is_prepared)
+        self.__transaction = None
+
+    def _commit_twophase_impl(self, xid, is_prepared):
+        if self.__connection.is_valid:
+            assert isinstance(self.__transaction, TwoPhaseTransaction)
+            self.__engine.dialect.do_commit_twophase(self, xid, is_prepared)
+        self.__transaction = None
+
     def _autocommit(self, statement):
         """When no Transaction is present, this is called after executions to provide "autocommit" behavior."""
         # TODO: have the dialect determine if autocommit can be set on the connection directly without this
@@ -722,30 +830,71 @@ class Transaction(object):
     """
 
     def __init__(self, connection, parent):
-        self.__connection = connection
-        self.__parent = parent or self
-        self.__is_active = True
-        if self.__parent is self:
-            self.__connection._begin_impl()
+        self._connection = connection
+        self._parent = parent or self
+        self._is_active = True
 
-    connection = property(lambda s:s.__connection, doc="The Connection object referenced by this Transaction")
-    is_active = property(lambda s:s.__is_active)
+    connection = property(lambda s:s._connection, doc="The Connection object referenced by this Transaction")
+    is_active = property(lambda s:s._is_active)
 
     def rollback(self):
-        if not self.__parent.__is_active:
+        if not self._parent._is_active:
             return
-        if self.__parent is self:
-            self.__connection._rollback_impl()
-            self.__is_active = False
-        else:
-            self.__parent.rollback()
+        self._is_active = False
+        self._do_rollback()
+    
+    def _do_rollback(self):
+        self._parent.rollback()
 
     def commit(self):
-        if not self.__parent.__is_active:
+        if not self._parent._is_active:
             raise exceptions.InvalidRequestError("This transaction is inactive")
-        if self.__parent is self:
-            self.__connection._commit_impl()
-            self.__is_active = False
+        self._is_active = False
+        self._do_commit()
+    
+    def _do_commit(self):
+        pass
+
+class RootTransaction(Transaction):
+    def __init__(self, connection):
+        super(RootTransaction, self).__init__(connection, None)
+        self._connection._begin_impl()
+    
+    def _do_rollback(self):
+        self._connection._rollback_impl()
+
+    def _do_commit(self):
+        self._connection._commit_impl()
+
+class NestedTransaction(Transaction):
+    def __init__(self, connection, parent):
+        super(NestedTransaction, self).__init__(connection, parent)
+        self._savepoint = self._connection._savepoint_impl()
+    
+    def _do_rollback(self):
+        self._connection._rollback_to_savepoint_impl(self._savepoint, self._parent)
+
+    def _do_commit(self):
+        self._connection._release_savepoint_impl(self._savepoint, self._parent)
+
+class TwoPhaseTransaction(Transaction):
+    def __init__(self, connection, xid):
+        super(TwoPhaseTransaction, self).__init__(connection, None)
+        self._is_prepared = False
+        self.xid = xid
+        self._connection._begin_twophase_impl(self.xid)
+    
+    def prepare(self):
+        if not self._parent._is_active:
+            raise exceptions.InvalidRequestError("This transaction is inactive")
+        self._connection._prepare_twophase_impl(self.xid)
+        self._is_prepared = True
+    
+    def _do_rollback(self):
+        self._connection._rollback_twophase_impl(self.xid, self._is_prepared)
+    
+    def commit(self):
+        self._connection._commit_twophase_impl(self.xid, self._is_prepared)
 
 class Engine(Connectable):
     """
index 25cfad11ee19364ae0e054a35413db02feeaca5f..c5e1e76ee3e92fd288a4e9032b401ffd321924a7 100644 (file)
@@ -95,6 +95,15 @@ class DefaultDialect(base.Dialect):
 
         #print "ENGINE COMMIT ON ", connection.connection
         connection.commit()
+        
+    def do_savepoint(self, connection, name):
+        connection.execute(sql.SavepointClause(name))
+
+    def do_rollback_to_savepoint(self, connection, name):
+        connection.execute(sql.RollbackToSavepointClause(name))
+
+    def do_release_savepoint(self, connection, name):
+        connection.execute(sql.ReleaseSavepointClause(name))
 
     def do_executemany(self, cursor, statement, parameters, **kwargs):
         cursor.executemany(statement, parameters)
index 35313271bad39a62bd7a46864a56c07b3fb585f3..d87bc6f859ad159f181a49b7546bc1faa073ab1b 100644 (file)
@@ -70,11 +70,8 @@ class TLConnection(base.Connection):
         self.__opencount += 1
         return self
 
-    def _create_transaction(self, parent):
-        return TLTransaction(self, parent)
-
     def _begin(self):
-        return base.Connection.begin(self)
+        return TLTransaction(self)
 
     def in_transaction(self):
         return self.session.in_transaction()
@@ -91,7 +88,7 @@ class TLConnection(base.Connection):
         self.__opencount = 0
         base.Connection.close(self)
 
-class TLTransaction(base.Transaction):
+class TLTransaction(base.RootTransaction):
     def _commit_impl(self):
         base.Transaction.commit(self)
 
index 60b2a3d3265237c12c98c0a35cc6ef74360da2ba..32c20bc10ff47ab0fe2874a876b3bf614585e9d8 100644 (file)
@@ -3114,3 +3114,18 @@ class Delete(_UpdateBase):
             return self._whereclause,
         else:
             return ()
+
+class _IdentifiedClause(ClauseElement):
+    def __init__(self, ident):
+        self.ident = ident
+    def supports_execution(self):
+        return True
+
+class SavepointClause(_IdentifiedClause):
+    pass
+
+class RollbackToSavepointClause(_IdentifiedClause):
+    pass
+
+class ReleaseSavepointClause(_IdentifiedClause):
+    pass
index 96fe7acf4da1b4ff3752472f3325c01293c48198..246d5cea50157f6ae953d68fd89068132e1702e7 100644 (file)
@@ -115,7 +115,153 @@ class TransactionTest(testbase.PersistTest):
         result = connection.execute("select * from query_users")
         assert len(result.fetchall()) == 0
         connection.close()
+    
+    @testbase.unsupported('sqlite')
+    def testnestedsubtransactionrollback(self):
+        connection = testbase.db.connect()
+        transaction = connection.begin()
+        connection.execute(users.insert(), user_id=1, user_name='user1')
+        trans2 = connection.begin_nested()
+        connection.execute(users.insert(), user_id=2, user_name='user2')
+        trans2.rollback()
+        connection.execute(users.insert(), user_id=3, user_name='user3')
+        transaction.commit()
+        
+        self.assertEquals(
+            connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+            [(1,),(3,)]
+        )
+        connection.close()
+
+    @testbase.unsupported('sqlite')
+    def testnestedsubtransactioncommit(self):
+        connection = testbase.db.connect()
+        transaction = connection.begin()
+        connection.execute(users.insert(), user_id=1, user_name='user1')
+        trans2 = connection.begin_nested()
+        connection.execute(users.insert(), user_id=2, user_name='user2')
+        trans2.commit()
+        connection.execute(users.insert(), user_id=3, user_name='user3')
+        transaction.commit()
+        
+        self.assertEquals(
+            connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+            [(1,),(2,),(3,)]
+        )
+        connection.close()
+
+    @testbase.unsupported('sqlite')
+    def testrollbacktosubtransaction(self):
+        connection = testbase.db.connect()
+        transaction = connection.begin()
+        connection.execute(users.insert(), user_id=1, user_name='user1')
+        trans2 = connection.begin_nested()
+        connection.execute(users.insert(), user_id=2, user_name='user2')
+        trans3 = connection.begin()
+        connection.execute(users.insert(), user_id=3, user_name='user3')
+        trans3.rollback()
+        connection.execute(users.insert(), user_id=4, user_name='user4')
+        transaction.commit()
+        
+        self.assertEquals(
+            connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+            [(1,),(4,)]
+        )
+        connection.close()
+    
+    @testbase.supported('postgres', 'mysql')
+    def testtwophasetransaction(self):
+        connection = testbase.db.connect()
+        
+        transaction = connection.begin_twophase()
+        connection.execute(users.insert(), user_id=1, user_name='user1')
+        transaction.prepare()
+        transaction.commit()
+        
+        transaction = connection.begin_twophase()
+        connection.execute(users.insert(), user_id=2, user_name='user2')
+        transaction.commit()
         
+        transaction = connection.begin_twophase()
+        connection.execute(users.insert(), user_id=3, user_name='user3')
+        transaction.rollback()
+        
+        transaction = connection.begin_twophase()
+        connection.execute(users.insert(), user_id=4, user_name='user4')
+        transaction.prepare()
+        transaction.rollback()
+        
+        self.assertEquals(
+            connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+            [(1,),(2,)]
+        )
+        connection.close()
+
+    @testbase.supported('postgres', 'mysql')
+    def testmixedtransaction(self):
+        connection = testbase.db.connect()
+        
+        transaction = connection.begin_twophase()
+        connection.execute(users.insert(), user_id=1, user_name='user1')
+        
+        transaction2 = connection.begin()
+        connection.execute(users.insert(), user_id=2, user_name='user2')
+        
+        transaction3 = connection.begin_nested()
+        connection.execute(users.insert(), user_id=3, user_name='user3')
+        
+        transaction4 = connection.begin()
+        connection.execute(users.insert(), user_id=4, user_name='user4')
+        transaction4.commit()
+        
+        transaction3.rollback()
+        
+        connection.execute(users.insert(), user_id=5, user_name='user5')
+        
+        transaction2.commit()
+        
+        transaction.prepare()
+        
+        transaction.commit()
+        
+        self.assertEquals(
+            connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+            [(1,),(2,),(5,)]
+        )
+        connection.close()
+        
+    @testbase.supported('postgres')
+    def testtwophaserecover(self):
+        # MySQL recovery doesn't currently seem to work correctly
+        # Prepared transactions disappear when connections are closed and even
+        # when they aren't it doesn't seem possible to use the recovery id.
+        connection = testbase.db.connect()
+        
+        transaction = connection.begin_twophase()
+        connection.execute(users.insert(), user_id=1, user_name='user1')
+        transaction.prepare()
+        
+        connection.close()
+        connection2 = testbase.db.connect()
+        
+        self.assertEquals(
+            connection2.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+            []
+        )
+        
+        recoverables = connection2.recover_twophase()
+        self.assertTrue(
+            transaction.xid in recoverables
+        )
+        
+        connection2.commit_prepared(transaction.xid, recover=True)
+
+        self.assertEquals(
+            connection2.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+            [(1,)]
+        )
+        connection2.close()
+
 class AutoRollbackTest(testbase.PersistTest):
     def setUpAll(self):
         global metadata