]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add Cursor.rownumber attribute
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 9 Feb 2021 17:47:41 +0000 (18:47 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 9 Feb 2021 20:09:14 +0000 (21:09 +0100)
docs/cursor.rst
psycopg3/psycopg3/cursor.py
tests/test_cursor.py
tests/test_cursor_async.py

index 162165c5b96ecb7dc5b081af165de04a443dde11..a1f0d9515db3e8e6037abd411443e05513f6bcef 100644 (file)
@@ -141,6 +141,9 @@ The `!Cursor` class
     .. autoattribute:: rowcount
         :annotation: int
 
+    .. autoattribute:: rownumber
+        :annotation: int
+
     .. autoattribute:: query
         :annotation: Optional[bytes]
 
index b6028195da25134fe5735caf43204c482be6cce8..a9e18c2f768103f4577fa12708a25c999233cf84 100644 (file)
@@ -146,6 +146,14 @@ class BaseCursor(Generic[ConnectionType]):
         """Number of records affected by the precedent operation."""
         return self._rowcount
 
+    @property
+    def rownumber(self) -> Optional[int]:
+        """Index of the next row to fetch in the current result.
+
+        `!None` if there is no result to fetch.
+        """
+        return self._pos if self._pgresult else None
+
     def setinputsizes(self, sizes: Sequence[Any]) -> None:
         # no-op
         pass
@@ -531,7 +539,7 @@ class Cursor(BaseCursor["Connection"]):
         self._check_result()
         assert self.pgresult
         records = self._tx.load_rows(self._pos, self.pgresult.ntuples)
-        self._pos += self.pgresult.ntuples
+        self._pos = self.pgresult.ntuples
         return records
 
     def __iter__(self) -> Iterator[Sequence[Any]]:
@@ -628,7 +636,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
         self._check_result()
         assert self.pgresult
         records = self._tx.load_rows(self._pos, self.pgresult.ntuples)
-        self._pos += self.pgresult.ntuples
+        self._pos = self.pgresult.ntuples
         return records
 
     async def __aiter__(self) -> AsyncIterator[Sequence[Any]]:
index 481dbf5e496098c3208bc8d038714a862d1bc938..f3f1835cb8567824ed963b9ae4f9e5a2d6d83ba5 100644 (file)
@@ -242,6 +242,29 @@ def test_rowcount(conn):
     assert cur.rowcount == -1
 
 
+def test_rownumber(conn):
+    cur = conn.cursor()
+    assert cur.rownumber is None
+
+    cur.execute("select 1 from generate_series(1, 42)")
+    assert cur.rownumber == 0
+
+    cur.fetchone()
+    assert cur.rownumber == 1
+    cur.fetchone()
+    assert cur.rownumber == 2
+    cur.fetchmany(10)
+    assert cur.rownumber == 12
+    rns = []
+    for i in cur:
+        rns.append(cur.rownumber)
+        if len(rns) >= 3:
+            break
+    assert rns == [13, 14, 15]
+    assert len(cur.fetchall()) == 42 - rns[-1]
+    assert cur.rownumber == 42
+
+
 def test_iter(conn):
     cur = conn.cursor()
     cur.execute("select generate_series(1, 3)")
index 6285aa5b57157ebc9375d0e0741f465a1e747ed7..a79af612139a370850e66bf02964e38b5ec4362f 100644 (file)
@@ -243,6 +243,29 @@ async def test_rowcount(aconn):
     assert cur.rowcount == -1
 
 
+async def test_rownumber(aconn):
+    cur = await aconn.cursor()
+    assert cur.rownumber is None
+
+    await cur.execute("select 1 from generate_series(1, 42)")
+    assert cur.rownumber == 0
+
+    await cur.fetchone()
+    assert cur.rownumber == 1
+    await cur.fetchone()
+    assert cur.rownumber == 2
+    await cur.fetchmany(10)
+    assert cur.rownumber == 12
+    rns = []
+    async for i in cur:
+        rns.append(cur.rownumber)
+        if len(rns) >= 3:
+            break
+    assert rns == [13, 14, 15]
+    assert len(await cur.fetchall()) == 42 - rns[-1]
+    assert cur.rownumber == 42
+
+
 async def test_iter(aconn):
     cur = await aconn.cursor()
     await cur.execute("select generate_series(1, 3)")