From: Daniele Varrazzo Date: Wed, 21 Sep 2022 10:40:25 +0000 (+0100) Subject: fix: make sure to terminate query on gen.close() from Cursor.stream() X-Git-Tag: 3.1.3~5^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0a9969a2dcb87444a54d8a083656731d25578322;p=thirdparty%2Fpsycopg.git fix: make sure to terminate query on gen.close() from Cursor.stream() Fix #382 --- diff --git a/docs/api/cursors.rst b/docs/api/cursors.rst index 704fe1bb3..7a83a634b 100644 --- a/docs/api/cursors.rst +++ b/docs/api/cursors.rst @@ -156,11 +156,23 @@ The `!Cursor` class to receive further commands (with a message such as *another command is already in progress*). - You can restore the connection to a working state by consuming - the generator entirely: see `this comment`__ to get a few ideas - about how to do it. - - .. __: https://github.com/psycopg/psycopg/issues/382#issuecomment-1253582340 + If there is a chance that the generator is not consumed entirely, + in order to restore the connection to a working state you can call + `~generator.close` on the generator object returned by `!stream()`. The + `contextlib.closing` function might be particularly useful to make + sure that `!close()` is called: + + .. code:: + + with closing(cur.stream("select generate_series(1, 10000)")) as gen: + for rec in gen: + something(rec) # might fail + + Without calling `!close()`, in case of error, the connection will + be `!ACTIVE` and unusable. If `!close()` is called, the connection + might be `!INTRANS` or `!INERROR`, depending on whether the server + managed to send the entire resultset to the client. An autocommit + connection will be `!IDLE` instead. .. attribute:: format diff --git a/docs/news.rst b/docs/news.rst index 5db0accd0..430455d39 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -13,6 +13,8 @@ Future releases Psycopg 3.1.3 (unreleased) ^^^^^^^^^^^^^^^^^^^^^^^^^^ +- Restore the state of the connection if `Cursor.stream()` is terminated + prematurely (:ticket:`#382`). - Fix regression introduced in 3.1 with different named tuples mangling rules for non-ascii attribute names (:ticket:`#386`). diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 70eebfbba..7fe4b4773 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -781,16 +781,26 @@ class Cursor(BaseCursor["Connection[Any]", Row]): rec: Row = self._tx.load_row(0, self._make_row) # type: ignore yield rec first = False + except e.Error as ex: - # try to get out of ACTIVE state. Just do a single attempt, which - # should work to recover from an error or query cancelled. + raise ex.with_traceback(None) + + finally: if self._pgconn.transaction_status == ACTIVE: + # Try to cancel the query, then consume the results already received. + self._conn.cancel() try: - self._conn.wait(self._stream_fetchone_gen(first)) + while self._conn.wait(self._stream_fetchone_gen(first=False)): + pass except Exception: pass - raise ex.with_traceback(None) + # Try to get out of ACTIVE state. Just do a single attempt, which + # should work to recover from an error or query cancelled. + try: + self._conn.wait(self._stream_fetchone_gen(first=False)) + except Exception: + pass def fetchone(self) -> Optional[Row]: """ diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index 8aa7f71d2..4a108175b 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -20,6 +20,8 @@ from ._pipeline import Pipeline if TYPE_CHECKING: from .connection_async import AsyncConnection +ACTIVE = pq.TransactionStatus.ACTIVE + class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): __module__ = "psycopg" @@ -143,16 +145,26 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): rec: Row = self._tx.load_row(0, self._make_row) # type: ignore yield rec first = False + except e.Error as ex: - # try to get out of ACTIVE state. Just do a single attempt, which - # should work to recover from an error or query cancelled. - if self._pgconn.transaction_status == pq.TransactionStatus.ACTIVE: + raise ex.with_traceback(None) + + finally: + if self._pgconn.transaction_status == ACTIVE: + # Try to cancel the query, then consume the results already received. + self._conn.cancel() try: - await self._conn.wait(self._stream_fetchone_gen(first)) + while await self._conn.wait(self._stream_fetchone_gen(first=False)): + pass except Exception: pass - raise ex.with_traceback(None) + # Try to get out of ACTIVE state. Just do a single attempt, which + # should work to recover from an error or query cancelled. + try: + await self._conn.wait(self._stream_fetchone_gen(first=False)) + except Exception: + pass async def fetchone(self) -> Optional[Row]: await self._fetch_pipeline() diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 75af433cb..bc1c5799e 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -3,6 +3,7 @@ import pickle import weakref import datetime as dt from typing import List, Union +from contextlib import closing import pytest @@ -645,6 +646,28 @@ def test_stream_error_notx(conn): assert conn.info.transaction_status == conn.TransactionStatus.IDLE +def test_stream_error_python_to_consume(conn): + cur = conn.cursor() + with pytest.raises(ZeroDivisionError): + with closing(cur.stream("select generate_series(1, 10000)")) as gen: + for rec in gen: + 1 / 0 + assert conn.info.transaction_status in ( + conn.TransactionStatus.INTRANS, + conn.TransactionStatus.INERROR, + ) + + +def test_stream_error_python_consumed(conn): + cur = conn.cursor() + with pytest.raises(ZeroDivisionError): + gen = cur.stream("select 1") + for rec in gen: + 1 / 0 + gen.close() + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + def test_stream_close(conn): cur = conn.cursor() with pytest.raises(psycopg.OperationalError): diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index 741eba3d9..50de79ee8 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -637,6 +637,31 @@ async def test_stream_error_notx(aconn): assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE +async def test_stream_error_python_to_consume(aconn): + cur = aconn.cursor() + with pytest.raises(ZeroDivisionError): + gen = cur.stream("select generate_series(1, 10000)") + async for rec in gen: + 1 / 0 + + await gen.aclose() + assert aconn.info.transaction_status in ( + aconn.TransactionStatus.INTRANS, + aconn.TransactionStatus.INERROR, + ) + + +async def test_stream_error_python_consumed(aconn): + cur = aconn.cursor() + with pytest.raises(ZeroDivisionError): + gen = cur.stream("select 1") + async for rec in gen: + 1 / 0 + + await gen.aclose() + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + async def test_stream_close(aconn): await aconn.set_autocommit(True) cur = aconn.cursor()