]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: fix loading of rows when row maker returns None
authorMartin Baláž <balaz@brightpick.ai>
Wed, 7 May 2025 15:25:13 +0000 (17:25 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 8 May 2025 21:00:01 +0000 (23:00 +0200)
Close #1073

psycopg/psycopg/_cursor_base.py
psycopg/psycopg/_py_transformer.py
psycopg/psycopg/abc.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
psycopg_c/psycopg_c/_psycopg.pyi
psycopg_c/psycopg_c/_psycopg/transform.pyx

index ee7c6462bb6b414f0998eae64b56c1558fa9b990..f314f122923cddd433a7260d9537604e8cf67e13 100644 (file)
@@ -569,7 +569,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
                 name, pgq.params, param_formats=pgq.formats, result_format=fmt
             )
 
-    def _check_result_for_fetch(self) -> None:
+    def _check_result_for_fetch(self) -> PGresult:
         if self.closed:
             raise e.InterfaceError("the cursor is closed")
 
@@ -577,7 +577,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
             raise e.ProgrammingError("no result available")
 
         if (status := res.status) == TUPLES_OK:
-            return
+            return res
         elif status == FATAL_ERROR:
             raise e.error_from_result(res, encoding=self._encoding)
         elif status == PIPELINE_ABORTED:
@@ -611,15 +611,14 @@ class BaseCursor(Generic[ConnectionType, Row]):
             )
 
     def _scroll(self, value: int, mode: str) -> None:
-        self._check_result_for_fetch()
-        assert self.pgresult
+        res = self._check_result_for_fetch()
         if mode == "relative":
             newpos = self._pos + value
         elif mode == "absolute":
             newpos = value
         else:
             raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
-        if not 0 <= newpos < self.pgresult.ntuples:
+        if not 0 <= newpos < res.ntuples:
             raise IndexError("position out of bound")
         self._pos = newpos
 
index ffba74e8c2e281f5f6063c4aace5dff8209280bc..a41332b9d57d3f0bf503ba9bb9ca0181920d6494 100644 (file)
@@ -312,12 +312,14 @@ class Transformer(AdaptContext):
 
         return records
 
-    def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None:
+    def load_row(self, row: int, make_row: RowMaker[Row]) -> Row:
         if not (res := self._pgresult):
-            return None
+            raise e.InterfaceError("result not set")
 
-        if not 0 <= row < self._ntuples:
-            return None
+        if not 0 <= row <= self._ntuples:
+            raise e.InterfaceError(
+                f"row must be included between 0 and {self._ntuples}"
+            )
 
         record: list[Any] = [None] * self._nfields
         for col in range(self._nfields):
index a95c9e069f477c6873a19a4a68148897e40bbc5a..764b9a1d5711ffc750731373eb3eb9501653d726 100644 (file)
@@ -243,7 +243,7 @@ class Transformer(Protocol):
 
     def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> list[Row]: ...
 
-    def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None: ...
+    def load_row(self, row: int, make_row: RowMaker[Row]) -> Row: ...
 
     def load_sequence(self, record: Sequence[Buffer | None]) -> tuple[Any, ...]: ...
 
index 56411dac086b4b06c1a85e696a396e730b8aeca9..6b62898ca7bc15de66ce2b57a6aea6461e0ec49a 100644 (file)
@@ -151,11 +151,9 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
                     self._stream_send_gen(query, params, binary=binary, size=size)
                 )
                 first = True
-                while self._conn.wait(self._stream_fetchone_gen(first)):
-                    for pos in range(size):
-                        if (rec := self._tx.load_row(pos, self._make_row)) is None:
-                            break
-                        yield rec
+                while res := self._conn.wait(self._stream_fetchone_gen(first)):
+                    for pos in range(res.ntuples):
+                        yield self._tx.load_row(pos, self._make_row)
                     first = False
             except e._NO_TRACEBACK as ex:
                 raise ex.with_traceback(None)
@@ -186,10 +184,12 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         :rtype: Row | None, with Row defined by `row_factory`
         """
         self._fetch_pipeline()
-        self._check_result_for_fetch()
-        if (record := self._tx.load_row(self._pos, self._make_row)) is not None:
+        res = self._check_result_for_fetch()
+        if self._pos < res.ntuples:
+            record = self._tx.load_row(self._pos, self._make_row)
             self._pos += 1
-        return record
+            return record
+        return None
 
     def fetchmany(self, size: int = 0) -> list[Row]:
         """
@@ -200,13 +200,12 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         :rtype: Sequence[Row], with Row defined by `row_factory`
         """
         self._fetch_pipeline()
-        self._check_result_for_fetch()
-        assert self.pgresult
+        res = self._check_result_for_fetch()
 
         if not size:
             size = self.arraysize
         records = self._tx.load_rows(
-            self._pos, min(self._pos + size, self.pgresult.ntuples), self._make_row
+            self._pos, min(self._pos + size, res.ntuples), self._make_row
         )
         self._pos += len(records)
         return records
@@ -218,18 +217,21 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         :rtype: Sequence[Row], with Row defined by `row_factory`
         """
         self._fetch_pipeline()
