From cf5dc52ada2a5bddcead8e91319344f68a08a4b7 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Wed, 27 Oct 2021 01:50:45 +0200 Subject: [PATCH] Add two-phase commit DBAPI Connection methods --- psycopg/psycopg/__init__.py | 2 + psycopg/psycopg/_tpc.py | 113 ++++++++++ psycopg/psycopg/connection.py | 122 ++++++++++- psycopg/setup.cfg | 1 + tests/test_psycopg_dbapi20.py | 4 +- tests/test_tpc.py | 383 ++++++++++++++++++++++++++++++++++ 6 files changed, 621 insertions(+), 4 deletions(-) create mode 100644 psycopg/psycopg/_tpc.py create mode 100644 tests/test_tpc.py diff --git a/psycopg/psycopg/__init__.py b/psycopg/psycopg/__init__.py index 20d675e25..1a5c5687c 100644 --- a/psycopg/psycopg/__init__.py +++ b/psycopg/psycopg/__init__.py @@ -9,6 +9,7 @@ import logging from . import pq # noqa: F401 import early to stabilize side effects from . import types from . import postgres +from ._tpc import Xid from .copy import Copy, AsyncCopy from ._enums import IsolationLevel from .cursor import Cursor @@ -71,6 +72,7 @@ __all__ = [ "Rollback", "ServerCursor", "Transaction", + "Xid", # DBAPI exports "connect", "apilevel", diff --git a/psycopg/psycopg/_tpc.py b/psycopg/psycopg/_tpc.py new file mode 100644 index 000000000..1c778c1f6 --- /dev/null +++ b/psycopg/psycopg/_tpc.py @@ -0,0 +1,113 @@ +""" +psycopg two-phase commit support +""" + +# Copyright (C) 2021 The Psycopg Team + +import re +import datetime as dt +from base64 import b64encode, b64decode +from typing import Optional, Union +from dataclasses import dataclass, replace + +_re_xid = re.compile(r"^(\d+)_([^_]*)_([^_]*)$") + + +@dataclass(frozen=True) +class Xid: + """A two-phase commit transaction identifier.""" + + format_id: Optional[int] + gtrid: str + bqual: Optional[str] + prepared: Optional[dt.datetime] = None + owner: Optional[str] = None + database: Optional[str] = None + + @classmethod + def from_string(cls, s: str) -> "Xid": + """Try to parse an XA triple from the string. + + This may fail for several reasons. In such case return an unparsed Xid. + """ + try: + return cls._parse_string(s) + except Exception: + return Xid(None, s, None) + + def __str__(self) -> str: + return self._as_tid() + + def __len__(self) -> int: + return 3 + + def __getitem__(self, index: int) -> Union[int, str, None]: + return (self.format_id, self.gtrid, self.bqual)[index] + + @classmethod + def _parse_string(cls, s: str) -> "Xid": + m = _re_xid.match(s) + if not m: + raise ValueError("bad Xid format") + + format_id = int(m.group(1)) + gtrid = b64decode(m.group(2)).decode() + bqual = b64decode(m.group(3)).decode() + return cls.from_parts(format_id, gtrid, bqual) + + @classmethod + def from_parts( + cls, format_id: Optional[int], gtrid: str, bqual: Optional[str] + ) -> "Xid": + if format_id is not None: + if bqual is None: + raise TypeError("if format_id is specified, bqual must be too") + if not 0 <= format_id < 0x80000000: + raise ValueError( + "format_id must be a non-negative 32-bit integer" + ) + if len(bqual) > 64: + raise ValueError("bqual must be not longer than 64 chars") + if len(gtrid) > 64: + raise ValueError("gtrid must be not longer than 64 chars") + + elif bqual is None: + raise TypeError("if format_id is None, bqual must be None too") + + return Xid(format_id, gtrid, bqual) + + def _as_tid(self) -> str: + """ + Return the PostgreSQL transaction_id for this XA xid. + + PostgreSQL wants just a string, while the DBAPI supports the XA + standard and thus a triple. We use the same conversion algorithm + implemented by JDBC in order to allow some form of interoperation. + + see also: the pgjdbc implementation + http://cvs.pgfoundry.org/cgi-bin/cvsweb.cgi/jdbc/pgjdbc/org/ + postgresql/xa/RecoveredXid.java?rev=1.2 + """ + if self.format_id is None or self.bqual is None: + # Unparsed xid: return the gtrid. + return self.gtrid + + # XA xid: mash together the components. + egtrid = b64encode(self.gtrid.encode()).decode() + ebqual = b64encode(self.bqual.encode()).decode() + + return f"{self.format_id}_{egtrid}_{ebqual}" + + @classmethod + def _get_recover_query(cls) -> str: + return "SELECT gid, prepared, owner, database FROM pg_prepared_xacts" + + @classmethod + def _from_record( + cls, gid: str, prepared: dt.datetime, owner: str, database: str + ) -> "Xid": + xid = Xid.from_string(gid) + return replace(xid, prepared=prepared, owner=owner, database=database) + + +Xid.__module__ = "psycopg" diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index ad6a5a900..8a7ce4450 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -9,7 +9,7 @@ import warnings import threading from types import TracebackType from typing import Any, Callable, cast, Dict, Generator, Generic, Iterator -from typing import List, NamedTuple, Optional, Type, TypeVar, Union +from typing import List, NamedTuple, Optional, Type, TypeVar, Tuple, Union from typing import overload, TYPE_CHECKING from weakref import ref, ReferenceType from functools import partial @@ -22,8 +22,9 @@ from . import postgres from .pq import ConnStatus, ExecStatus, TransactionStatus, Format from .abc import AdaptContext, ConnectionType, Params, Query, RV from .abc import PQGen, PQGenConn -from .sql import Composable -from .rows import Row, RowFactory, tuple_row, TupleRow +from .sql import Composable, SQL +from ._tpc import Xid +from .rows import Row, RowFactory, tuple_row, TupleRow, args_row from .adapt import AdaptersMap from ._enums import IsolationLevel from .cursor import Cursor @@ -116,6 +117,7 @@ class BaseConnection(Generic[Row]): self._closed = False # closed by an explicit close() self._prepared: PrepareManager = PrepareManager() + self._tpc: Optional[Tuple[Xid, bool]] = None # xid, prepared wself = ref(self) pgconn.notice_handler = partial(BaseConnection._notice_handler, wself) @@ -284,6 +286,11 @@ class BaseConnection(Generic[Row]): def cancel(self) -> None: """Cancel the current operation on the connection.""" + if self._tpc and self._tpc[1]: + raise e.ProgrammingError( + "cancel() cannot be used with a prepared two-phase transaction" + ) + c = self.pgconn.get_cancel() c.cancel() @@ -477,6 +484,10 @@ class BaseConnection(Generic[Row]): "context. (Transaction will be automatically committed " "on successful exit from context.)" ) + if self._tpc: + raise e.ProgrammingError( + "commit() cannot be used during a two-phase transaction" + ) if self.pgconn.transaction_status == TransactionStatus.IDLE: return @@ -490,6 +501,10 @@ class BaseConnection(Generic[Row]): "context. (Either raise Rollback() or allow " "an exception to propagate out of the context.)" ) + if self._tpc: + raise e.ProgrammingError( + "rollback() cannot be used during a two-phase transaction" + ) if self.pgconn.transaction_status == TransactionStatus.IDLE: return @@ -498,6 +513,74 @@ class BaseConnection(Generic[Row]): if cmd: yield from self._exec_command(cmd) + def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid: + return Xid.from_parts(format_id, gtrid, bqual) + + def _tpc_begin_gen(self, xid: Union[Xid, str]) -> PQGen[None]: + if not isinstance(xid, Xid): + xid = Xid.from_string(xid) + + if self.pgconn.transaction_status != TransactionStatus.IDLE: + raise e.ProgrammingError( + f"can't start two-phase transaction: connection in status" + f" {TransactionStatus(self.pgconn.transaction_status).name}" + ) + + if self._autocommit: + raise e.ProgrammingError( + "can't use two-phase transctions in autocommit mode" + ) + + self._tpc = (xid, False) + yield from self._exec_command(self._get_tx_start_command()) + + def _tpc_prepare_gen(self) -> PQGen[None]: + if not self._tpc: + raise e.ProgrammingError( + "'tpc_prepare()' must be called inside a two-phase transaction" + ) + if self._tpc[1]: + raise e.ProgrammingError( + "'tpc_prepare()' cannot be used during a prepared" + " two-phase transaction" + ) + xid = self._tpc[0] + self._tpc = (xid, True) + yield from self._exec_command( + SQL("PREPARE TRANSACTION {}").format(str(xid)) + ) + + def _tpc_finish_gen( + self, action: str, xid: Union[Xid, str, None] + ) -> PQGen[None]: + fname = f"tpc_{action}()" + if xid is None: + if not self._tpc: + raise e.ProgrammingError( + f"{fname} without xid must must be" + " called inside a two-phase transaction" + ) + xid = self._tpc[0] + else: + if self._tpc: + raise e.ProgrammingError( + f"{fname} with xid must must be called" + " outside a two-phase transaction" + ) + if not isinstance(xid, Xid): + xid = Xid.from_string(xid) + + if self._tpc and not self._tpc[1]: + meth: Callable[[], PQGen[None]] + meth = getattr(self, f"_{action}_gen") + self._tpc = None + yield from meth() + else: + yield from self._exec_command( + SQL("{} PREPARED {}").format(SQL(action.upper()), str(xid)) + ) + self._tpc = None + class Connection(BaseConnection[Row]): """ @@ -786,3 +869,36 @@ class Connection(BaseConnection[Row]): def _set_deferrable(self, value: Optional[bool]) -> None: with self.lock: super()._set_deferrable(value) + + def tpc_begin(self, xid: Union[Xid, str]) -> None: + with self.lock: + self.wait(self._tpc_begin_gen(xid)) + + def tpc_prepare(self) -> None: + try: + with self.lock: + self.wait(self._tpc_prepare_gen()) + except e.ObjectNotInPrerequisiteState as ex: + raise e.NotSupportedError(str(ex)) from None + + def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None: + with self.lock: + self.wait(self._tpc_finish_gen("commit", xid)) + + def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None: + with self.lock: + self.wait(self._tpc_finish_gen("rollback", xid)) + + def tpc_recover(self) -> List[Xid]: + status = self.info.transaction_status + with self.cursor(row_factory=args_row(Xid._from_record)) as cur: + cur.execute(Xid._get_recover_query()) + res = cur.fetchall() + + if ( + status == TransactionStatus.IDLE + and self.info.transaction_status == TransactionStatus.INTRANS + ): + self.rollback() + + return res diff --git a/psycopg/setup.cfg b/psycopg/setup.cfg index 48082ff8f..45f05f48d 100644 --- a/psycopg/setup.cfg +++ b/psycopg/setup.cfg @@ -40,6 +40,7 @@ packages = find: zip_safe = False install_requires = backports.zoneinfo; python_version < "3.9" + dataclasses; python_version < "3.7" typing_extensions; python_version < "3.8" [options.package_data] diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py index 04efa1ecb..0a40e4a20 100644 --- a/tests/test_psycopg_dbapi20.py +++ b/tests/test_psycopg_dbapi20.py @@ -8,6 +8,8 @@ from psycopg.conninfo import conninfo_to_dict from . import dbapi20 from . import dbapi20_tpc +from .test_tpc import tpc # noqa F401 # fixture + @pytest.fixture(scope="class") def with_dsn(request, dsn): @@ -29,7 +31,7 @@ class PsycopgTests(dbapi20.DatabaseAPI20Test): pass -# @skip_if_tpc_disabled +@pytest.mark.usefixtures("tpc") @pytest.mark.usefixtures("with_dsn") class PsycopgTPCTests(dbapi20_tpc.TwoPhaseCommitTests): driver = psycopg diff --git a/tests/test_tpc.py b/tests/test_tpc.py new file mode 100644 index 000000000..127e70a8a --- /dev/null +++ b/tests/test_tpc.py @@ -0,0 +1,383 @@ +from operator import attrgetter + +import pytest + +import psycopg +from psycopg import sql + + +def test_tpc_disabled(conn): + val = int(conn.execute("show max_prepared_transactions").fetchone()[0]) + if val: + pytest.skip("prepared transactions enabled") + + conn.rollback() + conn.tpc_begin("x") + with pytest.raises(psycopg.NotSupportedError): + conn.tpc_prepare() + + +class TestTPC: + def test_tpc_commit(self, conn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_prepare() + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + conn.tpc_commit() + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + def test_tpc_commit_one_phase(self, conn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit_1p')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_commit() + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + def test_tpc_commit_recovered(self, conn, dsn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit_rec')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_prepare() + conn.close() + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + conn = psycopg.connect(dsn) + xid = conn.xid(1, "gtrid", "bqual") + conn.tpc_commit(xid) + + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + def test_tpc_rollback(self, conn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_rollback')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_prepare() + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + conn.tpc_rollback() + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + def test_tpc_rollback_one_phase(self, conn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_rollback_1p')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_rollback() + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + def test_tpc_rollback_recovered(self, conn, dsn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit_rec')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_prepare() + conn.close() + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + conn = psycopg.connect(dsn) + xid = conn.xid(1, "gtrid", "bqual") + conn.tpc_rollback(xid) + + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + def test_status_after_recover(self, conn, tpc): + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + conn.tpc_recover() + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + + cur = conn.cursor() + cur.execute("select 1") + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + conn.tpc_recover() + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + def test_recovered_xids(self, conn, tpc): + # insert a few test xns + conn.autocommit = True + cur = conn.cursor() + cur.execute("begin; prepare transaction '1-foo'") + cur.execute("begin; prepare transaction '2-bar'") + + # read the values to return + cur.execute( + """ + select gid, prepared, owner, database from pg_prepared_xacts + where database = %s + """, + (conn.info.dbname,), + ) + okvals = cur.fetchall() + okvals.sort() + + xids = conn.tpc_recover() + xids = [xid for xid in xids if xid.database == conn.info.dbname] + xids.sort(key=attrgetter("gtrid")) + + # check the values returned + assert len(okvals) == len(xids) + for (xid, (gid, prepared, owner, database)) in zip(xids, okvals): + assert xid.gtrid == gid + assert xid.prepared == prepared + assert xid.owner == owner + assert xid.database == database + + def test_xid_encoding(self, conn, tpc): + xid = conn.xid(42, "gtrid", "bqual") + conn.tpc_begin(xid) + conn.tpc_prepare() + + cur = conn.cursor() + cur.execute( + "select gid from pg_prepared_xacts where database = %s", + (conn.info.dbname,), + ) + assert "42_Z3RyaWQ=_YnF1YWw=" == cur.fetchone()[0] + + @pytest.mark.parametrize( + "fid, gtrid, bqual", + [ + (0, "", ""), + (42, "gtrid", "bqual"), + (0x7FFFFFFF, "x" * 64, "y" * 64), + ], + ) + def test_xid_roundtrip(self, conn, dsn, tpc, fid, gtrid, bqual): + xid = conn.xid(fid, gtrid, bqual) + conn.tpc_begin(xid) + conn.tpc_prepare() + conn.close() + + conn = psycopg.connect(dsn) + xids = [ + x for x in conn.tpc_recover() if x.database == conn.info.dbname + ] + assert len(xids) == 1 + xid = xids[0] + assert xid.format_id == fid + assert xid.gtrid == gtrid + assert xid.bqual == bqual + + conn.tpc_rollback(xid) + + @pytest.mark.parametrize( + "tid", + [ + "", + "hello, world!", + "x" * 199, # PostgreSQL's limit in transaction id length + ], + ) + def test_unparsed_roundtrip(self, conn, dsn, tpc, tid): + conn.tpc_begin(tid) + conn.tpc_prepare() + conn.close() + + conn = psycopg.connect(dsn) + xids = [ + x for x in conn.tpc_recover() if x.database == conn.info.dbname + ] + assert len(xids) == 1 + xid = xids[0] + assert xid.format_id is None + assert xid.gtrid == tid + assert xid.bqual is None + + conn.tpc_rollback(xid) + + def test_xid_unicode(self, conn, dsn, tpc): + x1 = conn.xid(10, "uni", "code") + conn.tpc_begin(x1) + conn.tpc_prepare() + conn.close() + + conn = psycopg.connect(dsn) + xid = [ + x for x in conn.tpc_recover() if x.database == conn.info.dbname + ][0] + assert 10 == xid.format_id + assert "uni" == xid.gtrid + assert "code" == xid.bqual + + def test_xid_unicode_unparsed(self, conn, dsn, tpc): + # We don't expect people shooting snowmen as transaction ids, + # so if something explodes in an encode error I don't mind. + # Let's just check unicode is accepted as type. + conn.execute("set client_encoding to utf8") + conn.commit() + + conn.tpc_begin("transaction-id") + conn.tpc_prepare() + conn.close() + + conn = psycopg.connect(dsn) + xid = [ + x for x in conn.tpc_recover() if x.database == conn.info.dbname + ][0] + assert xid.format_id is None + assert xid.gtrid == "transaction-id" + assert xid.bqual is None + + def test_cancel_fails_prepared(self, conn, tpc): + conn.tpc_begin("cancel") + conn.tpc_prepare() + with pytest.raises(psycopg.ProgrammingError): + conn.cancel() + + def test_tpc_recover_non_dbapi_connection(self, conn, dsn, tpc): + conn.row_factory = psycopg.rows.dict_row + conn.tpc_begin("dict-connection") + conn.tpc_prepare() + conn.close() + + conn = psycopg.connect(dsn) + xids = conn.tpc_recover() + xid = [x for x in xids if x.database == conn.info.dbname][0] + assert xid.format_id is None + assert xid.gtrid == "dict-connection" + assert xid.bqual is None + + +class TestXidObject: + def test_xid_construction(self): + x1 = psycopg.Xid(74, "foo", "bar") + 74 == x1.format_id + "foo" == x1.gtrid + "bar" == x1.bqual + + def test_xid_from_string(self): + x2 = psycopg.Xid.from_string("42_Z3RyaWQ=_YnF1YWw=") + 42 == x2.format_id + "gtrid" == x2.gtrid + "bqual" == x2.bqual + + x3 = psycopg.Xid.from_string("99_xxx_yyy") + None is x3.format_id + "99_xxx_yyy" == x3.gtrid + None is x3.bqual + + def test_xid_to_string(self): + x1 = psycopg.Xid.from_string("42_Z3RyaWQ=_YnF1YWw=") + str(x1) == "42_Z3RyaWQ=_YnF1YWw=" + + x2 = psycopg.Xid.from_string("99_xxx_yyy") + str(x2) == "99_xxx_yyy" + + +@pytest.fixture +def tpc(svcconn): + tpc = Tpc(svcconn) + tpc.check_tpc() + tpc.clear_test_xacts() + tpc.make_test_table() + yield tpc + tpc.clear_test_xacts() + + +class Tpc: + """Helper object to test two-phase transactions""" + + def __init__(self, conn): + assert conn.autocommit + self.conn = conn + + def check_tpc(self): + val = int( + self.conn.execute("show max_prepared_transactions").fetchone()[0] + ) + if not val: + pytest.skip("prepared transactions disabled in the database") + + def clear_test_xacts(self): + """Rollback all the prepared transaction in the testing db.""" + cur = self.conn.execute( + "select gid from pg_prepared_xacts where database = %s", + (self.conn.info.dbname,), + ) + gids = [r[0] for r in cur] + for gid in gids: + self.conn.execute(sql.SQL("rollback prepared {}").format(gid)) + + def make_test_table(self): + self.conn.execute("CREATE TABLE IF NOT EXISTS test_tpc (data text)") + self.conn.execute("TRUNCATE test_tpc") + + def count_xacts(self): + """Return the number of prepared xacts currently in the test db.""" + cur = self.conn.execute( + """ + select count(*) from pg_prepared_xacts + where database = %s""", + (self.conn.info.dbname,), + ) + return cur.fetchone()[0] + + def count_test_records(self): + """Return the number of records in the test table.""" + cur = self.conn.execute("select count(*) from test_tpc") + return cur.fetchone()[0] -- 2.47.2