]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve pg two-phase transactions
authorFederico Caselli <cfederico87@gmail.com>
Mon, 13 Apr 2026 21:53:00 +0000 (23:53 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 14 Apr 2026 19:48:15 +0000 (21:48 +0200)
Improve handling of two phase transaction identifiers for PostgreSQL
when the identifier is provided by the user.
As part of this change the psycopg dialect was updated to use the DBAPI
two phase transaction API instead of executing the SQL directly.

Fixes: #13229
Change-Id: If8301a7253b4a0c88e5323c9a052c3a9fa258780
(cherry picked from commit 08cef20f4a2bfbeda61abfe6caee975190f0794c)

doc/build/changelog/unreleased_20/13229.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/_psycopg_common.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/provision.py
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/engine/base.py
test/engine/test_execute.py

diff --git a/doc/build/changelog/unreleased_20/13229.rst b/doc/build/changelog/unreleased_20/13229.rst
new file mode 100644 (file)
index 0000000..e02886b
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: postgresql, bug
+    :tickets: 13229
+
+    Improve handling of two phase transaction identifiers for PostgreSQL
+    when the identifier is provided by the user.
+    As part of this change the psycopg dialect was updated to use the DBAPI
+    two phase transaction API instead of executing the SQL directly.
index 03b0f76ec3e76924888b3b8a7ebcc2726e61b795..dcd93b12ccec5b8d1afcbb1038b39d086be16591 100644 (file)
@@ -187,3 +187,39 @@ class _PGDialect_common_psycopg(PGDialect):
                 dbapi_connection.autocommit = before_autocommit
 
         return True
+
+    def do_begin_twophase(self, connection, xid):
+        connection.connection.tpc_begin(xid)
+
+    def do_prepare_twophase(self, connection, xid):
+        connection.connection.tpc_prepare()
+
+    def _do_twophase(self, dbapi_conn, operation, xid, recover=False):
+        if recover:
+            if not self._twophase_idle_check(dbapi_conn):
+                dbapi_conn.rollback()
+            operation(xid)
+        else:
+            operation()
+
+    def _twophase_idle_check(self, dbapi_conn):
+        raise NotImplementedError
+
+    def do_rollback_twophase(
+        self, connection, xid, is_prepared=True, recover=False
+    ):
+        dbapi_conn = connection.connection.dbapi_connection
+        self._do_twophase(
+            dbapi_conn, dbapi_conn.tpc_rollback, xid, recover=recover
+        )
+
+    def do_commit_twophase(
+        self, connection, xid, is_prepared=True, recover=False
+    ):
+        dbapi_conn = connection.connection.dbapi_connection
+        self._do_twophase(
+            dbapi_conn, dbapi_conn.tpc_commit, xid, recover=recover
+        )
+
+    def do_recover_twophase(self, connection):
+        return [row[1] for row in connection.connection.tpc_recover()]
index e2b6f257e9c6cb448b77352fa21705c898bf727f..8637f35d4fba142ee97cbfc71b5bf7b0480e3137 100644 (file)
@@ -3478,7 +3478,11 @@ class PGDialect(default.DefaultDialect):
         self.do_begin(connection.connection)
 
     def do_prepare_twophase(self, connection, xid):
-        connection.exec_driver_sql("PREPARE TRANSACTION '%s'" % xid)
+        connection.execute(
+            sql.text("PREPARE TRANSACTION :xid'").bindparams(
+                sql.bindparam("xid", xid, literal_execute=True)
+            )
+        )
 
     def do_rollback_twophase(
         self, connection, xid, is_prepared=True, recover=False
@@ -3490,7 +3494,11 @@ class PGDialect(default.DefaultDialect):
                 # Must find out a way how to make the dbapi not
                 # open a transaction.
                 connection.exec_driver_sql("ROLLBACK")
-            connection.exec_driver_sql("ROLLBACK PREPARED '%s'" % xid)
+            connection.execute(
+                sql.text("ROLLBACK PREPARED :xid").bindparams(
+                    sql.bindparam("xid", xid, literal_execute=True)
+                )
+            )
             connection.exec_driver_sql("BEGIN")
             self.do_rollback(connection.connection)
         else:
@@ -3502,7 +3510,11 @@ class PGDialect(default.DefaultDialect):
         if is_prepared:
             if recover:
                 connection.exec_driver_sql("ROLLBACK")
-            connection.exec_driver_sql("COMMIT PREPARED '%s'" % xid)
+            connection.execute(
+                sql.text("COMMIT PREPARED :xid").bindparams(
+                    sql.bindparam("xid", xid, literal_execute=True)
+                )
+            )
             connection.exec_driver_sql("BEGIN")
             self.do_rollback(connection.connection)
         else:
index dfe0f627d2e7c0fe41bcdd7633d19c3b1403e6df..dfe67dce9cafdc66c8014550708402b8b14652a3 100644 (file)
@@ -95,9 +95,10 @@ def _postgresql_set_default_schema_on_connection(
 def drop_all_schema_objects_pre_tables(cfg, eng):
     with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
         for xid in conn.exec_driver_sql(
-            "select gid from pg_prepared_xacts"
+            "SELECT gid FROM pg_prepared_xacts "
+            "WHERE database = current_database()"
         ).scalars():
-            conn.exec_driver_sql("ROLLBACK PREPARED '%s'" % xid)
+            eng.dialect.do_rollback_twophase(conn, xid, recover=True)
 
 
 @drop_all_schema_objects_post_tables.for_db("postgresql")
index 67dc5ca86cf4237fd6b454dfa4b0d49ff18cf5b5..4af214dabbb66c5a4e5a7b0e5011e3b8f08814f1 100644 (file)
@@ -596,44 +596,12 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
                 return True
         return False
 
-    def _do_prepared_twophase(self, connection, command, recover=False):
-        dbapi_conn = connection.connection.dbapi_connection
-        if (
-            recover
-            # don't rely on psycopg providing enum symbols, compare with
-            # eq/ne
-            or dbapi_conn.info.transaction_status
-            != self._psycopg_TransactionStatus.IDLE
-        ):
-            dbapi_conn.rollback()
-        before_autocommit = dbapi_conn.autocommit
-        try:
-            if not before_autocommit:
-                self._do_autocommit(dbapi_conn, True)
-            dbapi_conn.execute(command)
-        finally:
-            if not before_autocommit:
-                self._do_autocommit(dbapi_conn, before_autocommit)
-
-    def do_rollback_twophase(
-        self, connection, xid, is_prepared=True, recover=False
-    ):
-        if is_prepared:
-            self._do_prepared_twophase(
-                connection, f"ROLLBACK PREPARED '{xid}'", recover=recover
-            )
-        else:
-            self.do_rollback(connection.connection)
-
-    def do_commit_twophase(
-        self, connection, xid, is_prepared=True, recover=False
-    ):
-        if is_prepared:
-            self._do_prepared_twophase(
-                connection, f"COMMIT PREPARED '{xid}'", recover=recover
-            )
-        else:
-            self.do_commit(connection.connection)
+    def _twophase_idle_check(self, dbapi_conn):
+        # don't rely on psycopg providing enum symbols, compare with eq/ne
+        return (
+            dbapi_conn.info.transaction_status
+            == self._psycopg_TransactionStatus.IDLE
+        )
 
     @util.memoized_property
     def _dialect_specific_select_one(self):
@@ -784,6 +752,21 @@ class AsyncAdapt_psycopg_connection(AdaptedConnection):
     def set_deferrable(self, value):
         self.await_(self._connection.set_deferrable(value))
 
+    def tpc_begin(self, xid):
+        return self.await_(self._connection.tpc_begin(xid))
+
+    def tpc_prepare(self):
+        return self.await_(self._connection.tpc_prepare())
+
+    def tpc_commit(self, xid=None):
+        return self.await_(self._connection.tpc_commit(xid))
+
+    def tpc_rollback(self, xid=None):
+        return self.await_(self._connection.tpc_rollback(xid))
+
+    def tpc_recover(self):
+        return self.await_(self._connection.tpc_recover())
+
 
 class AsyncAdaptFallback_psycopg_connection(AsyncAdapt_psycopg_connection):
     __slots__ = ()
index 189e6566cfca85d481d5840f5a39b82ca3f42612..48ed9bf1e136f01ec52bc0b860b88dbb367cc0be 100644 (file)
@@ -799,35 +799,8 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg):
         else:
             cursor.executemany(statement, parameters)
 
-    def do_begin_twophase(self, connection, xid):
-        connection.connection.tpc_begin(xid)
-
-    def do_prepare_twophase(self, connection, xid):
-        connection.connection.tpc_prepare()
-
-    def _do_twophase(self, dbapi_conn, operation, xid, recover=False):
-        if recover:
-            if dbapi_conn.status != self._psycopg2_extensions.STATUS_READY:
-                dbapi_conn.rollback()
-            operation(xid)
-        else:
-            operation()
-
-    def do_rollback_twophase(
-        self, connection, xid, is_prepared=True, recover=False
-    ):
-        dbapi_conn = connection.connection.dbapi_connection
-        self._do_twophase(
-            dbapi_conn, dbapi_conn.tpc_rollback, xid, recover=recover
-        )
-
-    def do_commit_twophase(
-        self, connection, xid, is_prepared=True, recover=False
-    ):
-        dbapi_conn = connection.connection.dbapi_connection
-        self._do_twophase(
-            dbapi_conn, dbapi_conn.tpc_commit, xid, recover=recover
-        )
+    def _twophase_idle_check(self, dbapi_conn):
+        return dbapi_conn.status == self._psycopg2_extensions.STATUS_READY
 
     @util.memoized_instancemethod
     def _hstore_oids(self, dbapi_connection):
index 26f162b88c2ac44d5bfe1cb54f2f16c40d6fbc2e..e3b71fe5b53fa6334d752814080ab2fdc78706f4 100644 (file)
@@ -951,7 +951,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         :meth:`~.TwoPhaseTransaction.prepare` method.
 
         :param xid: the two phase transaction id.  If not supplied, a
-          random id will be generated.
+          random id will be generated. The accepted type and value depends on
+          the driver in use.
 
         .. seealso::
 
index 4c6bf8a28fb3e34941150bc08a4b563670cb70c0..e96cfd06dc2bd993a7ce025a4b603120a9616555 100644 (file)
@@ -1502,8 +1502,9 @@ class ExecutionOptionsTest(fixtures.TestBase):
                 "hoho",
             )
             eng.update_execution_options(foo="hoho")
-            conn = eng.connect()
-            eq_(conn._execution_options["foo"], "hoho")
+            conn2 = eng.connect()
+            eq_(conn2._execution_options["foo"], "hoho")
+            conn2.close()
 
     def test_generative_engine_execution_options(self):
         eng = engines.testing_engine(
@@ -2811,6 +2812,7 @@ class EngineEventsTest(fixtures.TestBase):
                 "commit_twophase",
             ],
         )
+        conn.close()
 
 
 class HandleErrorTest(fixtures.TestBase):