]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: restore the connection in usable state after an error in stream()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Jun 2022 07:21:03 +0000 (09:21 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Jun 2022 07:32:19 +0000 (09:32 +0200)
docs/news.rst
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
tests/test_concurrency.py
tests/test_concurrency_async.py
tests/test_cursor.py
tests/test_cursor_async.py

index 51b78a48dbdcfd21ace97cffb94bcd9d472eb1f7..56cf42e24a6aae1c922f444c4ab967dfe7d087d2 100644 (file)
@@ -15,6 +15,7 @@ Psycopg 3.0.15 (unreleased)
 
 - Fix wrong escaping of unprintable chars in COPY (nonetheless correctly
   interpreted by PostgreSQL).
+- Restore the connection to usable state after an error in `~Cursor.stream()`.
 
 
 Current release
index c64eab2bd4441006d26a65e2a02b5f8f45559ddd..bd83801291c180c87f5e423d7f2523df81d6a51e 100644 (file)
@@ -43,6 +43,8 @@ else:
 
 _C = TypeVar("_C", bound="Cursor[Any]")
 
+ACTIVE = pq.TransactionStatus.ACTIVE
+
 
 class BaseCursor(Generic[ConnectionType, Row]):
     # Slots with __weakref__ and generic bases don't work on Py 3.6
@@ -585,6 +587,14 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
                     yield rec
                     first = False
         except e.Error as ex:
+            # try to get out of ACTIVE state. Just do a single attempt, which
+            # shoud work to recover from an error or query cancelled.
+            if self._pgconn.transaction_status == ACTIVE:
+                try:
+                    self._conn.wait(self._stream_fetchone_gen(first))
+                except Exception:
+                    pass
+
             raise ex.with_traceback(None)
 
     def fetchone(self) -> Optional[Row]:
index 0b665b43534b34a708b294fe2067dac6e26b6771..eab8ce4a0c994cdd374c638422e69baa7664f0dd 100644 (file)
@@ -8,6 +8,7 @@ from types import TracebackType
 from typing import Any, AsyncIterator, Iterable, List
 from typing import Optional, Type, TypeVar, TYPE_CHECKING
 
+from . import pq
 from . import errors as e
 
 from .abc import Query, Params
@@ -105,6 +106,14 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
                     yield rec
                     first = False
         except e.Error as ex:
+            # try to get out of ACTIVE state. Just do a single attempt, which
+            # shoud work to recover from an error or query cancelled.
+            if self._pgconn.transaction_status == pq.TransactionStatus.ACTIVE:
+                try:
+                    await self._conn.wait(self._stream_fetchone_gen(first))
+                except Exception:
+                    pass
+
             raise ex.with_traceback(None)
 
     async def fetchone(self) -> Optional[Row]:
index d09cc20db8b3fc1247921d3099122d7cb8169da3..0d9bc972f927144cb40a592d7ed4456bf84da51e 100644 (file)
@@ -14,6 +14,7 @@ from typing import List
 import pytest
 
 import psycopg
+from psycopg import errors as e
 
 
 @pytest.mark.slow
@@ -151,23 +152,24 @@ def test_notifies(conn, dsn):
     t.join()
 
 
+def canceller(conn, errors):
+    try:
+        time.sleep(0.5)
+        conn.cancel()
+    except Exception as exc:
+        errors.append(exc)
+
+
 @pytest.mark.slow
 def test_cancel(conn):
-    def canceller():
-        try:
-            time.sleep(0.5)
-            conn.cancel()
-        except Exception as exc:
-            errors.append(exc)
-
     errors: List[Exception] = []
 
     cur = conn.cursor()
-    t = threading.Thread(target=canceller)
+    t = threading.Thread(target=canceller, args=(conn, errors))
     t0 = time.time()
     t.start()
 
-    with pytest.raises(psycopg.DatabaseError):
+    with pytest.raises(e.QueryCanceled):
         cur.execute("select pg_sleep(2)")
 
     t1 = time.time()
@@ -181,6 +183,30 @@ def test_cancel(conn):
     t.join()
 
 
+@pytest.mark.slow
+def test_cancel_stream(conn):
+    errors: List[Exception] = []
+
+    cur = conn.cursor()
+    t = threading.Thread(target=canceller, args=(conn, errors))
+    t0 = time.time()
+    t.start()
+
+    with pytest.raises(e.QueryCanceled):
+        for row in cur.stream("select pg_sleep(2)"):
+            pass
+
+    t1 = time.time()
+    assert not errors
+    assert 0.0 < t1 - t0 < 1.0
+
+    # still working
+    conn.rollback()
+    assert cur.execute("select 1").fetchone()[0] == 1
+
+    t.join()
+
+
 @pytest.mark.slow
 def test_identify_closure(dsn):
     def closer():
index d98808f0e3436f6a0d29dbbeab73a03f646fec96..3d3e326464652fc37e8c9d396a1d2789f45f7536 100644 (file)
@@ -9,6 +9,7 @@ from typing import List, Tuple
 import pytest
 
 import psycopg
+from psycopg import errors as e
 from psycopg._compat import create_task
 
 pytestmark = pytest.mark.asyncio
@@ -101,22 +102,48 @@ async def test_notifies(aconn, dsn):
     assert t1 - t0 == pytest.approx(0.5, abs=0.05)
 
 
+async def canceller(aconn, errors):
+    try:
+        await asyncio.sleep(0.5)
+        aconn.cancel()
+    except Exception as exc:
+        errors.append(exc)
+
+
 @pytest.mark.slow
 async def test_cancel(aconn):
-    async def canceller():
-        try:
-            await asyncio.sleep(0.5)
-            aconn.cancel()
-        except Exception as exc:
-            errors.append(exc)
-
     async def worker():
         cur = aconn.cursor()
-        with pytest.raises(psycopg.DatabaseError):
+        with pytest.raises(e.QueryCanceled):
             await cur.execute("select pg_sleep(2)")
 
     errors: List[Exception] = []
-    workers = [worker(), canceller()]
+    workers = [worker(), canceller(aconn, errors)]
+
+    t0 = time.time()
+    await asyncio.gather(*workers)
+
+    t1 = time.time()
+    assert not errors
+    assert 0.0 < t1 - t0 < 1.0
+
+    # still working
+    await aconn.rollback()
+    cur = aconn.cursor()
+    await cur.execute("select 1")
+    assert await cur.fetchone() == (1,)
+
+
+@pytest.mark.slow
+async def test_cancel_stream(aconn):
+    async def worker():
+        cur = aconn.cursor()
+        with pytest.raises(e.QueryCanceled):
+            async for row in cur.stream("select pg_sleep(2)"):
+                pass
+
+    errors: List[Exception] = []
+    workers = [worker(), canceller(aconn, errors)]
 
     t0 = time.time()
     await asyncio.gather(*workers)
index 21114100aa0035a77532b8c5f6d123920afd602d..9f02710c9a11b26feb5fe0831a4bfcf204b1fb87 100644 (file)
@@ -570,6 +570,23 @@ def test_stream_badquery(conn, query):
             pass
 
 
+def test_stream_error_tx(conn):
+    cur = conn.cursor()
+    with pytest.raises(psycopg.ProgrammingError):
+        for rec in cur.stream("wat"):
+            pass
+    assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_stream_error_notx(conn):
+    conn.autocommit = True
+    cur = conn.cursor()
+    with pytest.raises(psycopg.ProgrammingError):
+        for rec in cur.stream("wat"):
+            pass
+    assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+
 def test_stream_binary_cursor(conn):
     cur = conn.cursor(binary=True)
     recs = []
index 08b09ebee4e2d48c286bea92851054dc10c8b4f7..29c1a583a7409538ed81cc32719762d29a97e6e3 100644 (file)
@@ -562,6 +562,23 @@ async def test_stream_badquery(aconn, query):
             pass
 
 
+async def test_stream_error_tx(aconn):
+    cur = aconn.cursor()
+    with pytest.raises(psycopg.ProgrammingError):
+        async for rec in cur.stream("wat"):
+            pass
+    assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_stream_error_notx(aconn):
+    await aconn.set_autocommit(True)
+    cur = aconn.cursor()
+    with pytest.raises(psycopg.ProgrammingError):
+        async for rec in cur.stream("wat"):
+            pass
+    assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE
+
+
 async def test_stream_binary_cursor(aconn):
     cur = aconn.cursor(binary=True)
     recs = []