]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: make sure to terminate query on gen.close() from Cursor.stream()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 21 Sep 2022 10:40:25 +0000 (11:40 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 28 Sep 2022 02:48:41 +0000 (03:48 +0100)
Fix #382

docs/api/cursors.rst
docs/news.rst
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
tests/test_cursor.py
tests/test_cursor_async.py

index 704fe1bb39546ee58eb0b7f961da81de756a9e38..7a83a634bec37c6878c797951437d510a381e583 100644 (file)
@@ -156,11 +156,23 @@ The `!Cursor` class
             to receive further commands (with a message such as *another
             command is already in progress*).
 
-            You can restore the connection to a working state by consuming
-            the generator entirely: see `this comment`__ to get a few ideas
-            about how to do it.
-
-            .. __: https://github.com/psycopg/psycopg/issues/382#issuecomment-1253582340
+            If there is a chance that the generator is not consumed entirely,
+            in order to restore the connection to a working state you can call
+            `~generator.close` on the generator object returned by `!stream()`. The
+            `contextlib.closing` function might be particularly useful to make
+            sure that `!close()` is called:
+
+            .. code::
+
+                with closing(cur.stream("select generate_series(1, 10000)")) as gen:
+                    for rec in gen:
+                        something(rec)  # might fail
+
+            Without calling `!close()`, in case of error, the connection will
+            be `!ACTIVE` and unusable. If `!close()` is called, the connection
+            might be `!INTRANS` or `!INERROR`, depending on whether the server
+            managed to send the entire resultset to the client. An autocommit
+            connection will be `!IDLE` instead.
 
 
     .. attribute:: format
index 5db0accd0cbecaaf908680db9b75ee967fa803d5..430455d3924c8c95ab45c248d737332ca5b7481d 100644 (file)
@@ -13,6 +13,8 @@ Future releases
 Psycopg 3.1.3 (unreleased)
 ^^^^^^^^^^^^^^^^^^^^^^^^^^
 
+- Restore the state of the connection if `Cursor.stream()` is terminated
+  prematurely (:ticket:`#382`).
 - Fix regression introduced in 3.1 with different named tuples mangling rules
   for non-ascii attribute names (:ticket:`#386`).
 
index 70eebfbbac78b65a236290fb4db62fef0e1aa82a..7fe4b4773f606f834f6700c2ddba06b364740354 100644 (file)
@@ -781,16 +781,26 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
                     rec: Row = self._tx.load_row(0, self._make_row)  # type: ignore
                     yield rec
                     first = False
+
         except e.Error as ex:
-            # try to get out of ACTIVE state. Just do a single attempt, which
-            # should work to recover from an error or query cancelled.
+            raise ex.with_traceback(None)
+
+        finally:
             if self._pgconn.transaction_status == ACTIVE:
+                # Try to cancel the query, then consume the results already received.
+                self._conn.cancel()
                 try:
-                    self._conn.wait(self._stream_fetchone_gen(first))
+                    while self._conn.wait(self._stream_fetchone_gen(first=False)):
+                        pass
                 except Exception:
                     pass
 
-            raise ex.with_traceback(None)
+                # Try to get out of ACTIVE state. Just do a single attempt, which
+                # should work to recover from an error or query cancelled.
+                try:
+                    self._conn.wait(self._stream_fetchone_gen(first=False))
+                except Exception:
+                    pass
 
     def fetchone(self) -> Optional[Row]:
         """
index 8aa7f71d2d3a9d460115362a6a8ac944826f2e78..4a108175b09d3b6a711fdcadbf325e1834f9618e 100644 (file)
@@ -20,6 +20,8 @@ from ._pipeline import Pipeline
 if TYPE_CHECKING:
     from .connection_async import AsyncConnection
 
+ACTIVE = pq.TransactionStatus.ACTIVE
+
 
 class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
     __module__ = "psycopg"
@@ -143,16 +145,26 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
                     rec: Row = self._tx.load_row(0, self._make_row)  # type: ignore
                     yield rec
                     first = False
+
         except e.Error as ex:
-            # try to get out of ACTIVE state. Just do a single attempt, which
-            # should work to recover from an error or query cancelled.
-            if self._pgconn.transaction_status == pq.TransactionStatus.ACTIVE:
+            raise ex.with_traceback(None)
+
+        finally:
+            if self._pgconn.transaction_status == ACTIVE:
+                # Try to cancel the query, then consume the results already received.
+                self._conn.cancel()
                 try:
-                    await self._conn.wait(self._stream_fetchone_gen(first))
+                    while await self._conn.wait(self._stream_fetchone_gen(first=False)):
+                        pass
                 except Exception:
                     pass
 
-            raise ex.with_traceback(None)
+                # Try to get out of ACTIVE state. Just do a single attempt, which
+                # should work to recover from an error or query cancelled.
+                try:
+                    await self._conn.wait(self._stream_fetchone_gen(first=False))
+                except Exception:
+                    pass
 
     async def fetchone(self) -> Optional[Row]:
         await self._fetch_pipeline()
index 75af433cbe9ad9191a796e1685b140970c0bdf6b..bc1c5799ece7d62b76d83e3c971f3b57b2a25ca9 100644 (file)
@@ -3,6 +3,7 @@ import pickle
 import weakref
 import datetime as dt
 from typing import List, Union
+from contextlib import closing
 
 import pytest
 
@@ -645,6 +646,28 @@ def test_stream_error_notx(conn):
     assert conn.info.transaction_status == conn.TransactionStatus.IDLE
 
 
+def test_stream_error_python_to_consume(conn):
+    cur = conn.cursor()
+    with pytest.raises(ZeroDivisionError):
+        with closing(cur.stream("select generate_series(1, 10000)")) as gen:
+            for rec in gen:
+                1 / 0
+    assert conn.info.transaction_status in (
+        conn.TransactionStatus.INTRANS,
+        conn.TransactionStatus.INERROR,
+    )
+
+
+def test_stream_error_python_consumed(conn):
+    cur = conn.cursor()
+    with pytest.raises(ZeroDivisionError):
+        gen = cur.stream("select 1")
+        for rec in gen:
+            1 / 0
+    gen.close()
+    assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
 def test_stream_close(conn):
     cur = conn.cursor()
     with pytest.raises(psycopg.OperationalError):
index 741eba3d96cf5a911c7605a664ba081b9dfb034c..50de79ee8bdb490c9bc277e92ff63d3c6263dbb8 100644 (file)
@@ -637,6 +637,31 @@ async def test_stream_error_notx(aconn):
     assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE
 
 
+async def test_stream_error_python_to_consume(aconn):
+    cur = aconn.cursor()
+    with pytest.raises(ZeroDivisionError):
+        gen = cur.stream("select generate_series(1, 10000)")
+        async for rec in gen:
+            1 / 0
+
+    await gen.aclose()
+    assert aconn.info.transaction_status in (
+        aconn.TransactionStatus.INTRANS,
+        aconn.TransactionStatus.INERROR,
+    )
+
+
+async def test_stream_error_python_consumed(aconn):
+    cur = aconn.cursor()
+    with pytest.raises(ZeroDivisionError):
+        gen = cur.stream("select 1")
+        async for rec in gen:
+            1 / 0
+
+    await gen.aclose()
+    assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
 async def test_stream_close(aconn):
     await aconn.set_autocommit(True)
     cur = aconn.cursor()