From: Daniele Varrazzo Date: Tue, 7 Jun 2022 06:01:46 +0000 (+0200) Subject: fix(crdb): raise NotSupportedError on two-phase commit methods X-Git-Tag: 3.1~49^2~20 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=474e4fa8c4cdcd6fb8e8996355b232475f59de6f;p=thirdparty%2Fpsycopg.git fix(crdb): raise NotSupportedError on two-phase commit methods --- diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index abd7149f4..ee05e8e50 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -585,9 +585,12 @@ class BaseConnection(Generic[Row]): The values passed to the method will be available on the returned object as the members `~Xid.format_id`, `~Xid.gtrid`, `~Xid.bqual`. """ + self._check_tpc() return Xid.from_parts(format_id, gtrid, bqual) def _tpc_begin_gen(self, xid: Union[Xid, str]) -> PQGen[None]: + self._check_tpc() + if not isinstance(xid, Xid): xid = Xid.from_string(xid) @@ -651,6 +654,11 @@ class BaseConnection(Generic[Row]): ) self._tpc = None + def _check_tpc(self) -> None: + """Raise NotSupportedError if TPC is not supported.""" + # TPC supported on every supported PostgreSQL version. + pass + class Connection(BaseConnection[Row]): """ @@ -1024,6 +1032,7 @@ class Connection(BaseConnection[Row]): self.wait(self._tpc_finish_gen("ROLLBACK", xid)) def tpc_recover(self) -> List[Xid]: + self._check_tpc() status = self.info.transaction_status with self.cursor(row_factory=args_row(Xid._from_record)) as cur: cur.execute(Xid._get_recover_query()) diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 1545667b7..89f3b5802 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -424,6 +424,7 @@ class AsyncConnection(BaseConnection[Row]): await self.wait(self._tpc_finish_gen("rollback", xid)) async def tpc_recover(self) -> List[Xid]: + self._check_tpc() status = self.info.transaction_status async with self.cursor(row_factory=args_row(Xid._from_record)) as cur: await cur.execute(Xid._get_recover_query()) diff --git a/psycopg/psycopg/crdb/connection.py b/psycopg/psycopg/crdb/connection.py index 998ba4720..fa5936e3a 100644 --- a/psycopg/psycopg/crdb/connection.py +++ b/psycopg/psycopg/crdb/connection.py @@ -51,6 +51,10 @@ class _CrdbConnectionMixin: def info(self) -> "CrdbConnectionInfo": return CrdbConnectionInfo(self.pgconn) + def _check_tpc(self) -> None: + if self.is_crdb(self.pgconn): + raise e.NotSupportedError("CockroachDB doesn't support prepared statements") + class CrdbConnection(_CrdbConnectionMixin, Connection[Row]): """ diff --git a/tests/crdb/test_connection.py b/tests/crdb/test_connection.py index fb57c6e6a..724728f69 100644 --- a/tests/crdb/test_connection.py +++ b/tests/crdb/test_connection.py @@ -1,4 +1,5 @@ import psycopg.crdb +from psycopg import errors as e from psycopg.crdb import CrdbConnection import pytest @@ -12,8 +13,26 @@ def test_is_crdb(conn): def test_connect(dsn): - with psycopg.crdb.CrdbConnection.connect(dsn) as conn: + with CrdbConnection.connect(dsn) as conn: assert isinstance(conn, CrdbConnection) with psycopg.crdb.connect(dsn) as conn: assert isinstance(conn, CrdbConnection) + + +def test_xid(dsn): + with CrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + conn.xid(1, "gtrid", "bqual") + + +def test_tpc_begin(dsn): + with CrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + conn.tpc_begin("foo") + + +def test_tpc_recover(dsn): + with CrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + conn.tpc_recover() diff --git a/tests/crdb/test_connection_async.py b/tests/crdb/test_connection_async.py index 3d6da1b40..9d3d16e84 100644 --- a/tests/crdb/test_connection_async.py +++ b/tests/crdb/test_connection_async.py @@ -1,4 +1,5 @@ import psycopg.crdb +from psycopg import errors as e from psycopg.crdb import AsyncCrdbConnection import pytest @@ -12,5 +13,23 @@ async def test_is_crdb(aconn): async def test_connect(dsn): - async with await psycopg.crdb.AsyncCrdbConnection.connect(dsn) as conn: + async with await AsyncCrdbConnection.connect(dsn) as conn: assert isinstance(conn, psycopg.crdb.AsyncCrdbConnection) + + +async def test_xid(dsn): + async with await AsyncCrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + conn.xid(1, "gtrid", "bqual") + + +async def test_tpc_begin(dsn): + async with await AsyncCrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + await conn.tpc_begin("foo") + + +async def test_tpc_recover(dsn): + async with await AsyncCrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + await conn.tpc_recover() diff --git a/tests/crdb/test_no_crdb.py b/tests/crdb/test_no_crdb.py index ac1bc1897..df43f3bd1 100644 --- a/tests/crdb/test_no_crdb.py +++ b/tests/crdb/test_no_crdb.py @@ -1,3 +1,4 @@ +from psycopg.pq import TransactionStatus from psycopg.crdb import CrdbConnection import pytest @@ -8,3 +9,26 @@ pytestmark = pytest.mark.crdb("skip") def test_is_crdb(conn): assert not CrdbConnection.is_crdb(conn) assert not CrdbConnection.is_crdb(conn.pgconn) + + +def test_tpc_on_pg_connection(conn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_prepare() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + conn.tpc_commit() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1