]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(crdb): raise NotSupportedError on two-phase commit methods
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Jun 2022 06:01:46 +0000 (08:01 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jul 2022 11:58:34 +0000 (12:58 +0100)
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/crdb/connection.py
tests/crdb/test_connection.py
tests/crdb/test_connection_async.py
tests/crdb/test_no_crdb.py

index abd7149f4ad7e1a4f44445a0bb5e3901142efea8..ee05e8e50fccf3e7b4caf8684be92b5dfec5dc89 100644 (file)
@@ -585,9 +585,12 @@ class BaseConnection(Generic[Row]):
         The values passed to the method will be available on the returned
         object as the members `~Xid.format_id`, `~Xid.gtrid`, `~Xid.bqual`.
         """
+        self._check_tpc()
         return Xid.from_parts(format_id, gtrid, bqual)
 
     def _tpc_begin_gen(self, xid: Union[Xid, str]) -> PQGen[None]:
+        self._check_tpc()
+
         if not isinstance(xid, Xid):
             xid = Xid.from_string(xid)
 
@@ -651,6 +654,11 @@ class BaseConnection(Generic[Row]):
             )
             self._tpc = None
 
+    def _check_tpc(self) -> None:
+        """Raise NotSupportedError if TPC is not supported."""
+        # TPC supported on every supported PostgreSQL version.
+        pass
+
 
 class Connection(BaseConnection[Row]):
     """
@@ -1024,6 +1032,7 @@ class Connection(BaseConnection[Row]):
             self.wait(self._tpc_finish_gen("ROLLBACK", xid))
 
     def tpc_recover(self) -> List[Xid]:
+        self._check_tpc()
         status = self.info.transaction_status
         with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
             cur.execute(Xid._get_recover_query())
index 1545667b74a0697c64cac1f721ef884942a5e5da..89f3b58028b90957718f26a9c6c89f0dbea40d2c 100644 (file)
@@ -424,6 +424,7 @@ class AsyncConnection(BaseConnection[Row]):
             await self.wait(self._tpc_finish_gen("rollback", xid))
 
     async def tpc_recover(self) -> List[Xid]:
+        self._check_tpc()
         status = self.info.transaction_status
         async with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
             await cur.execute(Xid._get_recover_query())
index 998ba4720ef5a03fc2b50e392d9ff552859170f7..fa5936e3a8a920334a9df9463e7ed9a851d3907d 100644 (file)
@@ -51,6 +51,10 @@ class _CrdbConnectionMixin:
     def info(self) -> "CrdbConnectionInfo":
         return CrdbConnectionInfo(self.pgconn)
 
+    def _check_tpc(self) -> None:
+        if self.is_crdb(self.pgconn):
+            raise e.NotSupportedError("CockroachDB doesn't support prepared statements")
+
 
 class CrdbConnection(_CrdbConnectionMixin, Connection[Row]):
     """
index fb57c6e6a09142bab25cc82dfa687356703d54f4..724728f69d320d9f18179698b4251a872a6b25b4 100644 (file)
@@ -1,4 +1,5 @@
 import psycopg.crdb
+from psycopg import errors as e
 from psycopg.crdb import CrdbConnection
 
 import pytest
@@ -12,8 +13,26 @@ def test_is_crdb(conn):
 
 
 def test_connect(dsn):
-    with psycopg.crdb.CrdbConnection.connect(dsn) as conn:
+    with CrdbConnection.connect(dsn) as conn:
         assert isinstance(conn, CrdbConnection)
 
     with psycopg.crdb.connect(dsn) as conn:
         assert isinstance(conn, CrdbConnection)
+
+
+def test_xid(dsn):
+    with CrdbConnection.connect(dsn) as conn:
+        with pytest.raises(e.NotSupportedError):
+            conn.xid(1, "gtrid", "bqual")
+
+
+def test_tpc_begin(dsn):
+    with CrdbConnection.connect(dsn) as conn:
+        with pytest.raises(e.NotSupportedError):
+            conn.tpc_begin("foo")
+
+
+def test_tpc_recover(dsn):
+    with CrdbConnection.connect(dsn) as conn:
+        with pytest.raises(e.NotSupportedError):
+            conn.tpc_recover()
index 3d6da1b407577b6dee98e6932b83bc15451f15a2..9d3d16e84da0554c843c657230816f287de32106 100644 (file)
@@ -1,4 +1,5 @@
 import psycopg.crdb
+from psycopg import errors as e
 from psycopg.crdb import AsyncCrdbConnection
 
 import pytest
@@ -12,5 +13,23 @@ async def test_is_crdb(aconn):
 
 
 async def test_connect(dsn):
-    async with await psycopg.crdb.AsyncCrdbConnection.connect(dsn) as conn:
+    async with await AsyncCrdbConnection.connect(dsn) as conn:
         assert isinstance(conn, psycopg.crdb.AsyncCrdbConnection)
+
+
+async def test_xid(dsn):
+    async with await AsyncCrdbConnection.connect(dsn) as conn:
+        with pytest.raises(e.NotSupportedError):
+            conn.xid(1, "gtrid", "bqual")
+
+
+async def test_tpc_begin(dsn):
+    async with await AsyncCrdbConnection.connect(dsn) as conn:
+        with pytest.raises(e.NotSupportedError):
+            await conn.tpc_begin("foo")
+
+
+async def test_tpc_recover(dsn):
+    async with await AsyncCrdbConnection.connect(dsn) as conn:
+        with pytest.raises(e.NotSupportedError):
+            await conn.tpc_recover()
index ac1bc18970b2946b1e40ff4393a634ffa6da3048..df43f3bd1fa3b2e7c5728bb094ff5fed9c36cd72 100644 (file)
@@ -1,3 +1,4 @@
+from psycopg.pq import TransactionStatus
 from psycopg.crdb import CrdbConnection
 
 import pytest
@@ -8,3 +9,26 @@ pytestmark = pytest.mark.crdb("skip")
 def test_is_crdb(conn):
     assert not CrdbConnection.is_crdb(conn)
     assert not CrdbConnection.is_crdb(conn.pgconn)
+
+
+def test_tpc_on_pg_connection(conn, tpc):
+    xid = conn.xid(1, "gtrid", "bqual")
+    assert conn.info.transaction_status == TransactionStatus.IDLE
+
+    conn.tpc_begin(xid)
+    assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+    cur = conn.cursor()
+    cur.execute("insert into test_tpc values ('test_tpc_commit')")
+    assert tpc.count_xacts() == 0
+    assert tpc.count_test_records() == 0
+
+    conn.tpc_prepare()
+    assert conn.info.transaction_status == TransactionStatus.IDLE
+    assert tpc.count_xacts() == 1
+    assert tpc.count_test_records() == 0
+
+    conn.tpc_commit()
+    assert conn.info.transaction_status == TransactionStatus.IDLE
+    assert tpc.count_xacts() == 0
+    assert tpc.count_test_records() == 1