]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve pg two-phase transactions
authorFederico Caselli <cfederico87@gmail.com>
Mon, 13 Apr 2026 21:53:00 +0000 (23:53 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 14 Apr 2026 19:46:28 +0000 (21:46 +0200)
Improve handling of two phase transaction identifiers for PostgreSQL
when the identifier is provided by the user.
As part of this change the psycopg dialect was updated to use the DBAPI
two phase transaction API instead of executing the SQL directly.

Fixes: #13229
Change-Id: If8301a7253b4a0c88e5323c9a052c3a9fa258780

doc/build/changelog/unreleased_20/13229.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/_psycopg_common.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/provision.py
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/engine/base.py
test/engine/test_execute.py

diff --git a/doc/build/changelog/unreleased_20/13229.rst b/doc/build/changelog/unreleased_20/13229.rst
new file mode 100644 (file)
index 0000000..e02886b
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: postgresql, bug
+    :tickets: 13229
+
+    Improve handling of two phase transaction identifiers for PostgreSQL
+    when the identifier is provided by the user.
+    As part of this change the psycopg dialect was updated to use the DBAPI
+    two phase transaction API instead of executing the SQL directly.
index bc4994a976e210a42ce145b81d2846f00aba3aec..97b3d30bbcba35146238dee5b6f0c13a531d5865 100644 (file)
@@ -191,3 +191,39 @@ class _PGDialect_common_psycopg(PGDialect):
                 dbapi_connection.autocommit = before_autocommit
 
         return True
+
+    def do_begin_twophase(self, connection, xid):
+        connection.connection.tpc_begin(xid)
+
+    def do_prepare_twophase(self, connection, xid):
+        connection.connection.tpc_prepare()
+
+    def _do_twophase(self, dbapi_conn, operation, xid, recover=False):
+        if recover:
+            if not self._twophase_idle_check(dbapi_conn):
+                dbapi_conn.rollback()
+            operation(xid)
+        else:
+            operation()
+
+    def _twophase_idle_check(self, dbapi_conn):
+        raise NotImplementedError
+
+    def do_rollback_twophase(
+        self, connection, xid, is_prepared=True, recover=False
+    ):
+        dbapi_conn = connection.connection.dbapi_connection
+        self._do_twophase(
+            dbapi_conn, dbapi_conn.tpc_rollback, xid, recover=recover
+        )
+
+    def do_commit_twophase(
+        self, connection, xid, is_prepared=True, recover=False
+    ):
+        dbapi_conn = connection.connection.dbapi_connection
+        self._do_twophase(
+            dbapi_conn, dbapi_conn.tpc_commit, xid, recover=recover
+        )
+
+    def do_recover_twophase(self, connection):
+        return [row[1] for row in connection.connection.tpc_recover()]
index 25d6b6fbea642652298aa5c046bd1b3c801e7d6d..1702bc70c9731f5d5ba827547d5411c4e2d9cb13 100644 (file)
@@ -3687,7 +3687,11 @@ class PGDialect(default.DefaultDialect):
         self.do_begin(connection.connection)
 
     def do_prepare_twophase(self, connection, xid):
-        connection.exec_driver_sql("PREPARE TRANSACTION '%s'" % xid)
+        connection.execute(
+            sql.text("PREPARE TRANSACTION :xid'").bindparams(
+                sql.bindparam("xid", xid, literal_execute=True)
+            )
+        )
 
     def do_rollback_twophase(
         self, connection, xid, is_prepared=True, recover=False
@@ -3699,7 +3703,11 @@ class PGDialect(default.DefaultDialect):
                 # Must find out a way how to make the dbapi not
                 # open a transaction.
                 connection.exec_driver_sql("ROLLBACK")
-            connection.exec_driver_sql("ROLLBACK PREPARED '%s'" % xid)
+            connection.execute(
+                sql.text("ROLLBACK PREPARED :xid").bindparams(
+                    sql.bindparam("xid", xid, literal_execute=True)
+                )
+            )
             connection.exec_driver_sql("BEGIN")
             self.do_rollback(connection.connection)
         else:
@@ -3711,7 +3719,11 @@ class PGDialect(default.DefaultDialect):
         if is_prepared:
             if recover:
                 connection.exec_driver_sql("ROLLBACK")
-            connection.exec_driver_sql("COMMIT PREPARED '%s'" % xid)
+            connection.execute(
+                sql.text("COMMIT PREPARED :xid").bindparams(
+                    sql.bindparam("xid", xid, literal_execute=True)
+                )
+            )
             connection.exec_driver_sql("BEGIN")
             self.do_rollback(connection.connection)
         else:
index dfe0f627d2e7c0fe41bcdd7633d19c3b1403e6df..dfe67dce9cafdc66c8014550708402b8b14652a3 100644 (file)
@@ -95,9 +95,10 @@ def _postgresql_set_default_schema_on_connection(
 def drop_all_schema_objects_pre_tables(cfg, eng):
     with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
         for xid in conn.exec_driver_sql(
-            "select gid from pg_prepared_xacts"
+            "SELECT gid FROM pg_prepared_xacts "
+            "WHERE database = current_database()"
         ).scalars():
-            conn.exec_driver_sql("ROLLBACK PREPARED '%s'" % xid)
+            eng.dialect.do_rollback_twophase(conn, xid, recover=True)
 
 
 @drop_all_schema_objects_post_tables.for_db("postgresql")
index b23ac6319fc7e70834213db7074e978ee71c65cf..5422849e82d3f26b60b8d3880c3c96c485cd2e69 100644 (file)
@@ -598,45 +598,12 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
                 return True
         return False
 
-    def _do_prepared_twophase(self, connection, command, recover=False):
-        dbapi_conn = connection.connection.dbapi_connection
-        if (
-            recover
-            # don't rely on psycopg providing enum symbols, compare with
-            # eq/ne
-            or dbapi_conn.info.transaction_status
-            != self._psycopg_TransactionStatus.IDLE
-        ):
-            dbapi_conn.rollback()
-        before_autocommit = dbapi_conn.autocommit
-        try:
-            if not before_autocommit:
-                self._do_autocommit(dbapi_conn, True)
-            with dbapi_conn.cursor() as cursor:
-                cursor.execute(command)
-        finally:
-            if not before_autocommit:
-                self._do_autocommit(dbapi_conn, before_autocommit)
-
-    def do_rollback_twophase(
-        self, connection, xid, is_prepared=True, recover=False
-    ):
-        if is_prepared:
-            self._do_prepared_twophase(
-                connection, f"ROLLBACK PREPARED '{xid}'", recover=recover
-            )
-        else:
-            self.do_rollback(connection.connection)
-
-    def do_commit_twophase(
-        self, connection, xid, is_prepared=True, recover=False
-    ):
-        if is_prepared:
-            self._do_prepared_twophase(
-                connection, f"COMMIT PREPARED '{xid}'", recover=recover
-            )
-        else:
-            self.do_commit(connection.connection)
+    def _twophase_idle_check(self, dbapi_conn):
+        # don't rely on psycopg providing enum symbols, compare with eq/ne
+        return (
+            dbapi_conn.info.transaction_status
+            == self._psycopg_TransactionStatus.IDLE
+        )
 
     @util.memoized_property
     def _dialect_specific_select_one(self):
@@ -759,6 +726,21 @@ class AsyncAdapt_psycopg_connection(AsyncAdapt_dbapi_connection):
         else:
             return AsyncAdapt_psycopg_cursor(self)
 
+    def tpc_begin(self, xid):
+        return await_(self._connection.tpc_begin(xid))
+
+    def tpc_prepare(self):
+        return await_(self._connection.tpc_prepare())
+
+    def tpc_commit(self, xid=None):
+        return await_(self._connection.tpc_commit(xid))
+
+    def tpc_rollback(self, xid=None):
+        return await_(self._connection.tpc_rollback(xid))
+
+    def tpc_recover(self):
+        return await_(self._connection.tpc_recover())
+
 
 class PsycopgAdaptDBAPI(AsyncAdapt_dbapi_module):
     def __init__(self, psycopg, ExecStatus) -> None:
index 2f886c9df612ee4d88bd25316424aef57acd59f0..bde9e1c93e6c4cca8c5dd4a7eeb3f8b47da2eeea 100644 (file)
@@ -794,35 +794,8 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg):
         else:
             cursor.executemany(statement, parameters)
 
-    def do_begin_twophase(self, connection, xid):
-        connection.connection.tpc_begin(xid)
-
-    def do_prepare_twophase(self, connection, xid):
-        connection.connection.tpc_prepare()
-
-    def _do_twophase(self, dbapi_conn, operation, xid, recover=False):
-        if recover:
-            if dbapi_conn.status != self._psycopg2_extensions.STATUS_READY:
-                dbapi_conn.rollback()
-            operation(xid)
-        else:
-            operation()
-
-    def do_rollback_twophase(
-        self, connection, xid, is_prepared=True, recover=False
-    ):
-        dbapi_conn = connection.connection.dbapi_connection
-        self._do_twophase(
-            dbapi_conn, dbapi_conn.tpc_rollback, xid, recover=recover
-        )
-
-    def do_commit_twophase(
-        self, connection, xid, is_prepared=True, recover=False
-    ):
-        dbapi_conn = connection.connection.dbapi_connection
-        self._do_twophase(
-            dbapi_conn, dbapi_conn.tpc_commit, xid, recover=recover
-        )
+    def _twophase_idle_check(self, dbapi_conn):
+        return dbapi_conn.status == self._psycopg2_extensions.STATUS_READY
 
     @util.memoized_instancemethod
     def _hstore_oids(self, dbapi_connection):
index 9aaf02f4a2dfbfdf8390e14e65c449594037fc42..7f8af56a8c4c41b18dfae10e7e5732097c4b8cc5 100644 (file)
@@ -967,7 +967,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         :meth:`~.TwoPhaseTransaction.prepare` method.
 
         :param xid: the two phase transaction id.  If not supplied, a
-          random id will be generated.
+          random id will be generated. The accepted type and value depends on
+          the driver in use.
 
         .. seealso::
 
index 66ef5ce189ea331df1c314cdf8b694825960e635..aba37ae14ac60e42b9d6e2350175c68038d0e1c3 100644 (file)
@@ -1565,8 +1565,9 @@ class ExecutionOptionsTest(fixtures.TestBase):
                 "hoho",
             )
             eng.update_execution_options(foo="hoho")
-            conn = eng.connect()
-            eq_(conn._execution_options["foo"], "hoho")
+            conn2 = eng.connect()
+            eq_(conn2._execution_options["foo"], "hoho")
+            conn2.close()
 
     def test_generative_engine_execution_options(self):
         eng = engines.testing_engine(
@@ -2873,6 +2874,7 @@ class EngineEventsTest(fixtures.TestBase):
                 "commit_twophase",
             ],
         )
+        conn.close()
 
 
 class HandleErrorTest(fixtures.TestBase):