]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add two-phase commit DBAPI Connection methods
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 26 Oct 2021 23:50:45 +0000 (01:50 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 28 Nov 2021 17:04:31 +0000 (18:04 +0100)
psycopg/psycopg/__init__.py
psycopg/psycopg/_tpc.py [new file with mode: 0644]
psycopg/psycopg/connection.py
psycopg/setup.cfg
tests/test_psycopg_dbapi20.py
tests/test_tpc.py [new file with mode: 0644]

index 20d675e25832cd9793c5f67b4cb83f944dbf455c..1a5c5687c06706d1636b044c09e925631dab48da 100644 (file)
@@ -9,6 +9,7 @@ import logging
 from . import pq  # noqa: F401 import early to stabilize side effects
 from . import types
 from . import postgres
+from ._tpc import Xid
 from .copy import Copy, AsyncCopy
 from ._enums import IsolationLevel
 from .cursor import Cursor
@@ -71,6 +72,7 @@ __all__ = [
     "Rollback",
     "ServerCursor",
     "Transaction",
+    "Xid",
     # DBAPI exports
     "connect",
     "apilevel",
diff --git a/psycopg/psycopg/_tpc.py b/psycopg/psycopg/_tpc.py
new file mode 100644 (file)
index 0000000..1c778c1
--- /dev/null
@@ -0,0 +1,113 @@
+"""
+psycopg two-phase commit support
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import re
+import datetime as dt
+from base64 import b64encode, b64decode
+from typing import Optional, Union
+from dataclasses import dataclass, replace
+
+_re_xid = re.compile(r"^(\d+)_([^_]*)_([^_]*)$")
+
+
+@dataclass(frozen=True)
+class Xid:
+    """A two-phase commit transaction identifier."""
+
+    format_id: Optional[int]
+    gtrid: str
+    bqual: Optional[str]
+    prepared: Optional[dt.datetime] = None
+    owner: Optional[str] = None
+    database: Optional[str] = None
+
+    @classmethod
+    def from_string(cls, s: str) -> "Xid":
+        """Try to parse an XA triple from the string.
+
+        This may fail for several reasons. In such case return an unparsed Xid.
+        """
+        try:
+            return cls._parse_string(s)
+        except Exception:
+            return Xid(None, s, None)
+
+    def __str__(self) -> str:
+        return self._as_tid()
+
+    def __len__(self) -> int:
+        return 3
+
+    def __getitem__(self, index: int) -> Union[int, str, None]:
+        return (self.format_id, self.gtrid, self.bqual)[index]
+
+    @classmethod
+    def _parse_string(cls, s: str) -> "Xid":
+        m = _re_xid.match(s)
+        if not m:
+            raise ValueError("bad Xid format")
+
+        format_id = int(m.group(1))
+        gtrid = b64decode(m.group(2)).decode()
+        bqual = b64decode(m.group(3)).decode()
+        return cls.from_parts(format_id, gtrid, bqual)
+
+    @classmethod
+    def from_parts(
+        cls, format_id: Optional[int], gtrid: str, bqual: Optional[str]
+    ) -> "Xid":
+        if format_id is not None:
+            if bqual is None:
+                raise TypeError("if format_id is specified, bqual must be too")
+            if not 0 <= format_id < 0x80000000:
+                raise ValueError(
+                    "format_id must be a non-negative 32-bit integer"
+                )
+            if len(bqual) > 64:
+                raise ValueError("bqual must be not longer than 64 chars")
+            if len(gtrid) > 64:
+                raise ValueError("gtrid must be not longer than 64 chars")
+
+        elif bqual is None:
+            raise TypeError("if format_id is None, bqual must be None too")
+
+        return Xid(format_id, gtrid, bqual)
+
+    def _as_tid(self) -> str:
+        """
+        Return the PostgreSQL transaction_id for this XA xid.
+
+        PostgreSQL wants just a string, while the DBAPI supports the XA
+        standard and thus a triple. We use the same conversion algorithm
+        implemented by JDBC in order to allow some form of interoperation.
+
+        see also: the pgjdbc implementation
+          http://cvs.pgfoundry.org/cgi-bin/cvsweb.cgi/jdbc/pgjdbc/org/
+            postgresql/xa/RecoveredXid.java?rev=1.2
+        """
+        if self.format_id is None or self.bqual is None:
+            # Unparsed xid: return the gtrid.
+            return self.gtrid
+
+        # XA xid: mash together the components.
+        egtrid = b64encode(self.gtrid.encode()).decode()
+        ebqual = b64encode(self.bqual.encode()).decode()
+
+        return f"{self.format_id}_{egtrid}_{ebqual}"
+
+    @classmethod
+    def _get_recover_query(cls) -> str:
+        return "SELECT gid, prepared, owner, database FROM pg_prepared_xacts"
+
+    @classmethod
+    def _from_record(
+        cls, gid: str, prepared: dt.datetime, owner: str, database: str
+    ) -> "Xid":
+        xid = Xid.from_string(gid)
+        return replace(xid, prepared=prepared, owner=owner, database=database)
+
+
+Xid.__module__ = "psycopg"
index ad6a5a9008fd8b37e79ab1dce71ddbb3dbb32423..8a7ce44503231212b8ce64430312eaef5525fe4b 100644 (file)
@@ -9,7 +9,7 @@ import warnings
 import threading
 from types import TracebackType
 from typing import Any, Callable, cast, Dict, Generator, Generic, Iterator
-from typing import List, NamedTuple, Optional, Type, TypeVar, Union
+from typing import List, NamedTuple, Optional, Type, TypeVar, Tuple, Union
 from typing import overload, TYPE_CHECKING
 from weakref import ref, ReferenceType
 from functools import partial
@@ -22,8 +22,9 @@ from . import postgres
 from .pq import ConnStatus, ExecStatus, TransactionStatus, Format
 from .abc import AdaptContext, ConnectionType, Params, Query, RV
 from .abc import PQGen, PQGenConn
-from .sql import Composable
-from .rows import Row, RowFactory, tuple_row, TupleRow
+from .sql import Composable, SQL
+from ._tpc import Xid
+from .rows import Row, RowFactory, tuple_row, TupleRow, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from .cursor import Cursor
@@ -116,6 +117,7 @@ class BaseConnection(Generic[Row]):
 
         self._closed = False  # closed by an explicit close()
         self._prepared: PrepareManager = PrepareManager()
+        self._tpc: Optional[Tuple[Xid, bool]] = None  # xid, prepared
 
         wself = ref(self)
         pgconn.notice_handler = partial(BaseConnection._notice_handler, wself)
@@ -284,6 +286,11 @@ class BaseConnection(Generic[Row]):
 
     def cancel(self) -> None:
         """Cancel the current operation on the connection."""
+        if self._tpc and self._tpc[1]:
+            raise e.ProgrammingError(
+                "cancel() cannot be used with a prepared two-phase transaction"
+            )
+
         c = self.pgconn.get_cancel()
         c.cancel()
 
@@ -477,6 +484,10 @@ class BaseConnection(Generic[Row]):
                 "context. (Transaction will be automatically committed "
                 "on successful exit from context.)"
             )
+        if self._tpc:
+            raise e.ProgrammingError(
+                "commit() cannot be used during a two-phase transaction"
+            )
         if self.pgconn.transaction_status == TransactionStatus.IDLE:
             return
 
@@ -490,6 +501,10 @@ class BaseConnection(Generic[Row]):
                 "context. (Either raise Rollback() or allow "
                 "an exception to propagate out of the context.)"
             )
+        if self._tpc:
+            raise e.ProgrammingError(
+                "rollback() cannot be used during a two-phase transaction"
+            )
         if self.pgconn.transaction_status == TransactionStatus.IDLE:
             return
 
@@ -498,6 +513,74 @@ class BaseConnection(Generic[Row]):
         if cmd:
             yield from self._exec_command(cmd)
 
+    def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid:
+        return Xid.from_parts(format_id, gtrid, bqual)
+
+    def _tpc_begin_gen(self, xid: Union[Xid, str]) -> PQGen[None]:
+        if not isinstance(xid, Xid):
+            xid = Xid.from_string(xid)
+
+        if self.pgconn.transaction_status != TransactionStatus.IDLE:
+            raise e.ProgrammingError(
+                f"can't start two-phase transaction: connection in status"
+                f" {TransactionStatus(self.pgconn.transaction_status).name}"
+            )
+
+        if self._autocommit:
+            raise e.ProgrammingError(
+                "can't use two-phase transctions in autocommit mode"
+            )
+
+        self._tpc = (xid, False)
+        yield from self._exec_command(self._get_tx_start_command())
+
+    def _tpc_prepare_gen(self) -> PQGen[None]:
+        if not self._tpc:
+            raise e.ProgrammingError(
+                "'tpc_prepare()' must be called inside a two-phase transaction"
+            )
+        if self._tpc[1]:
+            raise e.ProgrammingError(
+                "'tpc_prepare()' cannot be used during a prepared"
+                " two-phase transaction"
+            )
+        xid = self._tpc[0]
+        self._tpc = (xid, True)
+        yield from self._exec_command(
+            SQL("PREPARE TRANSACTION {}").format(str(xid))
+        )
+
+    def _tpc_finish_gen(
+        self, action: str, xid: Union[Xid, str, None]
+    ) -> PQGen[None]:
+        fname = f"tpc_{action}()"
+        if xid is None:
+            if not self._tpc:
+                raise e.ProgrammingError(
+                    f"{fname} without xid must must be"
+                    " called inside a two-phase transaction"
+                )
+            xid = self._tpc[0]
+        else:
+            if self._tpc:
+                raise e.ProgrammingError(
+                    f"{fname} with xid must must be called"
+                    " outside a two-phase transaction"
+                )
+            if not isinstance(xid, Xid):
+                xid = Xid.from_string(xid)
+
+        if self._tpc and not self._tpc[1]:
+            meth: Callable[[], PQGen[None]]
+            meth = getattr(self, f"_{action}_gen")
+            self._tpc = None
+            yield from meth()
+        else:
+            yield from self._exec_command(
+                SQL("{} PREPARED {}").format(SQL(action.upper()), str(xid))
+            )
+            self._tpc = None
+
 
 class Connection(BaseConnection[Row]):
     """
@@ -786,3 +869,36 @@ class Connection(BaseConnection[Row]):
     def _set_deferrable(self, value: Optional[bool]) -> None:
         with self.lock:
             super()._set_deferrable(value)
+
+    def tpc_begin(self, xid: Union[Xid, str]) -> None:
+        with self.lock:
+            self.wait(self._tpc_begin_gen(xid))
+
+    def tpc_prepare(self) -> None:
+        try:
+            with self.lock:
+                self.wait(self._tpc_prepare_gen())
+        except e.ObjectNotInPrerequisiteState as ex:
+            raise e.NotSupportedError(str(ex)) from None
+
+    def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None:
+        with self.lock:
+            self.wait(self._tpc_finish_gen("commit", xid))
+
+    def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None:
+        with self.lock:
+            self.wait(self._tpc_finish_gen("rollback", xid))
+
+    def tpc_recover(self) -> List[Xid]:
+        status = self.info.transaction_status
+        with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
+            cur.execute(Xid._get_recover_query())
+            res = cur.fetchall()
+
+        if (
+            status == TransactionStatus.IDLE
+            and self.info.transaction_status == TransactionStatus.INTRANS
+        ):
+            self.rollback()
+
+        return res
index 48082ff8f3084198366af86301e508969d301763..45f05f48d61b33b9f73274f86be4c7cd1b99ddfe 100644 (file)
@@ -40,6 +40,7 @@ packages = find:
 zip_safe = False
 install_requires =
     backports.zoneinfo; python_version < "3.9"
+    dataclasses; python_version < "3.7"
     typing_extensions; python_version < "3.8"
 
 [options.package_data]
index 04efa1ecb4ff404b1f923dd3405756e8c9554158..0a40e4a2035534b6c305762876c21f6631af19df 100644 (file)
@@ -8,6 +8,8 @@ from psycopg.conninfo import conninfo_to_dict
 from . import dbapi20
 from . import dbapi20_tpc
 
+from .test_tpc import tpc  # noqa F401  # fixture
+
 
 @pytest.fixture(scope="class")
 def with_dsn(request, dsn):
@@ -29,7 +31,7 @@ class PsycopgTests(dbapi20.DatabaseAPI20Test):
         pass
 
 
-# @skip_if_tpc_disabled
+@pytest.mark.usefixtures("tpc")
 @pytest.mark.usefixtures("with_dsn")
 class PsycopgTPCTests(dbapi20_tpc.TwoPhaseCommitTests):
     driver = psycopg
diff --git a/tests/test_tpc.py b/tests/test_tpc.py
new file mode 100644 (file)
index 0000000..127e70a
--- /dev/null
@@ -0,0 +1,383 @@
+from operator import attrgetter
+
+import pytest
+
+import psycopg
+from psycopg import sql
+
+
+def test_tpc_disabled(conn):
+    val = int(conn.execute("show max_prepared_transactions").fetchone()[0])
+    if val:
+        pytest.skip("prepared transactions enabled")
+
+    conn.rollback()
+    conn.tpc_begin("x")
+    with pytest.raises(psycopg.NotSupportedError):
+        conn.tpc_prepare()
+
+
+class TestTPC:
+    def test_tpc_commit(self, conn, tpc):
+        xid = conn.xid(1, "gtrid", "bqual")
+        assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+        conn.tpc_begin(xid)
+        assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+        cur = conn.cursor()
+        cur.execute("insert into test_tpc values ('test_tpc_commit')")
+        assert tpc.count_xacts() == 0
+        assert tpc.count_test_records() == 0
+
+        conn.tpc_prepare()
+        assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+        assert tpc.count_xacts() == 1
+        assert tpc.count_test_records() == 0
+
+        conn.tpc_commit()
+        assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+        assert tpc.count_xacts() == 0
+        assert tpc.count_test_records() == 1
+
+    def test_tpc_commit_one_phase(self, conn, tpc):
+        xid = conn.xid(1, "gtrid", "bqual")
+        assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+        conn.tpc_begin(xid)
+        assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+        cur = conn.cursor()
+        cur.execute("insert into test_tpc values ('test_tpc_commit_1p')")
+        assert tpc.count_xacts() == 0
+        assert tpc.count_test_records() == 0
+
+        conn.tpc_commit()
+        assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+        assert tpc.count_xacts() == 0
+        assert tpc.count_test_records() == 1
+
+    def test_tpc_commit_recovered(self, conn, dsn, tpc):
+        xid = conn.xid(1, "gtrid", "bqual")
+        assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+        conn.tpc_begin(xid)
+        assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+        cur = conn.cursor()
+        cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
+        assert tpc.count_xacts() == 0
+        assert tpc.count_test_records() == 0
+
+        conn.tpc_prepare()
+        conn.close()
+        assert tpc.count_xacts() == 1
+        assert tpc.count_test_records() == 0
+
+        conn = psycopg.connect(dsn)
+        xid = conn.xid(1, "gtrid", "bqual")
+        conn.tpc_commit(xid)
+
+        assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+        assert tpc.count_xacts() == 0
+        assert tpc.count_test_records() == 1
+
+    def test_tpc_rollback(self, conn, tpc):
+        xid = conn.xid(1, "gtrid", "bqual")
+        assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+        conn.tpc_begin(xid)
+        assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+        cur = conn.cursor()
+        cur.execute("insert into test_tpc values ('test_tpc_rollback')")
+        assert tpc.count_xacts() == 0
+        assert tpc.count_test_records() == 0
+
+        conn.tpc_prepare()
+        assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+        assert tpc.count_xacts() == 1
+        assert tpc.count_test_records() == 0
+
+        conn.tpc_rollback()
+        assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+        assert tpc.count_xacts() == 0
+        assert tpc.count_test_records() == 0
+
+    def test_tpc_rollback_one_phase(self, conn, tpc):
+        xid = conn.xid(1, "gtrid", "bqual")
+        assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+        conn.tpc_begin(xid)
+        assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+        cur = conn.cursor()
+        cur.execute("insert into test_tpc values ('test_tpc_rollback_1p')")
+        assert tpc.count_xacts() == 0
+        assert tpc.count_test_records() == 0
+
+        conn.tpc_rollback()
+        assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+        assert tpc.count_xacts() == 0
+        assert tpc.count_test_records() == 0
+
+    def test_tpc_rollback_recovered(self, conn, dsn, tpc):
+        xid = conn.xid(1, "gtrid", "bqual")
+        assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+        conn.tpc_begin(xid)
+        assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+        cur = conn.cursor()
+        cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
+        assert tpc.count_xacts() == 0
+        assert tpc.count_test_records() == 0
+
+        conn.tpc_prepare()
+        conn.close()
+        assert tpc.count_xacts() == 1
+        assert tpc.count_test_records() == 0
+
+        conn = psycopg.connect(dsn)
+        xid = conn.xid(1, "gtrid", "bqual")
+        conn.tpc_rollback(xid)
+
+        assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+        assert tpc.count_xacts() == 0
+        assert tpc.count_test_records() == 0
+
+    def test_status_after_recover(self, conn, tpc):
+        assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+        conn.tpc_recover()
+        assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+        cur = conn.cursor()
+        cur.execute("select 1")
+        assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+        conn.tpc_recover()
+        assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+    def test_recovered_xids(self, conn, tpc):
+        # insert a few test xns
+        conn.autocommit = True
+        cur = conn.cursor()
+        cur.execute("begin; prepare transaction '1-foo'")
+        cur.execute("begin; prepare transaction '2-bar'")
+
+        # read the values to return
+        cur.execute(
+            """
+            select gid, prepared, owner, database from pg_prepared_xacts
+            where database = %s
+            """,
+            (conn.info.dbname,),
+        )
+        okvals = cur.fetchall()
+        okvals.sort()
+
+        xids = conn.tpc_recover()
+        xids = [xid for xid in xids if xid.database == conn.info.dbname]
+        xids.sort(key=attrgetter("gtrid"))
+
+        # check the values returned
+        assert len(okvals) == len(xids)
+        for (xid, (gid, prepared, owner, database)) in zip(xids, okvals):
+            assert xid.gtrid == gid
+            assert xid.prepared == prepared
+            assert xid.owner == owner
+            assert xid.database == database
+
+    def test_xid_encoding(self, conn, tpc):
+        xid = conn.xid(42, "gtrid", "bqual")
+        conn.tpc_begin(xid)
+        conn.tpc_prepare()
+
+        cur = conn.cursor()
+        cur.execute(
+            "select gid from pg_prepared_xacts where database = %s",
+            (conn.info.dbname,),
+        )
+        assert "42_Z3RyaWQ=_YnF1YWw=" == cur.fetchone()[0]
+
+    @pytest.mark.parametrize(
+        "fid, gtrid, bqual",
+        [
+            (0, "", ""),
+            (42, "gtrid", "bqual"),
+            (0x7FFFFFFF, "x" * 64, "y" * 64),
+        ],
+    )
+    def test_xid_roundtrip(self, conn, dsn, tpc, fid, gtrid, bqual):
+        xid = conn.xid(fid, gtrid, bqual)
+        conn.tpc_begin(xid)
+        conn.tpc_prepare()
+        conn.close()
+
+        conn = psycopg.connect(dsn)
+        xids = [
+            x for x in conn.tpc_recover() if x.database == conn.info.dbname
+        ]
+        assert len(xids) == 1
+        xid = xids[0]
+        assert xid.format_id == fid
+        assert xid.gtrid == gtrid
+        assert xid.bqual == bqual
+
+        conn.tpc_rollback(xid)
+
+    @pytest.mark.parametrize(
+        "tid",
+        [
+            "",
+            "hello, world!",
+            "x" * 199,  # PostgreSQL's limit in transaction id length
+        ],
+    )
+    def test_unparsed_roundtrip(self, conn, dsn, tpc, tid):
+        conn.tpc_begin(tid)
+        conn.tpc_prepare()
+        conn.close()
+
+        conn = psycopg.connect(dsn)
+        xids = [
+            x for x in conn.tpc_recover() if x.database == conn.info.dbname
+        ]
+        assert len(xids) == 1
+        xid = xids[0]
+        assert xid.format_id is None
+        assert xid.gtrid == tid
+        assert xid.bqual is None
+
+        conn.tpc_rollback(xid)
+
+    def test_xid_unicode(self, conn, dsn, tpc):
+        x1 = conn.xid(10, "uni", "code")
+        conn.tpc_begin(x1)
+        conn.tpc_prepare()
+        conn.close()
+
+        conn = psycopg.connect(dsn)
+        xid = [
+            x for x in conn.tpc_recover() if x.database == conn.info.dbname
+        ][0]
+        assert 10 == xid.format_id
+        assert "uni" == xid.gtrid
+        assert "code" == xid.bqual
+
+    def test_xid_unicode_unparsed(self, conn, dsn, tpc):
+        # We don't expect people shooting snowmen as transaction ids,
+        # so if something explodes in an encode error I don't mind.
+        # Let's just check unicode is accepted as type.
+        conn.execute("set client_encoding to utf8")
+        conn.commit()
+
+        conn.tpc_begin("transaction-id")
+        conn.tpc_prepare()
+        conn.close()
+
+        conn = psycopg.connect(dsn)
+        xid = [
+            x for x in conn.tpc_recover() if x.database == conn.info.dbname
+        ][0]
+        assert xid.format_id is None
+        assert xid.gtrid == "transaction-id"
+        assert xid.bqual is None
+
+    def test_cancel_fails_prepared(self, conn, tpc):
+        conn.tpc_begin("cancel")
+        conn.tpc_prepare()
+        with pytest.raises(psycopg.ProgrammingError):
+            conn.cancel()
+
+    def test_tpc_recover_non_dbapi_connection(self, conn, dsn, tpc):
+        conn.row_factory = psycopg.rows.dict_row
+        conn.tpc_begin("dict-connection")
+        conn.tpc_prepare()
+        conn.close()
+
+        conn = psycopg.connect(dsn)
+        xids = conn.tpc_recover()
+        xid = [x for x in xids if x.database == conn.info.dbname][0]
+        assert xid.format_id is None
+        assert xid.gtrid == "dict-connection"
+        assert xid.bqual is None
+
+
+class TestXidObject:
+    def test_xid_construction(self):
+        x1 = psycopg.Xid(74, "foo", "bar")
+        74 == x1.format_id
+        "foo" == x1.gtrid
+        "bar" == x1.bqual
+
+    def test_xid_from_string(self):
+        x2 = psycopg.Xid.from_string("42_Z3RyaWQ=_YnF1YWw=")
+        42 == x2.format_id
+        "gtrid" == x2.gtrid
+        "bqual" == x2.bqual
+
+        x3 = psycopg.Xid.from_string("99_xxx_yyy")
+        None is x3.format_id
+        "99_xxx_yyy" == x3.gtrid
+        None is x3.bqual
+
+    def test_xid_to_string(self):
+        x1 = psycopg.Xid.from_string("42_Z3RyaWQ=_YnF1YWw=")
+        str(x1) == "42_Z3RyaWQ=_YnF1YWw="
+
+        x2 = psycopg.Xid.from_string("99_xxx_yyy")
+        str(x2) == "99_xxx_yyy"
+
+
+@pytest.fixture
+def tpc(svcconn):
+    tpc = Tpc(svcconn)
+    tpc.check_tpc()
+    tpc.clear_test_xacts()
+    tpc.make_test_table()
+    yield tpc
+    tpc.clear_test_xacts()
+
+
+class Tpc:
+    """Helper object to test two-phase transactions"""
+
+    def __init__(self, conn):
+        assert conn.autocommit
+        self.conn = conn
+
+    def check_tpc(self):
+        val = int(
+            self.conn.execute("show max_prepared_transactions").fetchone()[0]
+        )
+        if not val:
+            pytest.skip("prepared transactions disabled in the database")
+
+    def clear_test_xacts(self):
+        """Rollback all the prepared transaction in the testing db."""
+        cur = self.conn.execute(
+            "select gid from pg_prepared_xacts where database = %s",
+            (self.conn.info.dbname,),
+        )
+        gids = [r[0] for r in cur]
+        for gid in gids:
+            self.conn.execute(sql.SQL("rollback prepared {}").format(gid))
+
+    def make_test_table(self):
+        self.conn.execute("CREATE TABLE IF NOT EXISTS test_tpc (data text)")
+        self.conn.execute("TRUNCATE test_tpc")
+
+    def count_xacts(self):
+        """Return the number of prepared xacts currently in the test db."""
+        cur = self.conn.execute(
+            """
+            select count(*) from pg_prepared_xacts
+            where database = %s""",
+            (self.conn.info.dbname,),
+        )
+        return cur.fetchone()[0]
+
+    def count_test_records(self):
+        """Return the number of records in the test table."""
+        cur = self.conn.execute("select count(*) from test_tpc")
+        return cur.fetchone()[0]