-        self._check_result_for_fetch()
-        assert self.pgresult
-        records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
-        self._pos = self.pgresult.ntuples
+        res = self._check_result_for_fetch()
+        records = self._tx.load_rows(self._pos, res.ntuples, self._make_row)
+        self._pos = res.ntuples
         return records
 
     def __iter__(self) -> Self:
         return self
 
     def __next__(self) -> Row:
-        if (rec := self.fetchone()) is not None:
-            return rec
+        self._fetch_pipeline()
+        res = self._check_result_for_fetch()
+        if self._pos < res.ntuples:
+            record = self._tx.load_row(self._pos, self._make_row)
+            self._pos += 1
+            return record
         raise StopIteration("no more records to return")
 
     def scroll(self, value: int, mode: str = "relative") -> None:
index d2f6e775772a7c3a6172c565a8075a8b250a2332..5fc93419f878ec9e64c58d24a322b6a0d22c8fc3 100644 (file)
@@ -151,16 +151,12 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
                     self._stream_send_gen(query, params, binary=binary, size=size)
                 )
                 first = True
-                while await self._conn.wait(self._stream_fetchone_gen(first)):
-                    for pos in range(size):
-                        if (rec := self._tx.load_row(pos, self._make_row)) is None:
-                            break
-                        yield rec
+                while res := await self._conn.wait(self._stream_fetchone_gen(first)):
+                    for pos in range(res.ntuples):
+                        yield self._tx.load_row(pos, self._make_row)
                     first = False
-
             except e._NO_TRACEBACK as ex:
                 raise ex.with_traceback(None)
-
             finally:
                 if self._pgconn.transaction_status == ACTIVE:
                     # Try to cancel the query, then consume the results
@@ -190,10 +186,12 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         :rtype: Row | None, with Row defined by `row_factory`
         """
         await self._fetch_pipeline()
-        self._check_result_for_fetch()
-        if (record := self._tx.load_row(self._pos, self._make_row)) is not None:
+        res = self._check_result_for_fetch()
+        if self._pos < res.ntuples:
+            record = self._tx.load_row(self._pos, self._make_row)
             self._pos += 1
-        return record
+            return record
+        return None
 
     async def fetchmany(self, size: int = 0) -> list[Row]:
         """
@@ -204,13 +202,12 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         :rtype: Sequence[Row], with Row defined by `row_factory`
         """
         await self._fetch_pipeline()
-        self._check_result_for_fetch()
-        assert self.pgresult
+        res = self._check_result_for_fetch()
 
         if not size:
             size = self.arraysize
         records = self._tx.load_rows(
-            self._pos, min(self._pos + size, self.pgresult.ntuples), self._make_row
+            self._pos, min(self._pos + size, res.ntuples), self._make_row
         )
         self._pos += len(records)
         return records
@@ -222,18 +219,21 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         :rtype: Sequence[Row], with Row defined by `row_factory`
         """
         await self._fetch_pipeline()
-        self._check_result_for_fetch()
-        assert self.pgresult
-        records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
-        self._pos = self.pgresult.ntuples
+        res = self._check_result_for_fetch()
+        records = self._tx.load_rows(self._pos, res.ntuples, self._make_row)
+        self._pos = res.ntuples
         return records
 
     def __aiter__(self) -> Self:
         return self
 
     async def __anext__(self) -> Row:
-        if (rec := await self.fetchone()) is not None:
-            return rec
+        await self._fetch_pipeline()
+        res = self._check_result_for_fetch()
+        if self._pos < res.ntuples:
+            record = self._tx.load_row(self._pos, self._make_row)
+            self._pos += 1
+            return record
         raise StopAsyncIteration("no more records to return")
 
     async def scroll(self, value: int, mode: str = "relative") -> None:
index 498ee7d492d1045165cb68e145dc3f0f59be534e..c17823467194d3348312816bce4afa1c39ebaa2b 100644 (file)
@@ -46,7 +46,7 @@ class Transformer(abc.AdaptContext):
     def as_literal(self, obj: Any) -> bytes: ...
     def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ...
     def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> list[Row]: ...
-    def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None: ...
+    def load_row(self, row: int, make_row: RowMaker[Row]) -> Row: ...
     def load_sequence(self, record: Sequence[abc.Buffer | None]) -> tuple[Any, ...]: ...
     def get_loader(self, oid: int, format: pq.Format) -> abc.Loader: ...
 
index 3a5fc9a87611856d1a53b073b71219aad11bfb21..f699fdc6218b9d47f2c6de3dd0145d8cc26cb3fd 100644 (file)
@@ -492,12 +492,14 @@ cdef class Transformer:
                 Py_DECREF(<object>brecord)
         return records
 
-    def load_row(self, int row, object make_row) -> Row | None:
+    def load_row(self, int row, object make_row) -> Row:
         if self._pgresult is None:
-            return None
+            raise e.InterfaceError("result not set")
 
         if not 0 <= row < self._ntuples:
-            return None
+            raise e.InterfaceError(
+                f"row must be included between 0 and {self._ntuples}"
+            )
 
         cdef libpq.PGresult *res = self._pgresult._pgresult_ptr
         # cheeky access to the internal PGresult structure