]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Make sure you can use a named cursor to "steal" a portal
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 10 Feb 2021 01:47:45 +0000 (02:47 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 10 Feb 2021 01:47:45 +0000 (02:47 +0100)
psycopg3/psycopg3/named_cursor.py
tests/test_named_cursor.py
tests/test_named_cursor_async.py

index 65ee07de3dd8b4a272214f3c6095d6b716c01ee8..434901f9b792515d9848a15d3b51c4527c93439c 100644 (file)
@@ -4,7 +4,6 @@ psycopg3 named cursor objects (server-side cursors)
 
 # Copyright (C) 2020-2021 The Psycopg Team
 
-import weakref
 import warnings
 from types import TracebackType
 from typing import Any, AsyncIterator, Generic, List, Iterator, Optional
@@ -23,32 +22,24 @@ DEFAULT_ITERSIZE = 100
 
 
 class NamedCursorHelper(Generic[ConnectionType]):
-    __slots__ = ("name", "_wcur")
+    __slots__ = ("name", "described")
     """Helper object for common NamedCursor code.
 
     TODO: this should be a mixin, but couldn't find a way to work it
     correctly with the generic.
     """
 
-    def __init__(
-        self,
-        name: str,
-        cursor: BaseCursor[ConnectionType],
-    ):
+    def __init__(self, name: str):
         self.name = name
-        self._wcur = weakref.ref(cursor)
-
-    @property
-    def _cur(self) -> BaseCursor[Any]:
-        cur = self._wcur()
-        assert cur
-        return cur
+        self.described = False
 
     def _declare_gen(
-        self, query: Query, params: Optional[Params] = None
+        self,
+        cur: BaseCursor[ConnectionType],
+        query: Query,
+        params: Optional[Params] = None,
     ) -> PQGen[None]:
         """Generator implementing `NamedCursor.execute()`."""
-        cur = self._cur
         conn = cur._conn
         yield from cur._start_query(query)
         pgq = cur._convert_query(query, params)
@@ -57,24 +48,34 @@ class NamedCursorHelper(Generic[ConnectionType]):
         cur._execute_results(results)
 
         # The above result is an COMMAND_OK. Get the cursor result shape
+        yield from self._describe_gen(cur)
+
+    def _describe_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]:
+        conn = cur._conn
         conn.pgconn.send_describe_portal(
             self.name.encode(conn.client_encoding)
         )
         results = yield from execute(conn.pgconn)
         cur._execute_results(results)
+        self.described = True
 
-    def _close_gen(self) -> PQGen[None]:
-        cur = self._cur
+    def _close_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]:
         query = sql.SQL("close {}").format(sql.Identifier(self.name))
         yield from cur._conn._exec_command(query)
 
-    def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Tuple[Any, ...]]]:
+    def _fetch_gen(
+        self, cur: BaseCursor[ConnectionType], num: Optional[int]
+    ) -> PQGen[List[Tuple[Any, ...]]]:
+        # If we are stealing the cursor, make sure we know its shape
+        if not self.described:
+            yield from cur._start_query()
+            yield from self._describe_gen(cur)
+
         if num is not None:
             howmuch: sql.Composable = sql.Literal(num)
         else:
             howmuch = sql.SQL("all")
 
-        cur = self._cur
         query = sql.SQL("fetch forward {} from {}").format(
             howmuch, sql.Identifier(self.name)
         )
@@ -84,7 +85,9 @@ class NamedCursorHelper(Generic[ConnectionType]):
         cur.pgresult = res
         return cur._tx.load_rows(0, res.ntuples)
 
-    def _scroll_gen(self, value: int, mode: str) -> PQGen[None]:
+    def _scroll_gen(
+        self, cur: BaseCursor[ConnectionType], value: int, mode: str
+    ) -> PQGen[None]:
         if mode not in ("relative", "absolute"):
             raise ValueError(
                 f"bad mode: {mode}. It should be 'relative' or 'absolute'"
@@ -94,13 +97,15 @@ class NamedCursorHelper(Generic[ConnectionType]):
             sql.Literal(value),
             sql.Identifier(self.name),
         )
