]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add cursor.scroll()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 10 Feb 2021 00:26:08 +0000 (01:26 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 10 Feb 2021 00:26:08 +0000 (01:26 +0100)
psycopg3/psycopg3/cursor.py
tests/test_cursor.py
tests/test_cursor_async.py

index a9e18c2f768103f4577fa12708a25c999233cf84..05cdc600115666fd042f294698ce0049a2ad21cc 100644 (file)
@@ -439,6 +439,21 @@ class BaseCursor(Generic[ConnectionType]):
                 f" FROM STDIN statements, got {ExecStatus(status).name}"
             )
 
+    def _scroll(self, value: int, mode: str) -> None:
+        self._check_result()
+        assert self.pgresult
+        if mode == "relative":
+            newpos = self._pos + value
+        elif mode == "absolute":
+            newpos = value
+        else:
+            raise ValueError(
+                f"bad mode: {mode}. It should be 'relative' or 'absolute'"
+            )
+        if not 0 <= newpos < self.pgresult.ntuples:
+            raise IndexError("position out of bound")
+        self._pos = newpos
+
     def _close(self) -> None:
         self._closed = True
         # however keep the query available, which can be useful for debugging
@@ -554,6 +569,9 @@ class Cursor(BaseCursor["Connection"]):
             self._pos += 1
             yield row
 
+    def scroll(self, value: int, mode: str = "relative") -> None:
+        self._scroll(value, mode)
+
     @contextmanager
     def copy(self, statement: Query) -> Iterator[Copy]:
         """
@@ -651,6 +669,9 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
             self._pos += 1
             yield row
 
+    async def scroll(self, value: int, mode: str = "relative") -> None:
+        self._scroll(value, mode)
+
     @asynccontextmanager
     async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]:
         async with self._conn.lock:
index f3f1835cb8567824ed963b9ae4f9e5a2d6d83ba5..d22888e1f6f232d69e89a4c74b33004ca9d2f1be 100644 (file)
@@ -286,6 +286,48 @@ def test_iter_stop(conn):
     assert list(cur) == []
 
 
+def test_scroll(conn):
+    cur = conn.cursor()
+    with pytest.raises(psycopg3.ProgrammingError):
+        cur.scroll(0)
+
+    cur.execute("select generate_series(0,9)")
+    cur.scroll(2)
+    assert cur.fetchone() == (2,)
+    cur.scroll(2)
+    assert cur.fetchone() == (5,)
+    cur.scroll(2, mode="relative")
+    assert cur.fetchone() == (8,)
+    cur.scroll(-1)
+    assert cur.fetchone() == (8,)
+    cur.scroll(-2)
+    assert cur.fetchone() == (7,)
+    cur.scroll(2, mode="absolute")
+    assert cur.fetchone() == (2,)
+
+    # on the boundary
+    cur.scroll(0, mode="absolute")
+    assert cur.fetchone() == (0,)
+    with pytest.raises(IndexError):
+        cur.scroll(-1, mode="absolute")
+
+    cur.scroll(0, mode="absolute")
+    with pytest.raises(IndexError):
+        cur.scroll(-1)
+
+    cur.scroll(9, mode="absolute")
+    assert cur.fetchone() == (9,)
+    with pytest.raises(IndexError):
+        cur.scroll(10, mode="absolute")
+
+    cur.scroll(9, mode="absolute")
+    with pytest.raises(IndexError):
+        cur.scroll(1)
+
+    with pytest.raises(ValueError):
+        cur.scroll(1, "wat")
+
+
 def test_query_params_execute(conn):
     cur = conn.cursor()
     assert cur.query is None
index a79af612139a370850e66bf02964e38b5ec4362f..c2344e16b3e4409c8076b805380213297238406a 100644 (file)
@@ -291,6 +291,48 @@ async def test_iter_stop(aconn):
         assert False
 
 
+async def test_scroll(aconn):
+    cur = await aconn.cursor()
+    with pytest.raises(psycopg3.ProgrammingError):
+        await cur.scroll(0)
+
+    await cur.execute("select generate_series(0,9)")
+    await cur.scroll(2)
+    assert await cur.fetchone() == (2,)
+    await cur.scroll(2)
+    assert await cur.fetchone() == (5,)
+    await cur.scroll(2, mode="relative")
+    assert await cur.fetchone() == (8,)
+    await cur.scroll(-1)
+    assert await cur.fetchone() == (8,)
+    await cur.scroll(-2)
+    assert await cur.fetchone() == (7,)
+    await cur.scroll(2, mode="absolute")
+    assert await cur.fetchone() == (2,)
+
+    # on the boundary
+    await cur.scroll(0, mode="absolute")
+    assert await cur.fetchone() == (0,)
+    with pytest.raises(IndexError):
+        await cur.scroll(-1, mode="absolute")
+
+    await cur.scroll(0, mode="absolute")
+    with pytest.raises(IndexError):
+        await cur.scroll(-1)
+
+    await cur.scroll(9, mode="absolute")
+    assert await cur.fetchone() == (9,)
+    with pytest.raises(IndexError):
+        await cur.scroll(10, mode="absolute")
+
+    await cur.scroll(9, mode="absolute")
+    with pytest.raises(IndexError):
+        await cur.scroll(1)
+
+    with pytest.raises(ValueError):
+        await cur.scroll(1, "wat")
+
+
 async def test_query_params_execute(aconn):
     cur = await aconn.cursor()
     assert cur.query is None