]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: sync pipeline state on rollback 267/head
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 4 Apr 2022 01:41:49 +0000 (03:41 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 7 May 2022 13:41:25 +0000 (15:41 +0200)
psycopg/psycopg/connection.py
psycopg/psycopg/transaction.py
tests/test_pipeline.py
tests/test_pipeline_async.py

index 65bc2f658ef2d68aa645a4fdb2d7b4ee6e616e8b..fc92c0512f3a2cf31da2a03ea2c7031a03554995 100644 (file)
@@ -531,6 +531,11 @@ class BaseConnection(Generic[Row]):
             raise e.ProgrammingError(
                 "rollback() cannot be used during a two-phase transaction"
             )
+
+        # Get out of a "pipeline aborted" state
+        if self._pipeline and self.pgconn.pipeline_status == pq.PipelineStatus.ABORTED:
+            yield from self._pipeline._sync_gen()
+
         if self.pgconn.transaction_status == TransactionStatus.IDLE:
             return
 
@@ -539,6 +544,9 @@ class BaseConnection(Generic[Row]):
         for cmd in self._prepared.get_maintenance_commands():
             yield from self._exec_command(cmd)
 
+        if self._pipeline:
+            yield from self._pipeline._sync_gen()
+
     def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid:
         """
         Returns a `Xid` to pass to the `!tpc_*()` methods of this connection.
index b8e36867e1cc0f9eca2541c150e7cb041eb67478..35ef973d861093cdd540e0e1a90de2fe8924f1db 100644 (file)
@@ -56,6 +56,7 @@ class BaseTransaction(Generic[ConnectionType]):
         force_rollback: bool = False,
     ):
         self._conn = connection
+        self.pgconn = self._conn.pgconn
         self._savepoint_name = savepoint_name or ""
         self.force_rollback = force_rollback
         self._entered = self._exited = False
@@ -73,7 +74,7 @@ class BaseTransaction(Generic[ConnectionType]):
 
     def __repr__(self) -> str:
         cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
-        info = pq.misc.connection_summary(self._conn.pgconn)
+        info = pq.misc.connection_summary(self.pgconn)
         if not self._entered:
             status = "inactive"
         elif not self._exited:
@@ -136,6 +137,13 @@ class BaseTransaction(Generic[ConnectionType]):
         if ex:
             raise ex
 
+        # Get out of a "pipeline aborted" state
+        if (
+            self._conn._pipeline
+            and self.pgconn.pipeline_status == pq.PipelineStatus.ABORTED
+        ):
+            yield from self._conn._pipeline._sync_gen()
+
         for command in self._get_rollback_commands():
             yield from self._conn._exec_command(command)
 
@@ -196,7 +204,7 @@ class BaseTransaction(Generic[ConnectionType]):
         Also set the internal state of the object and verify consistency.
         """
         self._outer_transaction = (
-            self._conn.pgconn.transaction_status == TransactionStatus.IDLE
+            self.pgconn.transaction_status == TransactionStatus.IDLE
         )
         if self._outer_transaction:
             # outer transaction: if no name it's only a begin, else
@@ -248,7 +256,7 @@ class Transaction(BaseTransaction["Connection[Any]"]):
         exc_val: Optional[BaseException],
         exc_tb: Optional[TracebackType],
     ) -> bool:
-        if self._conn.pgconn.status == ConnStatus.OK:
+        if self.pgconn.status == ConnStatus.OK:
             with self._conn.lock:
                 return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
         else:
@@ -277,7 +285,7 @@ class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]):
         exc_val: Optional[BaseException],
         exc_tb: Optional[TracebackType],
     ) -> bool:
-        if self._conn.pgconn.status == ConnStatus.OK:
+        if self.pgconn.status == ConnStatus.OK:
             async with self._conn.lock:
                 return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
         else:
index aba1f9b048ef3dd36e49cda0fac35cfbecac9949..d5ef4185ecbafe9ff918dd78d952fd3dd3679242 100644 (file)
@@ -360,6 +360,26 @@ def test_outer_transaction_error(conn):
                 conn.execute("create table voila ()")
 
 
+def test_rollback_explicit(conn):
+    conn.autocommit = True
+    with conn.pipeline():
+        with pytest.raises(e.DivisionByZero):
+            cur = conn.execute("select 1 / %s", [0])
+            cur.fetchone()
+        conn.rollback()
+        conn.execute("select 1")
+
+
+def test_rollback_transaction(conn):
+    conn.autocommit = True
+    with pytest.raises(e.DivisionByZero):
+        with conn.pipeline():
+            with conn.transaction():
+                cur = conn.execute("select 1 / %s", [0])
+                cur.fetchone()
+    conn.execute("select 1")
+
+
 def test_concurrency(conn):
     with conn.transaction():
         conn.execute("drop table if exists pipeline_concurrency")
index 269af39e675cda4572a1d2c3293374fe26eeceed..4207348248c0156da1a679da2641d91adbe690c7 100644 (file)
@@ -364,6 +364,26 @@ async def test_outer_transaction_error(aconn):
                 await aconn.execute("create table voila ()")
 
 
+async def test_rollback_explicit(aconn):
+    await aconn.set_autocommit(True)
+    async with aconn.pipeline():
+        with pytest.raises(e.DivisionByZero):
+            cur = await aconn.execute("select 1 / %s", [0])
+            await cur.fetchone()
+        await aconn.rollback()
+        await aconn.execute("select 1")
+
+
+async def test_rollback_transaction(aconn):
+    await aconn.set_autocommit(True)
+    with pytest.raises(e.DivisionByZero):
+        async with aconn.pipeline():
+            async with aconn.transaction():
+                cur = await aconn.execute("select 1 / %s", [0])
+                await cur.fetchone()
+    await aconn.execute("select 1")
+
+
 async def test_concurrency(aconn):
     async with aconn.transaction():
         await aconn.execute("drop table if exists pipeline_concurrency")