-        cur = self._cur
         yield from cur._conn._exec_command(query)
 
     def _make_declare_statement(
-        self, query: Query, scrollable: bool, hold: bool
+        self,
+        cur: BaseCursor[ConnectionType],
+        query: Query,
+        scrollable: bool,
+        hold: bool,
     ) -> sql.Composable:
-        cur = self._cur
         if isinstance(query, bytes):
             query = query.decode(cur._conn.client_encoding)
         if not isinstance(query, sql.Composable):
@@ -128,7 +133,7 @@ class NamedCursor(BaseCursor["Connection"]):
         format: Format = Format.TEXT,
     ):
         super().__init__(connection, format=format)
-        self._helper = NamedCursorHelper(name, self)
+        self._helper: NamedCursorHelper["Connection"] = NamedCursorHelper(name)
         self.itersize = DEFAULT_ITERSIZE
 
     def __del__(self) -> None:
@@ -159,7 +164,7 @@ class NamedCursor(BaseCursor["Connection"]):
         Close the current cursor and free associated resources.
         """
         with self._conn.lock:
-            self._conn.wait(self._helper._close_gen())
+            self._conn.wait(self._helper._close_gen(self))
         self._close()
 
     def execute(
@@ -174,15 +179,15 @@ class NamedCursor(BaseCursor["Connection"]):
         Execute a query or command to the database.
         """
         query = self._helper._make_declare_statement(
-            query, scrollable=scrollable, hold=hold
+            self, query, scrollable=scrollable, hold=hold
         )
         with self._conn.lock:
-            self._conn.wait(self._helper._declare_gen(query, params))
+            self._conn.wait(self._helper._declare_gen(self, query, params))
         return self
 
     def fetchone(self) -> Optional[Sequence[Any]]:
         with self._conn.lock:
-            recs = self._conn.wait(self._helper._fetch_gen(1))
+            recs = self._conn.wait(self._helper._fetch_gen(self, 1))
         if recs:
             self._pos += 1
             return recs[0]
@@ -193,20 +198,22 @@ class NamedCursor(BaseCursor["Connection"]):
         if not size:
             size = self.arraysize
         with self._conn.lock:
-            recs = self._conn.wait(self._helper._fetch_gen(size))
+            recs = self._conn.wait(self._helper._fetch_gen(self, size))
         self._pos += len(recs)
         return recs
 
     def fetchall(self) -> Sequence[Sequence[Any]]:
         with self._conn.lock:
-            recs = self._conn.wait(self._helper._fetch_gen(None))
+            recs = self._conn.wait(self._helper._fetch_gen(self, None))
         self._pos += len(recs)
         return recs
 
     def __iter__(self) -> Iterator[Sequence[Any]]:
         while True:
             with self._conn.lock:
-                recs = self._conn.wait(self._helper._fetch_gen(self.itersize))
+                recs = self._conn.wait(
+                    self._helper._fetch_gen(self, self.itersize)
+                )
             for rec in recs:
                 self._pos += 1
                 yield rec
@@ -215,7 +222,7 @@ class NamedCursor(BaseCursor["Connection"]):
 
     def scroll(self, value: int, mode: str = "relative") -> None:
         with self._conn.lock:
-            self._conn.wait(self._helper._scroll_gen(value, mode))
+            self._conn.wait(self._helper._scroll_gen(self, value, mode))
         # Postgres doesn't have a reliable way to report a cursor out of bound
         if mode == "relative":
             self._pos += value
@@ -235,7 +242,8 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
         format: Format = Format.TEXT,
     ):
         super().__init__(connection, format=format)
-        self._helper = NamedCursorHelper(name, self)
+        self._helper: NamedCursorHelper["AsyncConnection"]
+        self._helper = NamedCursorHelper(name)
         self.itersize = DEFAULT_ITERSIZE
 
     def __del__(self) -> None:
