]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Don't clobber error leaving a ServerCursor block with a broken connection
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 10 Feb 2021 17:49:06 +0000 (18:49 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 10 Feb 2021 17:49:06 +0000 (18:49 +0100)
psycopg3/psycopg3/server_cursor.py
tests/test_server_cursor.py
tests/test_server_cursor_async.py

index 065967de54ef9b5d464769123f2618ee9b9276e4..289485edc0fe86bd2a2e5d8235306cd599aaeb16 100644 (file)
@@ -72,6 +72,13 @@ class ServerCursorHelper(Generic[ConnectionType]):
         self.described = True
 
     def _close_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]:
+        # if the connection is not in a sane state, don't even try
+        if cur._conn.pgconn.transaction_status not in (
+            pq.TransactionStatus.IDLE,
+            pq.TransactionStatus.INTRANS,
+        ):
+            return
+
         # if we didn't declare the cursor ourselves we still have to close it
         # but we must make sure it exists.
         if not self.described:
index fa9bbd12af86772aba2fffcbeb621abfc826b8ca..ff9c2559bfbb80ba87fcb4903995902e84bf79c8 100644 (file)
@@ -1,5 +1,6 @@
 import pytest
 
+from psycopg3 import errors as e
 from psycopg3.pq import Format
 
 
@@ -80,6 +81,12 @@ def test_context(conn, recwarn):
     assert not recwarn
 
 
+def test_close_no_clobber(conn):
+    with pytest.raises(e.DivisionByZero):
+        with conn.cursor("foo") as cur:
+            cur.execute("select 1 / %s", (0,))
+
+
 def test_warn_close(conn, recwarn):
     cur = conn.cursor("foo")
     cur.execute("select generate_series(1, 10) as bar")
@@ -89,7 +96,7 @@ def test_warn_close(conn, recwarn):
 
 def test_executemany(conn):
     cur = conn.cursor("foo")
-    with pytest.raises(conn.NotSupportedError):
+    with pytest.raises(e.NotSupportedError):
         cur.executemany("select %s", [(1,), (2,)])
 
 
@@ -182,7 +189,7 @@ def test_itersize(conn, commands):
 
 def test_scroll(conn):
     cur = conn.cursor("tmp")
-    with pytest.raises(conn.ProgrammingError):
+    with pytest.raises(e.ProgrammingError):
         cur.scroll(0)
 
     cur.execute("select generate_series(0,9)")
@@ -213,7 +220,7 @@ def test_non_scrollable(conn):
     curs = conn.cursor("foo")
     curs.execute("select generate_series(0, 5)", scrollable=False)
     curs.scroll(5)
-    with pytest.raises(conn.OperationalError):
+    with pytest.raises(e.OperationalError):
         curs.scroll(-1)
 
 
index ace61771b2cfddc8fdb4f52efba46c7a3375ef6d..685c559285cf8a0ec88d309685d559bd8a048ed5 100644 (file)
@@ -1,5 +1,6 @@
 import pytest
 
+from psycopg3 import errors as e
 from psycopg3.pq import Format
 
 pytestmark = pytest.mark.asyncio
@@ -82,6 +83,12 @@ async def test_context(aconn, recwarn):
     assert not recwarn
 
 
+async def test_close_no_clobber(aconn):
+    with pytest.raises(e.DivisionByZero):
+        async with aconn.cursor("foo") as cur:
+            await cur.execute("select 1 / %s", (0,))
+
+
 async def test_warn_close(aconn, recwarn):
     cur = aconn.cursor("foo")
     await cur.execute("select generate_series(1, 10) as bar")
@@ -91,7 +98,7 @@ async def test_warn_close(aconn, recwarn):
 
 async def test_executemany(aconn):
     cur = aconn.cursor("foo")
-    with pytest.raises(aconn.NotSupportedError):
+    with pytest.raises(e.NotSupportedError):
         await cur.executemany("select %s", [(1,), (2,)])
 
 
@@ -189,7 +196,7 @@ async def test_itersize(aconn, acommands):
 
 async def test_scroll(aconn):
     cur = aconn.cursor("tmp")
-    with pytest.raises(aconn.ProgrammingError):
+    with pytest.raises(e.ProgrammingError):
         await cur.scroll(0)
 
     await cur.execute("select generate_series(0,9)")
@@ -220,7 +227,7 @@ async def test_non_scrollable(aconn):
     curs = aconn.cursor("foo")
     await curs.execute("select generate_series(0, 5)", scrollable=False)
     await curs.scroll(5)
-    with pytest.raises(aconn.OperationalError):
+    with pytest.raises(e.OperationalError):
         await curs.scroll(-1)