--- /dev/null
+"""
+psycopg two-phase commit support
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import re
+import datetime as dt
+from base64 import b64encode, b64decode
+from typing import Optional, Union
+from dataclasses import dataclass, replace
+
+_re_xid = re.compile(r"^(\d+)_([^_]*)_([^_]*)$")
+
+
+@dataclass(frozen=True)
+class Xid:
+ """A two-phase commit transaction identifier."""
+
+ format_id: Optional[int]
+ gtrid: str
+ bqual: Optional[str]
+ prepared: Optional[dt.datetime] = None
+ owner: Optional[str] = None
+ database: Optional[str] = None
+
+ @classmethod
+ def from_string(cls, s: str) -> "Xid":
+ """Try to parse an XA triple from the string.
+
+ This may fail for several reasons. In such case return an unparsed Xid.
+ """
+ try:
+ return cls._parse_string(s)
+ except Exception:
+ return Xid(None, s, None)
+
+ def __str__(self) -> str:
+ return self._as_tid()
+
+ def __len__(self) -> int:
+ return 3
+
+ def __getitem__(self, index: int) -> Union[int, str, None]:
+ return (self.format_id, self.gtrid, self.bqual)[index]
+
+ @classmethod
+ def _parse_string(cls, s: str) -> "Xid":
+ m = _re_xid.match(s)
+ if not m:
+ raise ValueError("bad Xid format")
+
+ format_id = int(m.group(1))
+ gtrid = b64decode(m.group(2)).decode()
+ bqual = b64decode(m.group(3)).decode()
+ return cls.from_parts(format_id, gtrid, bqual)
+
+ @classmethod
+ def from_parts(
+ cls, format_id: Optional[int], gtrid: str, bqual: Optional[str]
+ ) -> "Xid":
+ if format_id is not None:
+ if bqual is None:
+ raise TypeError("if format_id is specified, bqual must be too")
+ if not 0 <= format_id < 0x80000000:
+ raise ValueError(
+ "format_id must be a non-negative 32-bit integer"
+ )
+ if len(bqual) > 64:
+ raise ValueError("bqual must be not longer than 64 chars")
+ if len(gtrid) > 64:
+ raise ValueError("gtrid must be not longer than 64 chars")
+
+ elif bqual is None:
+ raise TypeError("if format_id is None, bqual must be None too")
+
+ return Xid(format_id, gtrid, bqual)
+
+ def _as_tid(self) -> str:
+ """
+ Return the PostgreSQL transaction_id for this XA xid.
+
+ PostgreSQL wants just a string, while the DBAPI supports the XA
+ standard and thus a triple. We use the same conversion algorithm
+ implemented by JDBC in order to allow some form of interoperation.
+
+ see also: the pgjdbc implementation
+ http://cvs.pgfoundry.org/cgi-bin/cvsweb.cgi/jdbc/pgjdbc/org/
+ postgresql/xa/RecoveredXid.java?rev=1.2
+ """
+ if self.format_id is None or self.bqual is None:
+ # Unparsed xid: return the gtrid.
+ return self.gtrid
+
+ # XA xid: mash together the components.
+ egtrid = b64encode(self.gtrid.encode()).decode()
+ ebqual = b64encode(self.bqual.encode()).decode()
+
+ return f"{self.format_id}_{egtrid}_{ebqual}"
+
+ @classmethod
+ def _get_recover_query(cls) -> str:
+ return "SELECT gid, prepared, owner, database FROM pg_prepared_xacts"
+
+ @classmethod
+ def _from_record(
+ cls, gid: str, prepared: dt.datetime, owner: str, database: str
+ ) -> "Xid":
+ xid = Xid.from_string(gid)
+ return replace(xid, prepared=prepared, owner=owner, database=database)
+
+
+Xid.__module__ = "psycopg"
import threading
from types import TracebackType
from typing import Any, Callable, cast, Dict, Generator, Generic, Iterator
-from typing import List, NamedTuple, Optional, Type, TypeVar, Union
+from typing import List, NamedTuple, Optional, Type, TypeVar, Tuple, Union
from typing import overload, TYPE_CHECKING
from weakref import ref, ReferenceType
from functools import partial
from .pq import ConnStatus, ExecStatus, TransactionStatus, Format
from .abc import AdaptContext, ConnectionType, Params, Query, RV
from .abc import PQGen, PQGenConn
-from .sql import Composable
-from .rows import Row, RowFactory, tuple_row, TupleRow
+from .sql import Composable, SQL
+from ._tpc import Xid
+from .rows import Row, RowFactory, tuple_row, TupleRow, args_row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
from .cursor import Cursor
self._closed = False # closed by an explicit close()
self._prepared: PrepareManager = PrepareManager()
+ self._tpc: Optional[Tuple[Xid, bool]] = None # xid, prepared
wself = ref(self)
pgconn.notice_handler = partial(BaseConnection._notice_handler, wself)
def cancel(self) -> None:
"""Cancel the current operation on the connection."""
+ if self._tpc and self._tpc[1]:
+ raise e.ProgrammingError(
+ "cancel() cannot be used with a prepared two-phase transaction"
+ )
+
c = self.pgconn.get_cancel()
c.cancel()
"context. (Transaction will be automatically committed "
"on successful exit from context.)"
)
+ if self._tpc:
+ raise e.ProgrammingError(
+ "commit() cannot be used during a two-phase transaction"
+ )
if self.pgconn.transaction_status == TransactionStatus.IDLE:
return
"context. (Either raise Rollback() or allow "
"an exception to propagate out of the context.)"
)
+ if self._tpc:
+ raise e.ProgrammingError(
+ "rollback() cannot be used during a two-phase transaction"
+ )
if self.pgconn.transaction_status == TransactionStatus.IDLE:
return
if cmd:
yield from self._exec_command(cmd)
+ def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid:
+ return Xid.from_parts(format_id, gtrid, bqual)
+
+ def _tpc_begin_gen(self, xid: Union[Xid, str]) -> PQGen[None]:
+ if not isinstance(xid, Xid):
+ xid = Xid.from_string(xid)
+
+ if self.pgconn.transaction_status != TransactionStatus.IDLE:
+ raise e.ProgrammingError(
+ f"can't start two-phase transaction: connection in status"
+ f" {TransactionStatus(self.pgconn.transaction_status).name}"
+ )
+
+ if self._autocommit:
+ raise e.ProgrammingError(
+ "can't use two-phase transctions in autocommit mode"
+ )
+
+ self._tpc = (xid, False)
+ yield from self._exec_command(self._get_tx_start_command())
+
+ def _tpc_prepare_gen(self) -> PQGen[None]:
+ if not self._tpc:
+ raise e.ProgrammingError(
+ "'tpc_prepare()' must be called inside a two-phase transaction"
+ )
+ if self._tpc[1]:
+ raise e.ProgrammingError(
+ "'tpc_prepare()' cannot be used during a prepared"
+ " two-phase transaction"
+ )
+ xid = self._tpc[0]
+ self._tpc = (xid, True)
+ yield from self._exec_command(
+ SQL("PREPARE TRANSACTION {}").format(str(xid))
+ )
+
+ def _tpc_finish_gen(
+ self, action: str, xid: Union[Xid, str, None]
+ ) -> PQGen[None]:
+ fname = f"tpc_{action}()"
+ if xid is None:
+ if not self._tpc:
+ raise e.ProgrammingError(
+ f"{fname} without xid must must be"
+ " called inside a two-phase transaction"
+ )
+ xid = self._tpc[0]
+ else:
+ if self._tpc:
+ raise e.ProgrammingError(
+ f"{fname} with xid must must be called"
+ " outside a two-phase transaction"
+ )
+ if not isinstance(xid, Xid):
+ xid = Xid.from_string(xid)
+
+ if self._tpc and not self._tpc[1]:
+ meth: Callable[[], PQGen[None]]
+ meth = getattr(self, f"_{action}_gen")
+ self._tpc = None
+ yield from meth()
+ else:
+ yield from self._exec_command(
+ SQL("{} PREPARED {}").format(SQL(action.upper()), str(xid))
+ )
+ self._tpc = None
+
class Connection(BaseConnection[Row]):
"""
def _set_deferrable(self, value: Optional[bool]) -> None:
with self.lock:
super()._set_deferrable(value)
+
+ def tpc_begin(self, xid: Union[Xid, str]) -> None:
+ with self.lock:
+ self.wait(self._tpc_begin_gen(xid))
+
+ def tpc_prepare(self) -> None:
+ try:
+ with self.lock:
+ self.wait(self._tpc_prepare_gen())
+ except e.ObjectNotInPrerequisiteState as ex:
+ raise e.NotSupportedError(str(ex)) from None
+
+ def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None:
+ with self.lock:
+ self.wait(self._tpc_finish_gen("commit", xid))
+
+ def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None:
+ with self.lock:
+ self.wait(self._tpc_finish_gen("rollback", xid))
+
+ def tpc_recover(self) -> List[Xid]:
+ status = self.info.transaction_status
+ with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
+ cur.execute(Xid._get_recover_query())
+ res = cur.fetchall()
+
+ if (
+ status == TransactionStatus.IDLE
+ and self.info.transaction_status == TransactionStatus.INTRANS
+ ):
+ self.rollback()
+
+ return res
--- /dev/null
+from operator import attrgetter
+
+import pytest
+
+import psycopg
+from psycopg import sql
+
+
+def test_tpc_disabled(conn):
+ val = int(conn.execute("show max_prepared_transactions").fetchone()[0])
+ if val:
+ pytest.skip("prepared transactions enabled")
+
+ conn.rollback()
+ conn.tpc_begin("x")
+ with pytest.raises(psycopg.NotSupportedError):
+ conn.tpc_prepare()
+
+
+class TestTPC:
+ def test_tpc_commit(self, conn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_commit')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_prepare()
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_commit()
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ def test_tpc_commit_one_phase(self, conn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_commit_1p')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_commit()
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ def test_tpc_commit_recovered(self, conn, dsn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_prepare()
+ conn.close()
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ conn = psycopg.connect(dsn)
+ xid = conn.xid(1, "gtrid", "bqual")
+ conn.tpc_commit(xid)
+
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ def test_tpc_rollback(self, conn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_rollback')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_prepare()
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_rollback()
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ def test_tpc_rollback_one_phase(self, conn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_rollback_1p')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_rollback()
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ def test_tpc_rollback_recovered(self, conn, dsn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_prepare()
+ conn.close()
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ conn = psycopg.connect(dsn)
+ xid = conn.xid(1, "gtrid", "bqual")
+ conn.tpc_rollback(xid)
+
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ def test_status_after_recover(self, conn, tpc):
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+ conn.tpc_recover()
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+ cur = conn.cursor()
+ cur.execute("select 1")
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+ conn.tpc_recover()
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+ def test_recovered_xids(self, conn, tpc):
+ # insert a few test xns
+ conn.autocommit = True
+ cur = conn.cursor()
+ cur.execute("begin; prepare transaction '1-foo'")
+ cur.execute("begin; prepare transaction '2-bar'")
+
+ # read the values to return
+ cur.execute(
+ """
+ select gid, prepared, owner, database from pg_prepared_xacts
+ where database = %s
+ """,
+ (conn.info.dbname,),
+ )
+ okvals = cur.fetchall()
+ okvals.sort()
+
+ xids = conn.tpc_recover()
+ xids = [xid for xid in xids if xid.database == conn.info.dbname]
+ xids.sort(key=attrgetter("gtrid"))
+
+ # check the values returned
+ assert len(okvals) == len(xids)
+ for (xid, (gid, prepared, owner, database)) in zip(xids, okvals):
+ assert xid.gtrid == gid
+ assert xid.prepared == prepared
+ assert xid.owner == owner
+ assert xid.database == database
+
+ def test_xid_encoding(self, conn, tpc):
+ xid = conn.xid(42, "gtrid", "bqual")
+ conn.tpc_begin(xid)
+ conn.tpc_prepare()
+
+ cur = conn.cursor()
+ cur.execute(
+ "select gid from pg_prepared_xacts where database = %s",
+ (conn.info.dbname,),
+ )
+ assert "42_Z3RyaWQ=_YnF1YWw=" == cur.fetchone()[0]
+
+ @pytest.mark.parametrize(
+ "fid, gtrid, bqual",
+ [
+ (0, "", ""),
+ (42, "gtrid", "bqual"),
+ (0x7FFFFFFF, "x" * 64, "y" * 64),
+ ],
+ )
+ def test_xid_roundtrip(self, conn, dsn, tpc, fid, gtrid, bqual):
+ xid = conn.xid(fid, gtrid, bqual)
+ conn.tpc_begin(xid)
+ conn.tpc_prepare()
+ conn.close()
+
+ conn = psycopg.connect(dsn)
+ xids = [
+ x for x in conn.tpc_recover() if x.database == conn.info.dbname
+ ]
+ assert len(xids) == 1
+ xid = xids[0]
+ assert xid.format_id == fid
+ assert xid.gtrid == gtrid
+ assert xid.bqual == bqual
+
+ conn.tpc_rollback(xid)
+
+ @pytest.mark.parametrize(
+ "tid",
+ [
+ "",
+ "hello, world!",
+ "x" * 199, # PostgreSQL's limit in transaction id length
+ ],
+ )
+ def test_unparsed_roundtrip(self, conn, dsn, tpc, tid):
+ conn.tpc_begin(tid)
+ conn.tpc_prepare()
+ conn.close()
+
+ conn = psycopg.connect(dsn)
+ xids = [
+ x for x in conn.tpc_recover() if x.database == conn.info.dbname
+ ]
+ assert len(xids) == 1
+ xid = xids[0]
+ assert xid.format_id is None
+ assert xid.gtrid == tid
+ assert xid.bqual is None
+
+ conn.tpc_rollback(xid)
+
+ def test_xid_unicode(self, conn, dsn, tpc):
+ x1 = conn.xid(10, "uni", "code")
+ conn.tpc_begin(x1)
+ conn.tpc_prepare()
+ conn.close()
+
+ conn = psycopg.connect(dsn)
+ xid = [
+ x for x in conn.tpc_recover() if x.database == conn.info.dbname
+ ][0]
+ assert 10 == xid.format_id
+ assert "uni" == xid.gtrid
+ assert "code" == xid.bqual
+
+ def test_xid_unicode_unparsed(self, conn, dsn, tpc):
+ # We don't expect people shooting snowmen as transaction ids,
+ # so if something explodes in an encode error I don't mind.
+ # Let's just check unicode is accepted as type.
+ conn.execute("set client_encoding to utf8")
+ conn.commit()
+
+ conn.tpc_begin("transaction-id")
+ conn.tpc_prepare()
+ conn.close()
+
+ conn = psycopg.connect(dsn)
+ xid = [
+ x for x in conn.tpc_recover() if x.database == conn.info.dbname
+ ][0]
+ assert xid.format_id is None
+ assert xid.gtrid == "transaction-id"
+ assert xid.bqual is None
+
+ def test_cancel_fails_prepared(self, conn, tpc):
+ conn.tpc_begin("cancel")
+ conn.tpc_prepare()
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.cancel()
+
+ def test_tpc_recover_non_dbapi_connection(self, conn, dsn, tpc):
+ conn.row_factory = psycopg.rows.dict_row
+ conn.tpc_begin("dict-connection")
+ conn.tpc_prepare()
+ conn.close()
+
+ conn = psycopg.connect(dsn)
+ xids = conn.tpc_recover()
+ xid = [x for x in xids if x.database == conn.info.dbname][0]
+ assert xid.format_id is None
+ assert xid.gtrid == "dict-connection"
+ assert xid.bqual is None
+
+
+class TestXidObject:
+ def test_xid_construction(self):
+ x1 = psycopg.Xid(74, "foo", "bar")
+ 74 == x1.format_id
+ "foo" == x1.gtrid
+ "bar" == x1.bqual
+
+ def test_xid_from_string(self):
+ x2 = psycopg.Xid.from_string("42_Z3RyaWQ=_YnF1YWw=")
+ 42 == x2.format_id
+ "gtrid" == x2.gtrid
+ "bqual" == x2.bqual
+
+ x3 = psycopg.Xid.from_string("99_xxx_yyy")
+ None is x3.format_id
+ "99_xxx_yyy" == x3.gtrid
+ None is x3.bqual
+
+ def test_xid_to_string(self):
+ x1 = psycopg.Xid.from_string("42_Z3RyaWQ=_YnF1YWw=")
+ str(x1) == "42_Z3RyaWQ=_YnF1YWw="
+
+ x2 = psycopg.Xid.from_string("99_xxx_yyy")
+ str(x2) == "99_xxx_yyy"
+
+
+@pytest.fixture
+def tpc(svcconn):
+ tpc = Tpc(svcconn)
+ tpc.check_tpc()
+ tpc.clear_test_xacts()
+ tpc.make_test_table()
+ yield tpc
+ tpc.clear_test_xacts()
+
+
+class Tpc:
+ """Helper object to test two-phase transactions"""
+
+ def __init__(self, conn):
+ assert conn.autocommit
+ self.conn = conn
+
+ def check_tpc(self):
+ val = int(
+ self.conn.execute("show max_prepared_transactions").fetchone()[0]
+ )
+ if not val:
+ pytest.skip("prepared transactions disabled in the database")
+
+ def clear_test_xacts(self):
+ """Rollback all the prepared transaction in the testing db."""
+ cur = self.conn.execute(
+ "select gid from pg_prepared_xacts where database = %s",
+ (self.conn.info.dbname,),
+ )
+ gids = [r[0] for r in cur]
+ for gid in gids:
+ self.conn.execute(sql.SQL("rollback prepared {}").format(gid))
+
+ def make_test_table(self):
+ self.conn.execute("CREATE TABLE IF NOT EXISTS test_tpc (data text)")
+ self.conn.execute("TRUNCATE test_tpc")
+
+ def count_xacts(self):
+ """Return the number of prepared xacts currently in the test db."""
+ cur = self.conn.execute(
+ """
+ select count(*) from pg_prepared_xacts
+ where database = %s""",
+ (self.conn.info.dbname,),
+ )
+ return cur.fetchone()[0]
+
+ def count_test_records(self):
+ """Return the number of records in the test table."""
+ cur = self.conn.execute("select count(*) from test_tpc")
+ return cur.fetchone()[0]