From: Daniele Varrazzo Date: Tue, 7 Jun 2022 07:21:03 +0000 (+0200) Subject: fix: restore the connection in usable state after an error in stream() X-Git-Tag: 3.1~66^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=54325178e6e78efd7c20394a68a108e9d48afee8;p=thirdparty%2Fpsycopg.git fix: restore the connection in usable state after an error in stream() --- diff --git a/docs/news.rst b/docs/news.rst index 2ea5328e2..441960e41 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -35,6 +35,7 @@ Psycopg 3.0.15 (unreleased) - Fix wrong escaping of unprintable chars in COPY (nonetheless correctly interpreted by PostgreSQL). +- Restore the connection to usable state after an error in `~Cursor.stream()`. Current release diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index be391bb27..30c41d8ea 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -56,6 +56,8 @@ FATAL_ERROR = pq.ExecStatus.FATAL_ERROR SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED +ACTIVE = pq.TransactionStatus.ACTIVE + class BaseCursor(Generic[ConnectionType, Row]): __slots__ = """ @@ -786,6 +788,14 @@ class Cursor(BaseCursor["Connection[Any]", Row]): yield rec first = False except e.Error as ex: + # try to get out of ACTIVE state. Just do a single attempt, which + # shoud work to recover from an error or query cancelled. + if self._pgconn.transaction_status == ACTIVE: + try: + self._conn.wait(self._stream_fetchone_gen(first)) + except Exception: + pass + raise ex.with_traceback(None) def fetchone(self) -> Optional[Row]: diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index 4598bf434..5044732f8 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -9,6 +9,7 @@ from typing import Any, AsyncIterator, Iterable, List from typing import Optional, Type, TypeVar, TYPE_CHECKING, overload from contextlib import asynccontextmanager +from . import pq from . import errors as e from .abc import Query, Params from .copy import AsyncCopy @@ -144,6 +145,14 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): yield rec first = False except e.Error as ex: + # try to get out of ACTIVE state. Just do a single attempt, which + # shoud work to recover from an error or query cancelled. + if self._pgconn.transaction_status == pq.TransactionStatus.ACTIVE: + try: + await self._conn.wait(self._stream_fetchone_gen(first)) + except Exception: + pass + raise ex.with_traceback(None) async def fetchone(self) -> Optional[Row]: diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 0f2f5e840..59920e70f 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -14,6 +14,7 @@ from typing import List import pytest import psycopg +from psycopg import errors as e @pytest.mark.slow @@ -151,23 +152,24 @@ def test_notifies(conn, dsn): t.join() +def canceller(conn, errors): + try: + time.sleep(0.5) + conn.cancel() + except Exception as exc: + errors.append(exc) + + @pytest.mark.slow def test_cancel(conn): - def canceller(): - try: - time.sleep(0.5) - conn.cancel() - except Exception as exc: - errors.append(exc) - errors: List[Exception] = [] cur = conn.cursor() - t = threading.Thread(target=canceller) + t = threading.Thread(target=canceller, args=(conn, errors)) t0 = time.time() t.start() - with pytest.raises(psycopg.DatabaseError): + with pytest.raises(e.QueryCanceled): cur.execute("select pg_sleep(2)") t1 = time.time() @@ -181,6 +183,30 @@ def test_cancel(conn): t.join() +@pytest.mark.slow +def test_cancel_stream(conn): + errors: List[Exception] = [] + + cur = conn.cursor() + t = threading.Thread(target=canceller, args=(conn, errors)) + t0 = time.time() + t.start() + + with pytest.raises(e.QueryCanceled): + for row in cur.stream("select pg_sleep(2)"): + pass + + t1 = time.time() + assert not errors + assert 0.0 < t1 - t0 < 1.0 + + # still working + conn.rollback() + assert cur.execute("select 1").fetchone()[0] == 1 + + t.join() + + @pytest.mark.slow def test_identify_closure(dsn): def closer(): diff --git a/tests/test_concurrency_async.py b/tests/test_concurrency_async.py index 3d1b25963..5008918a2 100644 --- a/tests/test_concurrency_async.py +++ b/tests/test_concurrency_async.py @@ -9,6 +9,7 @@ from typing import List, Tuple import pytest import psycopg +from psycopg import errors as e from psycopg._compat import create_task pytestmark = pytest.mark.asyncio @@ -101,22 +102,48 @@ async def test_notifies(aconn, dsn): assert t1 - t0 == pytest.approx(0.5, abs=0.05) +async def canceller(aconn, errors): + try: + await asyncio.sleep(0.5) + aconn.cancel() + except Exception as exc: + errors.append(exc) + + @pytest.mark.slow async def test_cancel(aconn): - async def canceller(): - try: - await asyncio.sleep(0.5) - aconn.cancel() - except Exception as exc: - errors.append(exc) - async def worker(): cur = aconn.cursor() - with pytest.raises(psycopg.DatabaseError): + with pytest.raises(e.QueryCanceled): await cur.execute("select pg_sleep(2)") errors: List[Exception] = [] - workers = [worker(), canceller()] + workers = [worker(), canceller(aconn, errors)] + + t0 = time.time() + await asyncio.gather(*workers) + + t1 = time.time() + assert not errors + assert 0.0 < t1 - t0 < 1.0 + + # still working + await aconn.rollback() + cur = aconn.cursor() + await cur.execute("select 1") + assert await cur.fetchone() == (1,) + + +@pytest.mark.slow +async def test_cancel_stream(aconn): + async def worker(): + cur = aconn.cursor() + with pytest.raises(e.QueryCanceled): + async for row in cur.stream("select pg_sleep(2)"): + pass + + errors: List[Exception] = [] + workers = [worker(), canceller(aconn, errors)] t0 = time.time() await asyncio.gather(*workers) diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 534f6066e..8e7c78500 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -622,6 +622,23 @@ def test_stream_badquery(conn, query): pass +def test_stream_error_tx(conn): + cur = conn.cursor() + with pytest.raises(psycopg.ProgrammingError): + for rec in cur.stream("wat"): + pass + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_stream_error_notx(conn): + conn.autocommit = True + cur = conn.cursor() + with pytest.raises(psycopg.ProgrammingError): + for rec in cur.stream("wat"): + pass + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + + def test_stream_binary_cursor(conn): cur = conn.cursor(binary=True) recs = [] diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index 3af0e5775..3fcbea0ee 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -614,6 +614,23 @@ async def test_stream_badquery(aconn, query): pass +async def test_stream_error_tx(aconn): + cur = aconn.cursor() + with pytest.raises(psycopg.ProgrammingError): + async for rec in cur.stream("wat"): + pass + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_stream_error_notx(aconn): + await aconn.set_autocommit(True) + cur = aconn.cursor() + with pytest.raises(psycopg.ProgrammingError): + async for rec in cur.stream("wat"): + pass + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + + async def test_stream_binary_cursor(aconn): cur = aconn.cursor(binary=True) recs = []