From: Daniele Varrazzo Date: Wed, 10 Feb 2021 00:26:08 +0000 (+0100) Subject: Add cursor.scroll() X-Git-Tag: 3.0.dev0~115^2~15 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=aa0507a9df9fdcfe72fc1afb5d7c872b969e3923;p=thirdparty%2Fpsycopg.git Add cursor.scroll() --- diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index a9e18c2f7..05cdc6001 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -439,6 +439,21 @@ class BaseCursor(Generic[ConnectionType]): f" FROM STDIN statements, got {ExecStatus(status).name}" ) + def _scroll(self, value: int, mode: str) -> None: + self._check_result() + assert self.pgresult + if mode == "relative": + newpos = self._pos + value + elif mode == "absolute": + newpos = value + else: + raise ValueError( + f"bad mode: {mode}. It should be 'relative' or 'absolute'" + ) + if not 0 <= newpos < self.pgresult.ntuples: + raise IndexError("position out of bound") + self._pos = newpos + def _close(self) -> None: self._closed = True # however keep the query available, which can be useful for debugging @@ -554,6 +569,9 @@ class Cursor(BaseCursor["Connection"]): self._pos += 1 yield row + def scroll(self, value: int, mode: str = "relative") -> None: + self._scroll(value, mode) + @contextmanager def copy(self, statement: Query) -> Iterator[Copy]: """ @@ -651,6 +669,9 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): self._pos += 1 yield row + async def scroll(self, value: int, mode: str = "relative") -> None: + self._scroll(value, mode) + @asynccontextmanager async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]: async with self._conn.lock: diff --git a/tests/test_cursor.py b/tests/test_cursor.py index f3f1835cb..d22888e1f 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -286,6 +286,48 @@ def test_iter_stop(conn): assert list(cur) == [] +def test_scroll(conn): + cur = conn.cursor() + with pytest.raises(psycopg3.ProgrammingError): + cur.scroll(0) + + cur.execute("select generate_series(0,9)") + cur.scroll(2) + assert cur.fetchone() == (2,) + cur.scroll(2) + assert cur.fetchone() == (5,) + cur.scroll(2, mode="relative") + assert cur.fetchone() == (8,) + cur.scroll(-1) + assert cur.fetchone() == (8,) + cur.scroll(-2) + assert cur.fetchone() == (7,) + cur.scroll(2, mode="absolute") + assert cur.fetchone() == (2,) + + # on the boundary + cur.scroll(0, mode="absolute") + assert cur.fetchone() == (0,) + with pytest.raises(IndexError): + cur.scroll(-1, mode="absolute") + + cur.scroll(0, mode="absolute") + with pytest.raises(IndexError): + cur.scroll(-1) + + cur.scroll(9, mode="absolute") + assert cur.fetchone() == (9,) + with pytest.raises(IndexError): + cur.scroll(10, mode="absolute") + + cur.scroll(9, mode="absolute") + with pytest.raises(IndexError): + cur.scroll(1) + + with pytest.raises(ValueError): + cur.scroll(1, "wat") + + def test_query_params_execute(conn): cur = conn.cursor() assert cur.query is None diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index a79af6121..c2344e16b 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -291,6 +291,48 @@ async def test_iter_stop(aconn): assert False +async def test_scroll(aconn): + cur = await aconn.cursor() + with pytest.raises(psycopg3.ProgrammingError): + await cur.scroll(0) + + await cur.execute("select generate_series(0,9)") + await cur.scroll(2) + assert await cur.fetchone() == (2,) + await cur.scroll(2) + assert await cur.fetchone() == (5,) + await cur.scroll(2, mode="relative") + assert await cur.fetchone() == (8,) + await cur.scroll(-1) + assert await cur.fetchone() == (8,) + await cur.scroll(-2) + assert await cur.fetchone() == (7,) + await cur.scroll(2, mode="absolute") + assert await cur.fetchone() == (2,) + + # on the boundary + await cur.scroll(0, mode="absolute") + assert await cur.fetchone() == (0,) + with pytest.raises(IndexError): + await cur.scroll(-1, mode="absolute") + + await cur.scroll(0, mode="absolute") + with pytest.raises(IndexError): + await cur.scroll(-1) + + await cur.scroll(9, mode="absolute") + assert await cur.fetchone() == (9,) + with pytest.raises(IndexError): + await cur.scroll(10, mode="absolute") + + await cur.scroll(9, mode="absolute") + with pytest.raises(IndexError): + await cur.scroll(1) + + with pytest.raises(ValueError): + await cur.scroll(1, "wat") + + async def test_query_params_execute(aconn): cur = await aconn.cursor() assert cur.query is None