@@ -266,7 +274,7 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
         Close the current cursor and free associated resources.
         """
         async with self._conn.lock:
-            await self._conn.wait(self._helper._close_gen())
+            await self._conn.wait(self._helper._close_gen(self))
         self._close()
 
     async def execute(
@@ -281,15 +289,17 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
         Execute a query or command to the database.
         """
         query = self._helper._make_declare_statement(
-            query, scrollable=scrollable, hold=hold
+            self, query, scrollable=scrollable, hold=hold
         )
         async with self._conn.lock:
-            await self._conn.wait(self._helper._declare_gen(query, params))
+            await self._conn.wait(
+                self._helper._declare_gen(self, query, params)
+            )
         return self
 
     async def fetchone(self) -> Optional[Sequence[Any]]:
         async with self._conn.lock:
-            recs = await self._conn.wait(self._helper._fetch_gen(1))
+            recs = await self._conn.wait(self._helper._fetch_gen(self, 1))
         if recs:
             self._pos += 1
             return recs[0]
@@ -300,13 +310,13 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
         if not size:
             size = self.arraysize
         async with self._conn.lock:
-            recs = await self._conn.wait(self._helper._fetch_gen(size))
+            recs = await self._conn.wait(self._helper._fetch_gen(self, size))
         self._pos += len(recs)
         return recs
 
     async def fetchall(self) -> Sequence[Sequence[Any]]:
         async with self._conn.lock:
-            recs = await self._conn.wait(self._helper._fetch_gen(None))
+            recs = await self._conn.wait(self._helper._fetch_gen(self, None))
         self._pos += len(recs)
         return recs
 
@@ -314,7 +324,7 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
         while True:
             async with self._conn.lock:
                 recs = await self._conn.wait(
-                    self._helper._fetch_gen(self.itersize)
+                    self._helper._fetch_gen(self, self.itersize)
                 )
             for rec in recs:
                 self._pos += 1
@@ -324,4 +334,4 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
 
     async def scroll(self, value: int, mode: str = "relative") -> None:
         async with self._conn.lock:
-            await self._conn.wait(self._helper._scroll_gen(value, mode))
+            await self._conn.wait(self._helper._scroll_gen(self, value, mode))
index 252bb88b6ac4515092d4f2f3dd44c51ef0c29803..4b142dd954629879627967058b7a845cd01fb6e6 100644 (file)
@@ -166,3 +166,16 @@ def test_non_scrollable(conn):
     curs.scroll(5)
     with pytest.raises(conn.OperationalError):
         curs.scroll(-1)
+
+
+def test_steal_cursor(conn):
+    cur1 = conn.cursor()
+    cur1.execute(
+        "declare test cursor without hold for select generate_series(1, 6)"
+    )
+
+    cur2 = conn.cursor("test")
+    # can call fetch without execute
+    assert cur2.fetchone() == (1,)
+    assert cur2.fetchmany(3) == [(2,), (3,), (4,)]
+    assert cur2.fetchall() == [(5,), (6,)]
index 568b3120126ec25aeae569c83962a5f0a86a2c5e..3a2a078430663ed58a1db2ecbd6e03e2a00c0752 100644 (file)
@@ -173,3 +173,16 @@ async def test_non_scrollable(aconn):
     await curs.scroll(5)
     with pytest.raises(aconn.OperationalError):
         await curs.scroll(-1)
+
+
+async def test_steal_cursor(aconn):
+    cur1 = await aconn.cursor()
+    await cur1.execute(
+        "declare test cursor without hold for select generate_series(1, 6)"
+    )
+
+    cur2 = await aconn.cursor("test")
+    # can call fetch without execute
+    assert await cur2.fetchone() == (1,)
+    assert await cur2.fetchmany(3) == [(2,), (3,), (4,)]
+    assert await cur2.fetchall() == [(5,), (6,)]