]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Close correctly named cursors in corner cases
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 10 Feb 2021 02:04:00 +0000 (03:04 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 10 Feb 2021 02:04:00 +0000 (03:04 +0100)
psycopg3/psycopg3/named_cursor.py
tests/test_named_cursor.py
tests/test_named_cursor_async.py

index 434901f9b792515d9848a15d3b51c4527c93439c..6e6ec71eeb9676a398133a3ccd45d433dc9fa010 100644 (file)
@@ -60,6 +60,16 @@ class NamedCursorHelper(Generic[ConnectionType]):
         self.described = True
 
     def _close_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]:
+        # if we didn't declare the cursor ourselves we still have to close it
+        # but we must make sure it exists.
+        if not self.described:
+            query = sql.SQL(
+                "select 1 from pg_catalog.pg_cursors where name = {}"
+            ).format(sql.Literal(self.name))
+            res = yield from cur._conn._exec_command(query)
+            if res.ntuples == 0:
+                return
+
         query = sql.SQL("close {}").format(sql.Identifier(self.name))
         yield from cur._conn._exec_command(query)
 
index 4b142dd954629879627967058b7a845cd01fb6e6..e3bbca10a390d0fd8f83d38431a8dd8efac5f5cc 100644 (file)
@@ -31,6 +31,12 @@ def test_close(conn, recwarn):
     assert not recwarn
 
 
+def test_close_noop(conn, recwarn):
+    cur = conn.cursor("foo")
+    cur.close()
+    assert not recwarn
+
+
 def test_context(conn, recwarn):
     with conn.cursor("foo") as cur:
         cur.execute("select generate_series(1, 10) as bar")
@@ -170,12 +176,21 @@ def test_non_scrollable(conn):
 
 def test_steal_cursor(conn):
     cur1 = conn.cursor()
-    cur1.execute(
-        "declare test cursor without hold for select generate_series(1, 6)"
-    )
+    cur1.execute("declare test cursor for select generate_series(1, 6)")
 
     cur2 = conn.cursor("test")
     # can call fetch without execute
     assert cur2.fetchone() == (1,)
     assert cur2.fetchmany(3) == [(2,), (3,), (4,)]
     assert cur2.fetchall() == [(5,), (6,)]
+
+
+def test_stolen_cursor_close(conn):
+    cur1 = conn.cursor()
+    cur1.execute("declare test cursor for select generate_series(1, 6)")
+    cur2 = conn.cursor("test")
+    cur2.close()
+
+    cur1.execute("declare test cursor for select generate_series(1, 6)")
+    cur2 = conn.cursor("test")
+    cur2.close()
index 3a2a078430663ed58a1db2ecbd6e03e2a00c0752..0ae150229d1ead8059b980ae13e34fec58ac51d6 100644 (file)
@@ -33,6 +33,12 @@ async def test_close(aconn, recwarn):
     assert not recwarn
 
 
+async def test_close_noop(aconn, recwarn):
+    cur = await aconn.cursor("foo")
+    await cur.close()
+    assert not recwarn
+
+
 async def test_context(aconn, recwarn):
     async with await aconn.cursor("foo") as cur:
         await cur.execute("select generate_series(1, 10) as bar")