From: Daniele Varrazzo Date: Wed, 10 Feb 2021 02:04:00 +0000 (+0100) Subject: Close correctly named cursors in corner cases X-Git-Tag: 3.0.dev0~115^2~10 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c8e4eb2a240612d36f85e322fa6fec2660aa1dd9;p=thirdparty%2Fpsycopg.git Close correctly named cursors in corner cases --- diff --git a/psycopg3/psycopg3/named_cursor.py b/psycopg3/psycopg3/named_cursor.py index 434901f9b..6e6ec71ee 100644 --- a/psycopg3/psycopg3/named_cursor.py +++ b/psycopg3/psycopg3/named_cursor.py @@ -60,6 +60,16 @@ class NamedCursorHelper(Generic[ConnectionType]): self.described = True def _close_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]: + # if we didn't declare the cursor ourselves we still have to close it + # but we must make sure it exists. + if not self.described: + query = sql.SQL( + "select 1 from pg_catalog.pg_cursors where name = {}" + ).format(sql.Literal(self.name)) + res = yield from cur._conn._exec_command(query) + if res.ntuples == 0: + return + query = sql.SQL("close {}").format(sql.Identifier(self.name)) yield from cur._conn._exec_command(query) diff --git a/tests/test_named_cursor.py b/tests/test_named_cursor.py index 4b142dd95..e3bbca10a 100644 --- a/tests/test_named_cursor.py +++ b/tests/test_named_cursor.py @@ -31,6 +31,12 @@ def test_close(conn, recwarn): assert not recwarn +def test_close_noop(conn, recwarn): + cur = conn.cursor("foo") + cur.close() + assert not recwarn + + def test_context(conn, recwarn): with conn.cursor("foo") as cur: cur.execute("select generate_series(1, 10) as bar") @@ -170,12 +176,21 @@ def test_non_scrollable(conn): def test_steal_cursor(conn): cur1 = conn.cursor() - cur1.execute( - "declare test cursor without hold for select generate_series(1, 6)" - ) + cur1.execute("declare test cursor for select generate_series(1, 6)") cur2 = conn.cursor("test") # can call fetch without execute assert cur2.fetchone() == (1,) assert cur2.fetchmany(3) == [(2,), (3,), (4,)] assert cur2.fetchall() == [(5,), (6,)] + + +def test_stolen_cursor_close(conn): + cur1 = conn.cursor() + cur1.execute("declare test cursor for select generate_series(1, 6)") + cur2 = conn.cursor("test") + cur2.close() + + cur1.execute("declare test cursor for select generate_series(1, 6)") + cur2 = conn.cursor("test") + cur2.close() diff --git a/tests/test_named_cursor_async.py b/tests/test_named_cursor_async.py index 3a2a07843..0ae150229 100644 --- a/tests/test_named_cursor_async.py +++ b/tests/test_named_cursor_async.py @@ -33,6 +33,12 @@ async def test_close(aconn, recwarn): assert not recwarn +async def test_close_noop(aconn, recwarn): + cur = await aconn.cursor("foo") + await cur.close() + assert not recwarn + + async def test_context(aconn, recwarn): async with await aconn.cursor("foo") as cur: await cur.execute("select generate_series(1, 10) as bar")