Close #437.
- Return `!bytes` instead of `!memoryview` from `pq.Encoding` methods
(:ticket:`#422`).
+- Fix `Cursor.rownumber` to return `!None` when the result has no row to fetch
+ (:ticket:`#437`).
Current release
`!None` if there is no result to fetch.
"""
- return self._pos if self.pgresult else None
+ tuples = self.pgresult and self.pgresult.status == TUPLES_OK
+ return self._pos if tuples else None
def setinputsizes(self, sizes: Sequence[Any]) -> None:
# no-op
BINARY = pq.Format.BINARY
COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
IDLE = pq.TransactionStatus.IDLE
INTRANS = pq.TransactionStatus.INTRANS
"""
return self._withhold
+ @property
+ def rownumber(self) -> Optional[int]:
+ """Index of the next row to fetch in the current result.
+
+ `!None` if there is no result to fetch.
+ """
+ res = self.pgresult
+ # command_status is empty if the result comes from
+ # describe_portal, which means that we have just executed the DECLARE,
+ # so we can assume we are at the first row.
+ tuples = res and (res.status == TUPLES_OK or res.command_status == b"")
+ return self._pos if tuples else None
+
def _declare_gen(
self,
query: Query,
assert cur.rownumber == 42
+@pytest.mark.parametrize("query", ["", "set timezone to utc"])
+def test_rownumber_none(conn, query):
+ cur = conn.cursor()
+ cur.execute(query)
+ assert cur.rownumber is None
+
+
+def test_rownumber_mixed(conn):
+ cur = conn.cursor()
+ cur.execute(
+ """
+select x from generate_series(1, 3) x;
+set timezone to utc;
+select x from generate_series(4, 6) x;
+"""
+ )
+ assert cur.rownumber == 0
+ assert cur.fetchone() == (1,)
+ assert cur.rownumber == 1
+ assert cur.fetchone() == (2,)
+ assert cur.rownumber == 2
+ cur.nextset()
+ assert cur.rownumber is None
+ cur.nextset()
+ assert cur.rownumber == 0
+ assert cur.fetchone() == (4,)
+ assert cur.rownumber == 1
+
+
def test_iter(conn):
cur = conn.cursor()
cur.execute("select generate_series(1, 3)")
assert cur.rownumber == 42
+@pytest.mark.parametrize("query", ["", "set timezone to utc"])
+async def test_rownumber_none(aconn, query):
+ cur = aconn.cursor()
+ await cur.execute(query)
+ assert cur.rownumber is None
+
+
+async def test_rownumber_mixed(aconn):
+ cur = aconn.cursor()
+ await cur.execute(
+ """
+select x from generate_series(1, 3) x;
+set timezone to utc;
+select x from generate_series(4, 6) x;
+"""
+ )
+ assert cur.rownumber == 0
+ assert await cur.fetchone() == (1,)
+ assert cur.rownumber == 1
+ assert await cur.fetchone() == (2,)
+ assert cur.rownumber == 2
+ cur.nextset()
+ assert cur.rownumber is None
+ cur.nextset()
+ assert cur.rownumber == 0
+ assert await cur.fetchone() == (4,)
+ assert cur.rownumber == 1
+
+
async def test_iter(aconn):
cur = aconn.cursor()
await cur.execute("select generate_series(1, 3)")