]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Delete the Postgres cursor when closing a named cursor
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 9 Feb 2021 14:06:06 +0000 (15:06 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 9 Feb 2021 14:06:06 +0000 (15:06 +0100)
psycopg3/psycopg3/named_cursor.py
tests/test_named_cursor.py
tests/test_named_cursor_async.py

index d0344b18314a901db54b84955e0039c0235e7ef7..54e10e890e3bf339517ef381089c93159d483444 100644 (file)
@@ -21,6 +21,11 @@ if TYPE_CHECKING:
 
 class NamedCursorHelper(Generic[ConnectionType]):
     __slots__ = ("name", "_wcur")
+    """Helper object for common NamedCursor code.
+
+    TODO: this should be a mixin, but couldn't find a way to work it
+    correctly with the generic.
+    """
 
     def __init__(
         self,
@@ -54,6 +59,11 @@ class NamedCursorHelper(Generic[ConnectionType]):
         results = yield from execute(cur._conn.pgconn)
         cur._execute_results(results)
 
+    def _close_gen(self) -> PQGen[None]:
+        cur = self._cur
+        query = sql.SQL("close {}").format(sql.Identifier(self.name))
+        yield from cur._conn._exec_command(query)
+
     def _make_declare_statement(
         self, query: Query, scrollable: bool, hold: bool
     ) -> sql.Composable:
@@ -114,7 +124,8 @@ class NamedCursor(BaseCursor["Connection"]):
         """
         Close the current cursor and free associated resources.
         """
-        # TODO close the cursor for real
+        with self._conn.lock:
+            self._conn.wait(self._helper._close_gen())
         self._close()
 
     def execute(
@@ -177,7 +188,8 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
         """
         Close the current cursor and free associated resources.
         """
-        # TODO close the cursor for real
+        async with self._conn.lock:
+            await self._conn.wait(self._helper._close_gen())
         self._close()
 
     async def execute(
index 82ae83f6ce19d8cb629501567990e814c65721c3..b9421a0cb2dc30567e5d73f70bbad060975fe4de 100644 (file)
@@ -6,3 +6,35 @@ def test_description(conn):
     assert cur.description[0].name == "bar"
     assert cur.description[0].type_code == cur.adapters.types["int4"].oid
     assert cur.pgresult.ntuples == 0
+
+
+def test_close(conn, recwarn):
+    cur = conn.cursor("foo")
+    cur.execute("select generate_series(1, 10) as bar")
+    cur.close()
+    assert cur.closed
+
+    assert not conn.execute(
+        "select * from pg_cursors where name = 'foo'"
+    ).fetchone()
+    del cur
+    assert not recwarn
+
+
+def test_context(conn, recwarn):
+    with conn.cursor("foo") as cur:
+        cur.execute("select generate_series(1, 10) as bar")
+
+    assert cur.closed
+    assert not conn.execute(
+        "select * from pg_cursors where name = 'foo'"
+    ).fetchone()
+    del cur
+    assert not recwarn
+
+
+def test_warn_close(conn, recwarn):
+    cur = conn.cursor("foo")
+    cur.execute("select generate_series(1, 10) as bar")
+    del cur
+    assert ".close()" in str(recwarn.pop(ResourceWarning).message)
index 538be22e9b52bf55a98b2516a1138c53e94f57d2..0ceeee11436ea54ec58336fa6004083a0eda4d17 100644 (file)
@@ -11,3 +11,35 @@ async def test_description(aconn):
     assert cur.description[0].name == "bar"
     assert cur.description[0].type_code == cur.adapters.types["int4"].oid
     assert cur.pgresult.ntuples == 0
+
+
+async def test_close(aconn, recwarn):
+    cur = await aconn.cursor("foo")
+    await cur.execute("select generate_series(1, 10) as bar")
+    await cur.close()
+    assert cur.closed
+
+    assert not await (
+        await aconn.execute("select * from pg_cursors where name = 'foo'")
+    ).fetchone()
+    del cur
+    assert not recwarn
+
+
+async def test_context(aconn, recwarn):
+    async with await aconn.cursor("foo") as cur:
+        await cur.execute("select generate_series(1, 10) as bar")
+
+    assert cur.closed
+    assert not await (
+        await aconn.execute("select * from pg_cursors where name = 'foo'")
+    ).fetchone()
+    del cur
+    assert not recwarn
+
+
+async def test_warn_close(aconn, recwarn):
+    cur = await aconn.cursor("foo")
+    await cur.execute("select generate_series(1, 10) as bar")
+    del cur
+    assert ".close()" in str(recwarn.pop(ResourceWarning).message)