From: Daniele Varrazzo Date: Wed, 10 Feb 2021 01:47:45 +0000 (+0100) Subject: Make sure you can use a named cursor to "steal" a portal X-Git-Tag: 3.0.dev0~115^2~11 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=73727267c708df1d1b991af919371d09432dd408;p=thirdparty%2Fpsycopg.git Make sure you can use a named cursor to "steal" a portal --- diff --git a/psycopg3/psycopg3/named_cursor.py b/psycopg3/psycopg3/named_cursor.py index 65ee07de3..434901f9b 100644 --- a/psycopg3/psycopg3/named_cursor.py +++ b/psycopg3/psycopg3/named_cursor.py @@ -4,7 +4,6 @@ psycopg3 named cursor objects (server-side cursors) # Copyright (C) 2020-2021 The Psycopg Team -import weakref import warnings from types import TracebackType from typing import Any, AsyncIterator, Generic, List, Iterator, Optional @@ -23,32 +22,24 @@ DEFAULT_ITERSIZE = 100 class NamedCursorHelper(Generic[ConnectionType]): - __slots__ = ("name", "_wcur") + __slots__ = ("name", "described") """Helper object for common NamedCursor code. TODO: this should be a mixin, but couldn't find a way to work it correctly with the generic. """ - def __init__( - self, - name: str, - cursor: BaseCursor[ConnectionType], - ): + def __init__(self, name: str): self.name = name - self._wcur = weakref.ref(cursor) - - @property - def _cur(self) -> BaseCursor[Any]: - cur = self._wcur() - assert cur - return cur + self.described = False def _declare_gen( - self, query: Query, params: Optional[Params] = None + self, + cur: BaseCursor[ConnectionType], + query: Query, + params: Optional[Params] = None, ) -> PQGen[None]: """Generator implementing `NamedCursor.execute()`.""" - cur = self._cur conn = cur._conn yield from cur._start_query(query) pgq = cur._convert_query(query, params) @@ -57,24 +48,34 @@ class NamedCursorHelper(Generic[ConnectionType]): cur._execute_results(results) # The above result is an COMMAND_OK. Get the cursor result shape + yield from self._describe_gen(cur) + + def _describe_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]: + conn = cur._conn conn.pgconn.send_describe_portal( self.name.encode(conn.client_encoding) ) results = yield from execute(conn.pgconn) cur._execute_results(results) + self.described = True - def _close_gen(self) -> PQGen[None]: - cur = self._cur + def _close_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]: query = sql.SQL("close {}").format(sql.Identifier(self.name)) yield from cur._conn._exec_command(query) - def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Tuple[Any, ...]]]: + def _fetch_gen( + self, cur: BaseCursor[ConnectionType], num: Optional[int] + ) -> PQGen[List[Tuple[Any, ...]]]: + # If we are stealing the cursor, make sure we know its shape + if not self.described: + yield from cur._start_query() + yield from self._describe_gen(cur) + if num is not None: howmuch: sql.Composable = sql.Literal(num) else: howmuch = sql.SQL("all") - cur = self._cur query = sql.SQL("fetch forward {} from {}").format( howmuch, sql.Identifier(self.name) ) @@ -84,7 +85,9 @@ class NamedCursorHelper(Generic[ConnectionType]): cur.pgresult = res return cur._tx.load_rows(0, res.ntuples) - def _scroll_gen(self, value: int, mode: str) -> PQGen[None]: + def _scroll_gen( + self, cur: BaseCursor[ConnectionType], value: int, mode: str + ) -> PQGen[None]: if mode not in ("relative", "absolute"): raise ValueError( f"bad mode: {mode}. It should be 'relative' or 'absolute'" @@ -94,13 +97,15 @@ class NamedCursorHelper(Generic[ConnectionType]): sql.Literal(value), sql.Identifier(self.name), ) - cur = self._cur yield from cur._conn._exec_command(query) def _make_declare_statement( - self, query: Query, scrollable: bool, hold: bool + self, + cur: BaseCursor[ConnectionType], + query: Query, + scrollable: bool, + hold: bool, ) -> sql.Composable: - cur = self._cur if isinstance(query, bytes): query = query.decode(cur._conn.client_encoding) if not isinstance(query, sql.Composable): @@ -128,7 +133,7 @@ class NamedCursor(BaseCursor["Connection"]): format: Format = Format.TEXT, ): super().__init__(connection, format=format) - self._helper = NamedCursorHelper(name, self) + self._helper: NamedCursorHelper["Connection"] = NamedCursorHelper(name) self.itersize = DEFAULT_ITERSIZE def __del__(self) -> None: @@ -159,7 +164,7 @@ class NamedCursor(BaseCursor["Connection"]): Close the current cursor and free associated resources. """ with self._conn.lock: - self._conn.wait(self._helper._close_gen()) + self._conn.wait(self._helper._close_gen(self)) self._close() def execute( @@ -174,15 +179,15 @@ class NamedCursor(BaseCursor["Connection"]): Execute a query or command to the database. """ query = self._helper._make_declare_statement( - query, scrollable=scrollable, hold=hold + self, query, scrollable=scrollable, hold=hold ) with self._conn.lock: - self._conn.wait(self._helper._declare_gen(query, params)) + self._conn.wait(self._helper._declare_gen(self, query, params)) return self def fetchone(self) -> Optional[Sequence[Any]]: with self._conn.lock: - recs = self._conn.wait(self._helper._fetch_gen(1)) + recs = self._conn.wait(self._helper._fetch_gen(self, 1)) if recs: self._pos += 1 return recs[0] @@ -193,20 +198,22 @@ class NamedCursor(BaseCursor["Connection"]): if not size: size = self.arraysize with self._conn.lock: - recs = self._conn.wait(self._helper._fetch_gen(size)) + recs = self._conn.wait(self._helper._fetch_gen(self, size)) self._pos += len(recs) return recs def fetchall(self) -> Sequence[Sequence[Any]]: with self._conn.lock: - recs = self._conn.wait(self._helper._fetch_gen(None)) + recs = self._conn.wait(self._helper._fetch_gen(self, None)) self._pos += len(recs) return recs def __iter__(self) -> Iterator[Sequence[Any]]: while True: with self._conn.lock: - recs = self._conn.wait(self._helper._fetch_gen(self.itersize)) + recs = self._conn.wait( + self._helper._fetch_gen(self, self.itersize) + ) for rec in recs: self._pos += 1 yield rec @@ -215,7 +222,7 @@ class NamedCursor(BaseCursor["Connection"]): def scroll(self, value: int, mode: str = "relative") -> None: with self._conn.lock: - self._conn.wait(self._helper._scroll_gen(value, mode)) + self._conn.wait(self._helper._scroll_gen(self, value, mode)) # Postgres doesn't have a reliable way to report a cursor out of bound if mode == "relative": self._pos += value @@ -235,7 +242,8 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]): format: Format = Format.TEXT, ): super().__init__(connection, format=format) - self._helper = NamedCursorHelper(name, self) + self._helper: NamedCursorHelper["AsyncConnection"] + self._helper = NamedCursorHelper(name) self.itersize = DEFAULT_ITERSIZE def __del__(self) -> None: @@ -266,7 +274,7 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]): Close the current cursor and free associated resources. """ async with self._conn.lock: - await self._conn.wait(self._helper._close_gen()) + await self._conn.wait(self._helper._close_gen(self)) self._close() async def execute( @@ -281,15 +289,17 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]): Execute a query or command to the database. """ query = self._helper._make_declare_statement( - query, scrollable=scrollable, hold=hold + self, query, scrollable=scrollable, hold=hold ) async with self._conn.lock: - await self._conn.wait(self._helper._declare_gen(query, params)) + await self._conn.wait( + self._helper._declare_gen(self, query, params) + ) return self async def fetchone(self) -> Optional[Sequence[Any]]: async with self._conn.lock: - recs = await self._conn.wait(self._helper._fetch_gen(1)) + recs = await self._conn.wait(self._helper._fetch_gen(self, 1)) if recs: self._pos += 1 return recs[0] @@ -300,13 +310,13 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]): if not size: size = self.arraysize async with self._conn.lock: - recs = await self._conn.wait(self._helper._fetch_gen(size)) + recs = await self._conn.wait(self._helper._fetch_gen(self, size)) self._pos += len(recs) return recs async def fetchall(self) -> Sequence[Sequence[Any]]: async with self._conn.lock: - recs = await self._conn.wait(self._helper._fetch_gen(None)) + recs = await self._conn.wait(self._helper._fetch_gen(self, None)) self._pos += len(recs) return recs @@ -314,7 +324,7 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]): while True: async with self._conn.lock: recs = await self._conn.wait( - self._helper._fetch_gen(self.itersize) + self._helper._fetch_gen(self, self.itersize) ) for rec in recs: self._pos += 1 @@ -324,4 +334,4 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]): async def scroll(self, value: int, mode: str = "relative") -> None: async with self._conn.lock: - await self._conn.wait(self._helper._scroll_gen(value, mode)) + await self._conn.wait(self._helper._scroll_gen(self, value, mode)) diff --git a/tests/test_named_cursor.py b/tests/test_named_cursor.py index 252bb88b6..4b142dd95 100644 --- a/tests/test_named_cursor.py +++ b/tests/test_named_cursor.py @@ -166,3 +166,16 @@ def test_non_scrollable(conn): curs.scroll(5) with pytest.raises(conn.OperationalError): curs.scroll(-1) + + +def test_steal_cursor(conn): + cur1 = conn.cursor() + cur1.execute( + "declare test cursor without hold for select generate_series(1, 6)" + ) + + cur2 = conn.cursor("test") + # can call fetch without execute + assert cur2.fetchone() == (1,) + assert cur2.fetchmany(3) == [(2,), (3,), (4,)] + assert cur2.fetchall() == [(5,), (6,)] diff --git a/tests/test_named_cursor_async.py b/tests/test_named_cursor_async.py index 568b31201..3a2a07843 100644 --- a/tests/test_named_cursor_async.py +++ b/tests/test_named_cursor_async.py @@ -173,3 +173,16 @@ async def test_non_scrollable(aconn): await curs.scroll(5) with pytest.raises(aconn.OperationalError): await curs.scroll(-1) + + +async def test_steal_cursor(aconn): + cur1 = await aconn.cursor() + await cur1.execute( + "declare test cursor without hold for select generate_series(1, 6)" + ) + + cur2 = await aconn.cursor("test") + # can call fetch without execute + assert await cur2.fetchone() == (1,) + assert await cur2.fetchmany(3) == [(2,), (3,), (4,)] + assert await cur2.fetchall() == [(5,), (6,)]