]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: return rownumber=None if the result has no row to fetch
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 9 Nov 2022 15:25:42 +0000 (15:25 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 9 Nov 2022 16:45:43 +0000 (16:45 +0000)
Close #437.

docs/news.rst
psycopg/psycopg/cursor.py
psycopg/psycopg/server_cursor.py
tests/test_cursor.py
tests/test_cursor_async.py

index 441085f5848eb79cc6c1c59eaa2507bec4d64eb0..d29de2da7ed7081e0c50adb06bc08af79ca58723 100644 (file)
@@ -15,6 +15,8 @@ Psycopg 3.1.5 (unreleased)
 
 - Return `!bytes` instead of `!memoryview` from `pq.Encoding` methods
   (:ticket:`#422`).
+- Fix `Cursor.rownumber` to return `!None` when the result has no row to fetch
+  (:ticket:`#437`).
 
 
 Current release
index 3fbcfdcda7d2e8f570821f974212adbb0ea05ec7..72c128e6f253344f0364df9d546a8133b6273baf 100644 (file)
@@ -137,7 +137,8 @@ class BaseCursor(Generic[ConnectionType, Row]):
 
         `!None` if there is no result to fetch.
         """
-        return self._pos if self.pgresult else None
+        tuples = self.pgresult and self.pgresult.status == TUPLES_OK
+        return self._pos if tuples else None
 
     def setinputsizes(self, sizes: Sequence[Any]) -> None:
         # no-op
index 3a12b5e473174674decd86418479d9ef1ac76b22..b890d7728245e87810d75aeaaa836768cc8045a3 100644 (file)
@@ -27,6 +27,7 @@ TEXT = pq.Format.TEXT
 BINARY = pq.Format.BINARY
 
 COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
 
 IDLE = pq.TransactionStatus.IDLE
 INTRANS = pq.TransactionStatus.INTRANS
@@ -78,6 +79,19 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
         """
         return self._withhold
 
+    @property
+    def rownumber(self) -> Optional[int]:
+        """Index of the next row to fetch in the current result.
+
+        `!None` if there is no result to fetch.
+        """
+        res = self.pgresult
+        # command_status is empty if the result comes from
+        # describe_portal, which means that we have just executed the DECLARE,
+        # so we can assume we are at the first row.
+        tuples = res and (res.status == TUPLES_OK or res.command_status == b"")
+        return self._pos if tuples else None
+
     def _declare_gen(
         self,
         query: Query,
index bc1c5799ece7d62b76d83e3c971f3b57b2a25ca9..8b044944a0ecc0d9def9d24023f037a5262ba191 100644 (file)
@@ -428,6 +428,35 @@ def test_rownumber(conn):
     assert cur.rownumber == 42
 
 
+@pytest.mark.parametrize("query", ["", "set timezone to utc"])
+def test_rownumber_none(conn, query):
+    cur = conn.cursor()
+    cur.execute(query)
+    assert cur.rownumber is None
+
+
+def test_rownumber_mixed(conn):
+    cur = conn.cursor()
+    cur.execute(
+        """
+select x from generate_series(1, 3) x;
+set timezone to utc;
+select x from generate_series(4, 6) x;
+"""
+    )
+    assert cur.rownumber == 0
+    assert cur.fetchone() == (1,)
+    assert cur.rownumber == 1
+    assert cur.fetchone() == (2,)
+    assert cur.rownumber == 2
+    cur.nextset()
+    assert cur.rownumber is None
+    cur.nextset()
+    assert cur.rownumber == 0
+    assert cur.fetchone() == (4,)
+    assert cur.rownumber == 1
+
+
 def test_iter(conn):
     cur = conn.cursor()
     cur.execute("select generate_series(1, 3)")
index 50de79ee8bdb490c9bc277e92ff63d3c6263dbb8..84b57226037490f8b4c0c523047dc75c0e8dbb19 100644 (file)
@@ -419,6 +419,35 @@ async def test_rownumber(aconn):
     assert cur.rownumber == 42
 
 
+@pytest.mark.parametrize("query", ["", "set timezone to utc"])
+async def test_rownumber_none(aconn, query):
+    cur = aconn.cursor()
+    await cur.execute(query)
+    assert cur.rownumber is None
+
+
+async def test_rownumber_mixed(aconn):
+    cur = aconn.cursor()
+    await cur.execute(
+        """
+select x from generate_series(1, 3) x;
+set timezone to utc;
+select x from generate_series(4, 6) x;
+"""
+    )
+    assert cur.rownumber == 0
+    assert await cur.fetchone() == (1,)
+    assert cur.rownumber == 1
+    assert await cur.fetchone() == (2,)
+    assert cur.rownumber == 2
+    cur.nextset()
+    assert cur.rownumber is None
+    cur.nextset()
+    assert cur.rownumber == 0
+    assert await cur.fetchone() == (4,)
+    assert cur.rownumber == 1
+
+
 async def test_iter(aconn):
     cur = aconn.cursor()
     await cur.execute("select generate_series(1, 3)")