From: Daniele Varrazzo Date: Tue, 9 Feb 2021 17:47:41 +0000 (+0100) Subject: Add Cursor.rownumber attribute X-Git-Tag: 3.0.dev0~115^2~18 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b89085ac2125982b25ea2e6a33939f1c6e143576;p=thirdparty%2Fpsycopg.git Add Cursor.rownumber attribute --- diff --git a/docs/cursor.rst b/docs/cursor.rst index 162165c5b..a1f0d9515 100644 --- a/docs/cursor.rst +++ b/docs/cursor.rst @@ -141,6 +141,9 @@ The `!Cursor` class .. autoattribute:: rowcount :annotation: int + .. autoattribute:: rownumber + :annotation: int + .. autoattribute:: query :annotation: Optional[bytes] diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index b6028195d..a9e18c2f7 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -146,6 +146,14 @@ class BaseCursor(Generic[ConnectionType]): """Number of records affected by the precedent operation.""" return self._rowcount + @property + def rownumber(self) -> Optional[int]: + """Index of the next row to fetch in the current result. + + `!None` if there is no result to fetch. + """ + return self._pos if self._pgresult else None + def setinputsizes(self, sizes: Sequence[Any]) -> None: # no-op pass @@ -531,7 +539,7 @@ class Cursor(BaseCursor["Connection"]): self._check_result() assert self.pgresult records = self._tx.load_rows(self._pos, self.pgresult.ntuples) - self._pos += self.pgresult.ntuples + self._pos = self.pgresult.ntuples return records def __iter__(self) -> Iterator[Sequence[Any]]: @@ -628,7 +636,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): self._check_result() assert self.pgresult records = self._tx.load_rows(self._pos, self.pgresult.ntuples) - self._pos += self.pgresult.ntuples + self._pos = self.pgresult.ntuples return records async def __aiter__(self) -> AsyncIterator[Sequence[Any]]: diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 481dbf5e4..f3f1835cb 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -242,6 +242,29 @@ def test_rowcount(conn): assert cur.rowcount == -1 +def test_rownumber(conn): + cur = conn.cursor() + assert cur.rownumber is None + + cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + cur.fetchone() + assert cur.rownumber == 1 + cur.fetchone() + assert cur.rownumber == 2 + cur.fetchmany(10) + assert cur.rownumber == 12 + rns = [] + for i in cur: + rns.append(cur.rownumber) + if len(rns) >= 3: + break + assert rns == [13, 14, 15] + assert len(cur.fetchall()) == 42 - rns[-1] + assert cur.rownumber == 42 + + def test_iter(conn): cur = conn.cursor() cur.execute("select generate_series(1, 3)") diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index 6285aa5b5..a79af6121 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -243,6 +243,29 @@ async def test_rowcount(aconn): assert cur.rowcount == -1 +async def test_rownumber(aconn): + cur = await aconn.cursor() + assert cur.rownumber is None + + await cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + await cur.fetchone() + assert cur.rownumber == 1 + await cur.fetchone() + assert cur.rownumber == 2 + await cur.fetchmany(10) + assert cur.rownumber == 12 + rns = [] + async for i in cur: + rns.append(cur.rownumber) + if len(rns) >= 3: + break + assert rns == [13, 14, 15] + assert len(await cur.fetchall()) == 42 - rns[-1] + assert cur.rownumber == 42 + + async def test_iter(aconn): cur = await aconn.cursor() await cur.execute("select generate_series(1, 3)")