]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: restore the connection in usable state after an error in stream()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Jun 2022 07:21:03 +0000 (09:21 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Jun 2022 08:46:27 +0000 (10:46 +0200)
docs/news.rst
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
tests/test_concurrency.py
tests/test_concurrency_async.py
tests/test_cursor.py
tests/test_cursor_async.py

index 2ea5328e292f7a0ca5899c02a6f3248a06ae9558..441960e41d4cfc00315058ebdcd828558175300c 100644 (file)
@@ -35,6 +35,7 @@ Psycopg 3.0.15 (unreleased)
 
 - Fix wrong escaping of unprintable chars in COPY (nonetheless correctly
   interpreted by PostgreSQL).
+- Restore the connection to usable state after an error in `~Cursor.stream()`.
 
 
 Current release
index be391bb2778bfe8adca68a46ac8c025ebf6ef94e..30c41d8ea31668e79a0389ca3b6ac60b8b29a42a 100644 (file)
@@ -56,6 +56,8 @@ FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
 SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
 PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
 
+ACTIVE = pq.TransactionStatus.ACTIVE
+
 
 class BaseCursor(Generic[ConnectionType, Row]):
     __slots__ = """
@@ -786,6 +788,14 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
                     yield rec
                     first = False
         except e.Error as ex:
+            # try to get out of ACTIVE state. Just do a single attempt, which
+            # shoud work to recover from an error or query cancelled.
+            if self._pgconn.transaction_status == ACTIVE:
+                try:
+                    self._conn.wait(self._stream_fetchone_gen(first))
+                except Exception:
+                    pass
+
             raise ex.with_traceback(None)
 
     def fetchone(self) -> Optional[Row]:
index 4598bf434fa7aabccdab825580da1b4dee7ed7d5..5044732f8219cb3fabfc25c3c41ab71e2d114a9f 100644 (file)
@@ -9,6 +9,7 @@ from typing import Any, AsyncIterator, Iterable, List
 from typing import Optional, Type, TypeVar, TYPE_CHECKING, overload
 from contextlib import asynccontextmanager
 
+from . import pq
 from . import errors as e
 from .abc import Query, Params
 from .copy import AsyncCopy
@@ -144,6 +145,14 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
                     yield rec
                     first = False
         except e.Error as ex:
+            # try to get out of ACTIVE state. Just do a single attempt, which
+            # shoud work to recover from an error or query cancelled.
+            if self._pgconn.transaction_status == pq.TransactionStatus.ACTIVE:
+                try:
+                    await self._conn.wait(self._stream_fetchone_gen(first))
+                except Exception:
+                    pass
+
             raise ex.with_traceback(None)
 
     async def fetchone(self) -> Optional[Row]:
index 0f2f5e840e9d2e2150f982a05d9a0c46d88b7b41..59920e70f5ba4f4dc1d348f0163a52c6854d5fac 100644 (file)
@@ -14,6 +14,7 @@ from typing import List
 import pytest
 
 import psycopg
+from psycopg import errors as e
 
 
 @pytest.mark.slow
@@ -151,23 +152,24 @@ def test_notifies(conn, dsn):
     t.join()
 
 
+def canceller(conn, errors):
+    try:
+        time.sleep(0.5)
+        conn.cancel()
+    except Exception as exc:
+        errors.append(exc)
+
+
 @pytest.mark.slow
 def test_cancel(conn):
-    def canceller():
-        try:
-            time.sleep(0.5)
-            conn.cancel()
-        except Exception as exc:
-            errors.append(exc)
-
     errors: List[Exception] = []
 
     cur = conn.cursor()
-    t = threading.Thread(target=canceller)
+    t = threading.Thread(target=canceller, args=(conn, errors))
     t0 = time.time()
     t.start()
 
-    with pytest.raises(psycopg.DatabaseError):
+    with pytest.raises(e.QueryCanceled):
         cur.execute("select pg_sleep(2)")
 
     t1 = time.time()
@@ -181,6 +183,30 @@ def test_cancel(conn):
     t.join()
 
 
+@pytest.mark.slow
+def test_cancel_stream(conn):
+    errors: List[Exception] = []
+
+    cur = conn.cursor()
+    t = threading.Thread(target=canceller, args=(conn, errors))
+    t0 = time.time()
+    t.start()
+
+    with pytest.raises(e.QueryCanceled):
+        for row in cur.stream("select pg_sleep(2)"):
+            pass
+
+    t1 = time.time()
+    assert not errors
+    assert 0.0 < t1 - t0 < 1.0
+
+    # still working
+    conn.rollback()
+    assert cur.execute("select 1").fetchone()[0] == 1
+
+    t.join()
+
+
 @pytest.mark.slow
 def test_identify_closure(dsn):
     def closer():
index 3d1b259630a762cb6f209f888332acf4c9fc95cd..5008918a20fd90725a53bf98f3838fe89c99b4a7 100644 (file)
@@ -9,6 +9,7 @@ from typing import List, Tuple
 import pytest
 
 import psycopg
+from psycopg import errors as e
 from psycopg._compat import create_task
 
 pytestmark = pytest.mark.asyncio
@@ -101,22 +102,48 @@ async def test_notifies(aconn, dsn):
     assert t1 - t0 == pytest.approx(0.5, abs=0.05)
 
 
+async def canceller(aconn, errors):
+    try:
+        await asyncio.sleep(0.5)
+        aconn.cancel()
+    except Exception as exc:
+        errors.append(exc)
+
+
 @pytest.mark.slow
 async def test_cancel(aconn):
-    async def canceller():
-        try:
-            await asyncio.sleep(0.5)
-            aconn.cancel()
-        except Exception as exc:
-            errors.append(exc)
-
     async def worker():
         cur = aconn.cursor()
-        with pytest.raises(psycopg.DatabaseError):
+        with pytest.raises(e.QueryCanceled):
             await cur.execute("select pg_sleep(2)")
 
     errors: List[Exception] = []
-    workers = [worker(), canceller()]
+    workers = [worker(), canceller(aconn, errors)]
+
+    t0 = time.time()
+    await asyncio.gather(*workers)
+
+    t1 = time.time()
+    assert not errors
+    assert 0.0 < t1 - t0 < 1.0
+
+    # still working
+    await aconn.rollback()
+    cur = aconn.cursor()
+    await cur.execute("select 1")
+    assert await cur.fetchone() == (1,)
+
+
+@pytest.mark.slow
+async def test_cancel_stream(aconn):
+    async def worker():
+        cur = aconn.cursor()
+        with pytest.raises(e.QueryCanceled):
+            async for row in cur.stream("select pg_sleep(2)"):
+                pass
+
+    errors: List[Exception] = []
+    workers = [worker(), canceller(aconn, errors)]
 
     t0 = time.time()
     await asyncio.gather(*workers)
index 534f6066e0e1e441d0e41ae53ca8b0e5ade31716..8e7c785008fc946207f5ac92ed1d325d37f50d9e 100644 (file)
@@ -622,6 +622,23 @@ def test_stream_badquery(conn, query):
             pass
 
 
+def test_stream_error_tx(conn):
+    cur = conn.cursor()
+    with pytest.raises(psycopg.ProgrammingError):
+        for rec in cur.stream("wat"):
+            pass
+    assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_stream_error_notx(conn):
+    conn.autocommit = True
+    cur = conn.cursor()
+    with pytest.raises(psycopg.ProgrammingError):
+        for rec in cur.stream("wat"):
+            pass
+    assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+
 def test_stream_binary_cursor(conn):
     cur = conn.cursor(binary=True)
     recs = []
index 3af0e5775ba22804d955d3c051e0471e1725e7a5..3fcbea0ee7154ec19b2ef910d4aaf94fb46ebe41 100644 (file)
@@ -614,6 +614,23 @@ async def test_stream_badquery(aconn, query):
             pass
 
 
+async def test_stream_error_tx(aconn):
+    cur = aconn.cursor()
+    with pytest.raises(psycopg.ProgrammingError):
+        async for rec in cur.stream("wat"):
+            pass
+    assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_stream_error_notx(aconn):
+    await aconn.set_autocommit(True)
+    cur = aconn.cursor()
+    with pytest.raises(psycopg.ProgrammingError):
+        async for rec in cur.stream("wat"):
+            pass
+    assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE
+
+
 async def test_stream_binary_cursor(aconn):
     cur = aconn.cursor(binary=True)
     recs = []