]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add named cursor scroll
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 10 Feb 2021 00:49:41 +0000 (01:49 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 10 Feb 2021 00:49:41 +0000 (01:49 +0100)
psycopg3/psycopg3/named_cursor.py
tests/test_named_cursor.py
tests/test_named_cursor_async.py

index e8c25d17e7d76ff747fc82536647a1b7b6c39c82..65ee07de3dd8b4a272214f3c6095d6b716c01ee8 100644 (file)
@@ -84,6 +84,19 @@ class NamedCursorHelper(Generic[ConnectionType]):
         cur.pgresult = res
         return cur._tx.load_rows(0, res.ntuples)
 
+    def _scroll_gen(self, value: int, mode: str) -> PQGen[None]:
+        if mode not in ("relative", "absolute"):
+            raise ValueError(
+                f"bad mode: {mode}. It should be 'relative' or 'absolute'"
+            )
+        query = sql.SQL("move{} {} from {}").format(
+            sql.SQL(" absolute" if mode == "absolute" else ""),
+            sql.Literal(value),
+            sql.Identifier(self.name),
+        )
+        cur = self._cur
+        yield from cur._conn._exec_command(query)
+
     def _make_declare_statement(
         self, query: Query, scrollable: bool, hold: bool
     ) -> sql.Composable:
@@ -200,6 +213,15 @@ class NamedCursor(BaseCursor["Connection"]):
             if len(recs) < self.itersize:
                 break
 
+    def scroll(self, value: int, mode: str = "relative") -> None:
+        with self._conn.lock:
+            self._conn.wait(self._helper._scroll_gen(value, mode))
+        # Postgres doesn't have a reliable way to report a cursor out of bound
+        if mode == "relative":
+            self._pos += value
+        else:
+            self._pos = value
+
 
 class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
     __module__ = "psycopg3"
@@ -299,3 +321,7 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
                 yield rec
             if len(recs) < self.itersize:
                 break
+
+    async def scroll(self, value: int, mode: str = "relative") -> None:
+        async with self._conn.lock:
+            await self._conn.wait(self._helper._scroll_gen(value, mode))
index b18a5446c4bf453880977aaf102140ef13e391e0..2961d96cba015df3ad0744d20dd95284443d22d9 100644 (file)
@@ -1,3 +1,6 @@
+import pytest
+
+
 def test_funny_name(conn):
     cur = conn.cursor("1-2-3")
     cur.execute("select generate_series(1, 3) as bar")
@@ -126,3 +129,22 @@ def test_itersize(conn, commands):
         assert len(cmds) == 2
         for cmd in cmds:
             assert ("fetch forward 2") in cmd.lower()
+
+
+def test_scroll(conn):
+    cur = conn.cursor("tmp")
+    with pytest.raises(conn.ProgrammingError):
+        cur.scroll(0)
+
+    cur.execute("select generate_series(0,9)")
+    cur.scroll(2)
+    assert cur.fetchone() == (2,)
+    cur.scroll(2)
+    assert cur.fetchone() == (5,)
+    cur.scroll(2, mode="relative")
+    assert cur.fetchone() == (8,)
+    cur.scroll(9, mode="absolute")
+    assert cur.fetchone() == (9,)
+
+    with pytest.raises(ValueError):
+        cur.scroll(9, mode="wat")
index 844b7dfe8767f73de50d8277943320c4086c8fcc..1dd643b41b8a552c5f8d83c6af5db5d54190f70a 100644 (file)
@@ -136,3 +136,22 @@ async def test_itersize(aconn, acommands):
         assert len(cmds) == 2
         for cmd in cmds:
             assert ("fetch forward 2") in cmd.lower()
+
+
+async def test_scroll(aconn):
+    cur = await aconn.cursor("tmp")
+    with pytest.raises(aconn.ProgrammingError):
+        await cur.scroll(0)
+
+    await cur.execute("select generate_series(0,9)")
+    await cur.scroll(2)
+    assert await cur.fetchone() == (2,)
+    await cur.scroll(2)
+    assert await cur.fetchone() == (5,)
+    await cur.scroll(2, mode="relative")
+    assert await cur.fetchone() == (8,)
+    await cur.scroll(9, mode="absolute")
+    assert await cur.fetchone() == (9,)
+
+    with pytest.raises(ValueError):
+        await cur.scroll(9, mode="wat")