From: Daniele Varrazzo Date: Wed, 9 Nov 2022 15:25:42 +0000 (+0000) Subject: fix: return rownumber=None if the result has no row to fetch X-Git-Tag: 3.1.5~17 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a8168869b762c821e7f07d1fc52d3642f899711e;p=thirdparty%2Fpsycopg.git fix: return rownumber=None if the result has no row to fetch Close #437. --- diff --git a/docs/news.rst b/docs/news.rst index 441085f58..d29de2da7 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -15,6 +15,8 @@ Psycopg 3.1.5 (unreleased) - Return `!bytes` instead of `!memoryview` from `pq.Encoding` methods (:ticket:`#422`). +- Fix `Cursor.rownumber` to return `!None` when the result has no row to fetch + (:ticket:`#437`). Current release diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 3fbcfdcda..72c128e6f 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -137,7 +137,8 @@ class BaseCursor(Generic[ConnectionType, Row]): `!None` if there is no result to fetch. """ - return self._pos if self.pgresult else None + tuples = self.pgresult and self.pgresult.status == TUPLES_OK + return self._pos if tuples else None def setinputsizes(self, sizes: Sequence[Any]) -> None: # no-op diff --git a/psycopg/psycopg/server_cursor.py b/psycopg/psycopg/server_cursor.py index 3a12b5e47..b890d7728 100644 --- a/psycopg/psycopg/server_cursor.py +++ b/psycopg/psycopg/server_cursor.py @@ -27,6 +27,7 @@ TEXT = pq.Format.TEXT BINARY = pq.Format.BINARY COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK IDLE = pq.TransactionStatus.IDLE INTRANS = pq.TransactionStatus.INTRANS @@ -78,6 +79,19 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]): """ return self._withhold + @property + def rownumber(self) -> Optional[int]: + """Index of the next row to fetch in the current result. + + `!None` if there is no result to fetch. + """ + res = self.pgresult + # command_status is empty if the result comes from + # describe_portal, which means that we have just executed the DECLARE, + # so we can assume we are at the first row. + tuples = res and (res.status == TUPLES_OK or res.command_status == b"") + return self._pos if tuples else None + def _declare_gen( self, query: Query, diff --git a/tests/test_cursor.py b/tests/test_cursor.py index bc1c5799e..8b044944a 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -428,6 +428,35 @@ def test_rownumber(conn): assert cur.rownumber == 42 +@pytest.mark.parametrize("query", ["", "set timezone to utc"]) +def test_rownumber_none(conn, query): + cur = conn.cursor() + cur.execute(query) + assert cur.rownumber is None + + +def test_rownumber_mixed(conn): + cur = conn.cursor() + cur.execute( + """ +select x from generate_series(1, 3) x; +set timezone to utc; +select x from generate_series(4, 6) x; +""" + ) + assert cur.rownumber == 0 + assert cur.fetchone() == (1,) + assert cur.rownumber == 1 + assert cur.fetchone() == (2,) + assert cur.rownumber == 2 + cur.nextset() + assert cur.rownumber is None + cur.nextset() + assert cur.rownumber == 0 + assert cur.fetchone() == (4,) + assert cur.rownumber == 1 + + def test_iter(conn): cur = conn.cursor() cur.execute("select generate_series(1, 3)") diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index 50de79ee8..84b572260 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -419,6 +419,35 @@ async def test_rownumber(aconn): assert cur.rownumber == 42 +@pytest.mark.parametrize("query", ["", "set timezone to utc"]) +async def test_rownumber_none(aconn, query): + cur = aconn.cursor() + await cur.execute(query) + assert cur.rownumber is None + + +async def test_rownumber_mixed(aconn): + cur = aconn.cursor() + await cur.execute( + """ +select x from generate_series(1, 3) x; +set timezone to utc; +select x from generate_series(4, 6) x; +""" + ) + assert cur.rownumber == 0 + assert await cur.fetchone() == (1,) + assert cur.rownumber == 1 + assert await cur.fetchone() == (2,) + assert cur.rownumber == 2 + cur.nextset() + assert cur.rownumber is None + cur.nextset() + assert cur.rownumber == 0 + assert await cur.fetchone() == (4,) + assert cur.rownumber == 1 + + async def test_iter(aconn): cur = aconn.cursor() await cur.execute("select generate_series(1, 3)")