From: Daniele Varrazzo Date: Tue, 9 Feb 2021 22:53:54 +0000 (+0100) Subject: Fix rownumber during iteration in named cursors X-Git-Tag: 3.0.dev0~115^2~16 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d04769d672e57562dfd8cb6e7b03445fe27bba1b;p=thirdparty%2Fpsycopg.git Fix rownumber during iteration in named cursors --- diff --git a/psycopg3/psycopg3/named_cursor.py b/psycopg3/psycopg3/named_cursor.py index 109b2bb23..e8c25d17e 100644 --- a/psycopg3/psycopg3/named_cursor.py +++ b/psycopg3/psycopg3/named_cursor.py @@ -82,9 +82,7 @@ class NamedCursorHelper(Generic[ConnectionType]): # TODO: loaders don't need to be refreshed cur.pgresult = res - nrows = res.ntuples - cur._pos += nrows - return cur._tx.load_rows(0, nrows) + return cur._tx.load_rows(0, res.ntuples) def _make_declare_statement( self, query: Query, scrollable: bool, hold: bool @@ -172,18 +170,24 @@ class NamedCursor(BaseCursor["Connection"]): def fetchone(self) -> Optional[Sequence[Any]]: with self._conn.lock: recs = self._conn.wait(self._helper._fetch_gen(1)) - return recs[0] if recs else None + if recs: + self._pos += 1 + return recs[0] + else: + return None def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]: if not size: size = self.arraysize with self._conn.lock: recs = self._conn.wait(self._helper._fetch_gen(size)) + self._pos += len(recs) return recs def fetchall(self) -> Sequence[Sequence[Any]]: with self._conn.lock: recs = self._conn.wait(self._helper._fetch_gen(None)) + self._pos += len(recs) return recs def __iter__(self) -> Iterator[Sequence[Any]]: @@ -191,6 +195,7 @@ class NamedCursor(BaseCursor["Connection"]): with self._conn.lock: recs = self._conn.wait(self._helper._fetch_gen(self.itersize)) for rec in recs: + self._pos += 1 yield rec if len(recs) < self.itersize: break @@ -263,18 +268,24 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]): async def fetchone(self) -> Optional[Sequence[Any]]: async with self._conn.lock: recs = await self._conn.wait(self._helper._fetch_gen(1)) - return recs[0] if recs else None + if recs: + self._pos += 1 + return recs[0] + else: + return None async def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]: if not size: size = self.arraysize async with self._conn.lock: recs = await self._conn.wait(self._helper._fetch_gen(size)) + self._pos += len(recs) return recs async def fetchall(self) -> Sequence[Sequence[Any]]: async with self._conn.lock: recs = await self._conn.wait(self._helper._fetch_gen(None)) + self._pos += len(recs) return recs async def __aiter__(self) -> AsyncIterator[Sequence[Any]]: @@ -284,6 +295,7 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]): self._helper._fetch_gen(self.itersize) ) for rec in recs: + self._pos += 1 yield rec if len(recs) < self.itersize: break diff --git a/tests/test_named_cursor.py b/tests/test_named_cursor.py index 5dd89cbae..b18a5446c 100644 --- a/tests/test_named_cursor.py +++ b/tests/test_named_cursor.py @@ -1,3 +1,10 @@ +def test_funny_name(conn): + cur = conn.cursor("1-2-3") + cur.execute("select generate_series(1, 3) as bar") + assert cur.fetchall() == [(1,), (2,), (3,)] + assert cur.name == "1-2-3" + + def test_description(conn): cur = conn.cursor("foo") assert cur.name == "foo" @@ -100,6 +107,13 @@ def test_iter(conn): assert recs == [(2,), (3,)] +def test_iter_rownumber(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + for row in cur: + assert cur.rownumber == row[0] + + def test_itersize(conn, commands): with conn.cursor("foo") as cur: assert cur.itersize == 100 diff --git a/tests/test_named_cursor_async.py b/tests/test_named_cursor_async.py index d56951a73..844b7dfe8 100644 --- a/tests/test_named_cursor_async.py +++ b/tests/test_named_cursor_async.py @@ -3,6 +3,13 @@ import pytest pytestmark = pytest.mark.asyncio +async def test_funny_name(aconn): + cur = await aconn.cursor("1-2-3") + await cur.execute("select generate_series(1, 3) as bar") + assert await cur.fetchall() == [(1,), (2,), (3,)] + assert cur.name == "1-2-3" + + async def test_description(aconn): cur = await aconn.cursor("foo") assert cur.name == "foo" @@ -109,6 +116,13 @@ async def test_iter(aconn): assert recs == [(2,), (3,)] +async def test_iter_rownumber(aconn): + async with await aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + async for row in cur: + assert cur.rownumber == row[0] + + async def test_itersize(aconn, acommands): async with await aconn.cursor("foo") as cur: assert cur.itersize == 100