]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix server-side cursor close() with committed transactions.
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 22 Jul 2021 15:43:39 +0000 (17:43 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 23 Jul 2021 14:39:05 +0000 (16:39 +0200)
psycopg/psycopg/server_cursor.py
tests/test_server_cursor.py
tests/test_server_cursor_async.py

index 2bf703eaad0cf428428f6592040e27506457dc20..2739c0878b6ff1304c8dc903daa4ddef9b171ead 100644 (file)
@@ -90,11 +90,14 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
         self.described = True
 
     def _close_gen(self, cur: BaseCursor[ConnectionType, Row]) -> PQGen[None]:
+        ts = cur._conn.pgconn.transaction_status
+
         # if the connection is not in a sane state, don't even try
-        if cur._conn.pgconn.transaction_status not in (
-            pq.TransactionStatus.IDLE,
-            pq.TransactionStatus.INTRANS,
-        ):
+        if ts not in (pq.TransactionStatus.IDLE, pq.TransactionStatus.INTRANS):
+            return
+
+        # If we are IDLE, a WITHOUT HOLD cursor will surely have gone already.
+        if not self.withhold and ts == pq.TransactionStatus.IDLE:
             return
 
         # if we didn't declare the cursor ourselves we still have to close it
index 6b2974802abd099cd899b74cfc1f37cb3c505200..57c58b9a70313ed4de3493370689df46a163adf3 100644 (file)
@@ -76,6 +76,15 @@ def test_close_noop(conn, recwarn, retries):
             assert not recwarn, [str(w.message) for w in recwarn.list]
 
 
+def test_close_on_error(conn):
+    cur = conn.cursor("foo")
+    cur.execute("select 1")
+    with pytest.raises(e.ProgrammingError):
+        conn.execute("wat")
+    assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+    cur.close()
+
+
 def test_context(conn, recwarn, retries):
     for retry in retries:
         with retry:
@@ -292,12 +301,12 @@ def test_non_scrollable(conn):
 
 @pytest.mark.parametrize("kwargs", [{}, {"withhold": False}])
 def test_no_hold(conn, kwargs):
-    with pytest.raises(e.InvalidCursorName):
-        with conn.cursor("foo", **kwargs) as curs:
-            assert curs.withhold is False
-            curs.execute("select generate_series(0, 2)")
-            assert curs.fetchone() == (0,)
-            conn.commit()
+    with conn.cursor("foo", **kwargs) as curs:
+        assert curs.withhold is False
+        curs.execute("select generate_series(0, 2)")
+        assert curs.fetchone() == (0,)
+        conn.commit()
+        with pytest.raises(e.InvalidCursorName):
             curs.fetchone()
 
 
index ca64590b90db126246ee490aebbd06e0fc307d17..996c70a5c749ef7d73cb22c6db4f3c0dacf66528 100644 (file)
@@ -80,6 +80,15 @@ async def test_close_noop(aconn, recwarn, retries):
             assert not recwarn, [str(w.message) for w in recwarn.list]
 
 
+async def test_close_on_error(aconn):
+    cur = aconn.cursor("foo")
+    await cur.execute("select 1")
+    with pytest.raises(e.ProgrammingError):
+        await aconn.execute("wat")
+    assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+    await cur.close()
+
+
 async def test_context(aconn, recwarn, retries):
     async for retry in retries:
         with retry:
@@ -303,12 +312,12 @@ async def test_non_scrollable(aconn):
 
 @pytest.mark.parametrize("kwargs", [{}, {"withhold": False}])
 async def test_no_hold(aconn, kwargs):
-    with pytest.raises(e.InvalidCursorName):
-        async with aconn.cursor("foo", **kwargs) as curs:
-            assert curs.withhold is False
-            await curs.execute("select generate_series(0, 2)")
-            assert await curs.fetchone() == (0,)
-            await aconn.commit()
+    async with aconn.cursor("foo", **kwargs) as curs:
+        assert curs.withhold is False
+        await curs.execute("select generate_series(0, 2)")
+        assert await curs.fetchone() == (0,)
+        await aconn.commit()
+        with pytest.raises(e.InvalidCursorName):
             await curs.fetchone()