]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix rownumber during iteration in named cursors
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 9 Feb 2021 22:53:54 +0000 (23:53 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 9 Feb 2021 22:53:54 +0000 (23:53 +0100)
psycopg3/psycopg3/named_cursor.py
tests/test_named_cursor.py
tests/test_named_cursor_async.py

index 109b2bb236feb5cc127bb6281cc275c2e4d10d37..e8c25d17e7d76ff747fc82536647a1b7b6c39c82 100644 (file)
@@ -82,9 +82,7 @@ class NamedCursorHelper(Generic[ConnectionType]):
 
         # TODO: loaders don't need to be refreshed
         cur.pgresult = res
-        nrows = res.ntuples
-        cur._pos += nrows
-        return cur._tx.load_rows(0, nrows)
+        return cur._tx.load_rows(0, res.ntuples)
 
     def _make_declare_statement(
         self, query: Query, scrollable: bool, hold: bool
@@ -172,18 +170,24 @@ class NamedCursor(BaseCursor["Connection"]):
     def fetchone(self) -> Optional[Sequence[Any]]:
         with self._conn.lock:
             recs = self._conn.wait(self._helper._fetch_gen(1))
-        return recs[0] if recs else None
+        if recs:
+            self._pos += 1
+            return recs[0]
+        else:
+            return None
 
     def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]:
         if not size:
             size = self.arraysize
         with self._conn.lock:
             recs = self._conn.wait(self._helper._fetch_gen(size))
+        self._pos += len(recs)
         return recs
 
     def fetchall(self) -> Sequence[Sequence[Any]]:
         with self._conn.lock:
             recs = self._conn.wait(self._helper._fetch_gen(None))
+        self._pos += len(recs)
         return recs
 
     def __iter__(self) -> Iterator[Sequence[Any]]:
@@ -191,6 +195,7 @@ class NamedCursor(BaseCursor["Connection"]):
             with self._conn.lock:
                 recs = self._conn.wait(self._helper._fetch_gen(self.itersize))
             for rec in recs:
+                self._pos += 1
                 yield rec
             if len(recs) < self.itersize:
                 break
@@ -263,18 +268,24 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
     async def fetchone(self) -> Optional[Sequence[Any]]:
         async with self._conn.lock:
             recs = await self._conn.wait(self._helper._fetch_gen(1))
-        return recs[0] if recs else None
+        if recs:
+            self._pos += 1
+            return recs[0]
+        else:
+            return None
 
     async def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]:
         if not size:
             size = self.arraysize
         async with self._conn.lock:
             recs = await self._conn.wait(self._helper._fetch_gen(size))
+        self._pos += len(recs)
         return recs
 
     async def fetchall(self) -> Sequence[Sequence[Any]]:
         async with self._conn.lock:
             recs = await self._conn.wait(self._helper._fetch_gen(None))
+        self._pos += len(recs)
         return recs
 
     async def __aiter__(self) -> AsyncIterator[Sequence[Any]]:
@@ -284,6 +295,7 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
                     self._helper._fetch_gen(self.itersize)
                 )
             for rec in recs:
+                self._pos += 1
                 yield rec
             if len(recs) < self.itersize:
                 break
index 5dd89cbae48cbd204d5be15b2a7030e9ac7faf2e..b18a5446c4bf453880977aaf102140ef13e391e0 100644 (file)
@@ -1,3 +1,10 @@
+def test_funny_name(conn):
+    cur = conn.cursor("1-2-3")
+    cur.execute("select generate_series(1, 3) as bar")
+    assert cur.fetchall() == [(1,), (2,), (3,)]
+    assert cur.name == "1-2-3"
+
+
 def test_description(conn):
     cur = conn.cursor("foo")
     assert cur.name == "foo"
@@ -100,6 +107,13 @@ def test_iter(conn):
     assert recs == [(2,), (3,)]
 
 
+def test_iter_rownumber(conn):
+    with conn.cursor("foo") as cur:
+        cur.execute("select generate_series(1, %s) as bar", (3,))
+        for row in cur:
+            assert cur.rownumber == row[0]
+
+
 def test_itersize(conn, commands):
     with conn.cursor("foo") as cur:
         assert cur.itersize == 100
index d56951a736d4a84859264de5eace517b77c944a9..844b7dfe8767f73de50d8277943320c4086c8fcc 100644 (file)
@@ -3,6 +3,13 @@ import pytest
 pytestmark = pytest.mark.asyncio
 
 
+async def test_funny_name(aconn):
+    cur = await aconn.cursor("1-2-3")
+    await cur.execute("select generate_series(1, 3) as bar")
+    assert await cur.fetchall() == [(1,), (2,), (3,)]
+    assert cur.name == "1-2-3"
+
+
 async def test_description(aconn):
     cur = await aconn.cursor("foo")
     assert cur.name == "foo"
@@ -109,6 +116,13 @@ async def test_iter(aconn):
     assert recs == [(2,), (3,)]
 
 
+async def test_iter_rownumber(aconn):
+    async with await aconn.cursor("foo") as cur:
+        await cur.execute("select generate_series(1, %s) as bar", (3,))
+        async for row in cur:
+            assert cur.rownumber == row[0]
+
+
 async def test_itersize(aconn, acommands):
     async with await aconn.cursor("foo") as cur:
         assert cur.itersize == 100