From: Daniele Varrazzo Date: Wed, 10 Feb 2021 00:49:41 +0000 (+0100) Subject: Add named cursor scroll X-Git-Tag: 3.0.dev0~115^2~14 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=42e2817981d70515c4f78197f1a42ed7dedee6fd;p=thirdparty%2Fpsycopg.git Add named cursor scroll --- diff --git a/psycopg3/psycopg3/named_cursor.py b/psycopg3/psycopg3/named_cursor.py index e8c25d17e..65ee07de3 100644 --- a/psycopg3/psycopg3/named_cursor.py +++ b/psycopg3/psycopg3/named_cursor.py @@ -84,6 +84,19 @@ class NamedCursorHelper(Generic[ConnectionType]): cur.pgresult = res return cur._tx.load_rows(0, res.ntuples) + def _scroll_gen(self, value: int, mode: str) -> PQGen[None]: + if mode not in ("relative", "absolute"): + raise ValueError( + f"bad mode: {mode}. It should be 'relative' or 'absolute'" + ) + query = sql.SQL("move{} {} from {}").format( + sql.SQL(" absolute" if mode == "absolute" else ""), + sql.Literal(value), + sql.Identifier(self.name), + ) + cur = self._cur + yield from cur._conn._exec_command(query) + def _make_declare_statement( self, query: Query, scrollable: bool, hold: bool ) -> sql.Composable: @@ -200,6 +213,15 @@ class NamedCursor(BaseCursor["Connection"]): if len(recs) < self.itersize: break + def scroll(self, value: int, mode: str = "relative") -> None: + with self._conn.lock: + self._conn.wait(self._helper._scroll_gen(value, mode)) + # Postgres doesn't have a reliable way to report a cursor out of bound + if mode == "relative": + self._pos += value + else: + self._pos = value + class AsyncNamedCursor(BaseCursor["AsyncConnection"]): __module__ = "psycopg3" @@ -299,3 +321,7 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]): yield rec if len(recs) < self.itersize: break + + async def scroll(self, value: int, mode: str = "relative") -> None: + async with self._conn.lock: + await self._conn.wait(self._helper._scroll_gen(value, mode)) diff --git a/tests/test_named_cursor.py b/tests/test_named_cursor.py index b18a5446c..2961d96cb 100644 --- a/tests/test_named_cursor.py +++ b/tests/test_named_cursor.py @@ -1,3 +1,6 @@ +import pytest + + def test_funny_name(conn): cur = conn.cursor("1-2-3") cur.execute("select generate_series(1, 3) as bar") @@ -126,3 +129,22 @@ def test_itersize(conn, commands): assert len(cmds) == 2 for cmd in cmds: assert ("fetch forward 2") in cmd.lower() + + +def test_scroll(conn): + cur = conn.cursor("tmp") + with pytest.raises(conn.ProgrammingError): + cur.scroll(0) + + cur.execute("select generate_series(0,9)") + cur.scroll(2) + assert cur.fetchone() == (2,) + cur.scroll(2) + assert cur.fetchone() == (5,) + cur.scroll(2, mode="relative") + assert cur.fetchone() == (8,) + cur.scroll(9, mode="absolute") + assert cur.fetchone() == (9,) + + with pytest.raises(ValueError): + cur.scroll(9, mode="wat") diff --git a/tests/test_named_cursor_async.py b/tests/test_named_cursor_async.py index 844b7dfe8..1dd643b41 100644 --- a/tests/test_named_cursor_async.py +++ b/tests/test_named_cursor_async.py @@ -136,3 +136,22 @@ async def test_itersize(aconn, acommands): assert len(cmds) == 2 for cmd in cmds: assert ("fetch forward 2") in cmd.lower() + + +async def test_scroll(aconn): + cur = await aconn.cursor("tmp") + with pytest.raises(aconn.ProgrammingError): + await cur.scroll(0) + + await cur.execute("select generate_series(0,9)") + await cur.scroll(2) + assert await cur.fetchone() == (2,) + await cur.scroll(2) + assert await cur.fetchone() == (5,) + await cur.scroll(2, mode="relative") + assert await cur.fetchone() == (8,) + await cur.scroll(9, mode="absolute") + assert await cur.fetchone() == (9,) + + with pytest.raises(ValueError): + await cur.scroll(9, mode="wat")