From d4ec028728a747a2b594ccc730f4e50f643447f0 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sat, 30 Oct 2021 19:35:28 +0200 Subject: [PATCH] Add two-phase support method to async connection --- psycopg/psycopg/connection_async.py | 40 +++- tests/test_tpc_async.py | 322 ++++++++++++++++++++++++++++ 2 files changed, 359 insertions(+), 3 deletions(-) create mode 100644 tests/test_tpc_async.py diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 783a8ad0d..d892a974f 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -8,14 +8,15 @@ import sys import asyncio import logging from types import TracebackType -from typing import Any, AsyncGenerator, AsyncIterator, Dict, Optional +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional from typing import Type, Union, cast, overload, TYPE_CHECKING from . import errors as e from . import waiting -from .pq import Format +from .pq import Format, TransactionStatus from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV -from .rows import Row, AsyncRowFactory, tuple_row, TupleRow +from ._tpc import Xid +from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row from .adapt import AdaptersMap from ._enums import IsolationLevel from ._compat import asynccontextmanager, get_running_loop @@ -334,3 +335,36 @@ class AsyncConnection(BaseConnection[Row]): f"'the {attribute!r} property is read-only on async connections:" f" please use 'await .set_{attribute}()' instead." ) + + async def tpc_begin(self, xid: Union[Xid, str]) -> None: + async with self.lock: + await self.wait(self._tpc_begin_gen(xid)) + + async def tpc_prepare(self) -> None: + try: + async with self.lock: + await self.wait(self._tpc_prepare_gen()) + except e.ObjectNotInPrerequisiteState as ex: + raise e.NotSupportedError(str(ex)) from None + + async def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None: + async with self.lock: + await self.wait(self._tpc_finish_gen("commit", xid)) + + async def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None: + async with self.lock: + await self.wait(self._tpc_finish_gen("rollback", xid)) + + async def tpc_recover(self) -> List[Xid]: + status = self.info.transaction_status + async with self.cursor(row_factory=args_row(Xid._from_record)) as cur: + await cur.execute(Xid._get_recover_query()) + res = await cur.fetchall() + + if ( + status == TransactionStatus.IDLE + and self.info.transaction_status == TransactionStatus.INTRANS + ): + await self.rollback() + + return res diff --git a/tests/test_tpc_async.py b/tests/test_tpc_async.py new file mode 100644 index 000000000..d8fdaf433 --- /dev/null +++ b/tests/test_tpc_async.py @@ -0,0 +1,322 @@ +from operator import attrgetter + +import pytest + +import psycopg + +from .test_tpc import tpc # noqa: F401 # fixture + +pytestmark = [pytest.mark.asyncio] +tpc = tpc # Silence F811 in the rest of the file + + +async def test_tpc_disabled(aconn): + cur = await aconn.execute("show max_prepared_transactions") + val = int((await cur.fetchone())[0]) + if val: + pytest.skip("prepared transactions enabled") + + await aconn.rollback() + await aconn.tpc_begin("x") + with pytest.raises(psycopg.NotSupportedError): + await aconn.tpc_prepare() + + +class TestTPC: + async def test_tpc_commit(self, aconn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute("insert into test_tpc values ('test_tpc_commit')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_prepare() + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + await aconn.tpc_commit() + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + async def test_tpc_commit_one_phase(self, aconn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute("insert into test_tpc values ('test_tpc_commit_1p')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_commit() + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + async def test_tpc_commit_recovered(self, aconn, dsn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute( + "insert into test_tpc values ('test_tpc_commit_rec')" + ) + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_prepare() + await aconn.close() + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + aconn = await psycopg.AsyncConnection.connect(dsn) + xid = aconn.xid(1, "gtrid", "bqual") + await aconn.tpc_commit(xid) + + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + async def test_tpc_rollback(self, aconn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute("insert into test_tpc values ('test_tpc_rollback')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_prepare() + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + await aconn.tpc_rollback() + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + async def test_tpc_rollback_one_phase(self, aconn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute( + "insert into test_tpc values ('test_tpc_rollback_1p')" + ) + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_rollback() + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + async def test_tpc_rollback_recovered(self, aconn, dsn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute( + "insert into test_tpc values ('test_tpc_commit_rec')" + ) + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_prepare() + await aconn.close() + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + aconn = await psycopg.AsyncConnection.connect(dsn) + xid = aconn.xid(1, "gtrid", "bqual") + await aconn.tpc_rollback(xid) + + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + async def test_status_after_recover(self, aconn, tpc): + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + await aconn.tpc_recover() + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + + cur = aconn.cursor() + await cur.execute("select 1") + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + await aconn.tpc_recover() + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + async def test_recovered_xids(self, aconn, tpc): + # insert a few test xns + await aconn.set_autocommit(True) + cur = aconn.cursor() + await cur.execute("begin; prepare transaction '1-foo'") + await cur.execute("begin; prepare transaction '2-bar'") + + # read the values to return + await cur.execute( + """ + select gid, prepared, owner, database from pg_prepared_xacts + where database = %s + """, + (aconn.info.dbname,), + ) + okvals = await cur.fetchall() + okvals.sort() + + xids = await aconn.tpc_recover() + xids = [xid for xid in xids if xid.database == aconn.info.dbname] + xids.sort(key=attrgetter("gtrid")) + + # check the values returned + assert len(okvals) == len(xids) + for (xid, (gid, prepared, owner, database)) in zip(xids, okvals): + assert xid.gtrid == gid + assert xid.prepared == prepared + assert xid.owner == owner + assert xid.database == database + + async def test_xid_encoding(self, aconn, tpc): + xid = aconn.xid(42, "gtrid", "bqual") + await aconn.tpc_begin(xid) + await aconn.tpc_prepare() + + cur = aconn.cursor() + await cur.execute( + "select gid from pg_prepared_xacts where database = %s", + (aconn.info.dbname,), + ) + assert "42_Z3RyaWQ=_YnF1YWw=" == (await cur.fetchone())[0] + + @pytest.mark.parametrize( + "fid, gtrid, bqual", + [ + (0, "", ""), + (42, "gtrid", "bqual"), + (0x7FFFFFFF, "x" * 64, "y" * 64), + ], + ) + async def test_xid_roundtrip(self, aconn, dsn, tpc, fid, gtrid, bqual): + xid = aconn.xid(fid, gtrid, bqual) + await aconn.tpc_begin(xid) + await aconn.tpc_prepare() + await aconn.close() + + aconn = await psycopg.AsyncConnection.connect(dsn) + xids = [ + x + for x in await aconn.tpc_recover() + if x.database == aconn.info.dbname + ] + assert len(xids) == 1 + xid = xids[0] + assert xid.format_id == fid + assert xid.gtrid == gtrid + assert xid.bqual == bqual + + await aconn.tpc_rollback(xid) + + @pytest.mark.parametrize( + "tid", + [ + "", + "hello, world!", + "x" * 199, # PostgreSQL's limit in transaction id length + ], + ) + async def test_unparsed_roundtrip(self, aconn, dsn, tpc, tid): + await aconn.tpc_begin(tid) + await aconn.tpc_prepare() + await aconn.close() + + aconn = await psycopg.AsyncConnection.connect(dsn) + xids = [ + x + for x in await aconn.tpc_recover() + if x.database == aconn.info.dbname + ] + assert len(xids) == 1 + xid = xids[0] + assert xid.format_id is None + assert xid.gtrid == tid + assert xid.bqual is None + + await aconn.tpc_rollback(xid) + + async def test_xid_unicode(self, aconn, dsn, tpc): + x1 = aconn.xid(10, "uni", "code") + await aconn.tpc_begin(x1) + await aconn.tpc_prepare() + await aconn.close() + + aconn = await psycopg.AsyncConnection.connect(dsn) + xid = [ + x + for x in await aconn.tpc_recover() + if x.database == aconn.info.dbname + ][0] + assert 10 == xid.format_id + assert "uni" == xid.gtrid + assert "code" == xid.bqual + + async def test_xid_unicode_unparsed(self, aconn, dsn, tpc): + # We don't expect people shooting snowmen as transaction ids, + # so if something explodes in an encode error I don't mind. + # Let's just check unicode is accepted as type. + await aconn.execute("set client_encoding to utf8") + await aconn.commit() + + await aconn.tpc_begin("transaction-id") + await aconn.tpc_prepare() + await aconn.close() + + aconn = await psycopg.AsyncConnection.connect(dsn) + xid = [ + x + for x in await aconn.tpc_recover() + if x.database == aconn.info.dbname + ][0] + assert xid.format_id is None + assert xid.gtrid == "transaction-id" + assert xid.bqual is None + + async def test_cancel_fails_prepared(self, aconn, tpc): + await aconn.tpc_begin("cancel") + await aconn.tpc_prepare() + with pytest.raises(psycopg.ProgrammingError): + aconn.cancel() + + async def test_tpc_recover_non_dbapi_connection(self, aconn, dsn, tpc): + aconn.row_factory = psycopg.rows.dict_row + await aconn.tpc_begin("dict-connection") + await aconn.tpc_prepare() + await aconn.close() + + aconn = await psycopg.AsyncConnection.connect(dsn) + xids = await aconn.tpc_recover() + xid = [x for x in xids if x.database == aconn.info.dbname][0] + assert xid.format_id is None + assert xid.gtrid == "dict-connection" + assert xid.bqual is None -- 2.47.2