]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: fix loading of rows when row maker returns None
authorMartin Baláž <balaz@brightpick.ai>
Wed, 7 May 2025 15:25:13 +0000 (17:25 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 9 May 2025 00:16:45 +0000 (02:16 +0200)
Close #1073

psycopg/psycopg/_cursor_base.py
psycopg/psycopg/_py_transformer.py
psycopg/psycopg/abc.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
psycopg_c/psycopg_c/_psycopg.pyi
psycopg_c/psycopg_c/_psycopg/transform.pyx

index f6e85f19057e4be64af5d5388a0d70459b0adb51..10ddc1f314f614f51641952c8f933f39a722d754 100644 (file)
@@ -565,7 +565,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
                 name, pgq.params, param_formats=pgq.formats, result_format=fmt
             )
 
-    def _check_result_for_fetch(self) -> None:
+    def _check_result_for_fetch(self) -> PGresult:
         if self.closed:
             raise e.InterfaceError("the cursor is closed")
 
@@ -573,7 +573,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
             raise e.ProgrammingError("no result available")
 
         if (status := res.status) == TUPLES_OK:
-            return
+            return res
         elif status == FATAL_ERROR:
             raise e.error_from_result(res, encoding=self._encoding)
         elif status == PIPELINE_ABORTED:
@@ -607,15 +607,14 @@ class BaseCursor(Generic[ConnectionType, Row]):
             )
 
     def _scroll(self, value: int, mode: str) -> None:
-        self._check_result_for_fetch()
-        assert self.pgresult
+        res = self._check_result_for_fetch()
         if mode == "relative":
             newpos = self._pos + value
         elif mode == "absolute":
             newpos = value
         else:
             raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
-        if not 0 <= newpos < self.pgresult.ntuples:
+        if not 0 <= newpos < res.ntuples:
             raise IndexError("position out of bound")
         self._pos = newpos
 
index 820c567d677c73aff71251a4276f999bb922fca6..7328c8fe3eaa2fff0151077683d1a49427289d56 100644 (file)
@@ -312,12 +312,14 @@ class Transformer(AdaptContext):
 
         return records
 
-    def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None:
+    def load_row(self, row: int, make_row: RowMaker[Row]) -> Row:
         if not (res := self._pgresult):
-            return None
+            raise e.InterfaceError("result not set")
 
-        if not 0 <= row < self._ntuples:
-            return None
+        if not 0 <= row <= self._ntuples:
+            raise e.InterfaceError(
+                f"row must be included between 0 and {self._ntuples}"
+            )
 
         record: list[Any] = [None] * self._nfields
         for col in range(self._nfields):
index 3b96397626fa8422e79b72df95e9186aded0fb74..8767df08eb5f06f055ae959961331faba51e6cc5 100644 (file)
@@ -243,7 +243,7 @@ class Transformer(Protocol):
 
     def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> list[Row]: ...
 
-    def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None: ...
+    def load_row(self, row: int, make_row: RowMaker[Row]) -> Row: ...
 
     def load_sequence(self, record: Sequence[Buffer | None]) -> tuple[Any, ...]: ...
 
index a91d28eb6d59c05db9fb36ef4dda1581899b146c..b5754930eac01ac66392ed318d6e746b305c05bb 100644 (file)
@@ -150,11 +150,9 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
                     self._stream_send_gen(query, params, binary=binary, size=size)
                 )
                 first = True
-                while self._conn.wait(self._stream_fetchone_gen(first)):
-                    for pos in range(size):
-                        if (rec := self._tx.load_row(pos, self._make_row)) is None:
-                            break
-                        yield rec
+                while res := self._conn.wait(self._stream_fetchone_gen(first)):
+                    for pos in range(res.ntuples):
+                        yield self._tx.load_row(pos, self._make_row)
                     first = False
             except e._NO_TRACEBACK as ex:
                 raise ex.with_traceback(None)
@@ -185,10 +183,12 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         :rtype: Row | None, with Row defined by `row_factory`
         """
         self._fetch_pipeline()
-        self._check_result_for_fetch()
-        if (record := self._tx.load_row(self._pos, self._make_row)) is not None:
+        res = self._check_result_for_fetch()
+        if self._pos < res.ntuples:
+            record = self._tx.load_row(self._pos, self._make_row)
             self._pos += 1
-        return record
+            return record
+        return None
 
     def fetchmany(self, size: int = 0) -> list[Row]:
         """
@@ -199,13 +199,12 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         :rtype: Sequence[Row], with Row defined by `row_factory`
         """
         self._fetch_pipeline()
-        self._check_result_for_fetch()
-        assert self.pgresult
+        res = self._check_result_for_fetch()
 
         if not size:
             size = self.arraysize
         records = self._tx.load_rows(
-            self._pos, min(self._pos + size, self.pgresult.ntuples), self._make_row
+            self._pos, min(self._pos + size, res.ntuples), self._make_row
         )
         self._pos += len(records)
         return records
@@ -217,20 +216,17 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         :rtype: Sequence[Row], with Row defined by `row_factory`
         """
         self._fetch_pipeline()
-        self._check_result_for_fetch()
-        assert self.pgresult
-        records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
-        self._pos = self.pgresult.ntuples
+        res = self._check_result_for_fetch()
+        records = self._tx.load_rows(self._pos, res.ntuples, self._make_row)
+        self._pos = res.ntuples
         return records
 
     def __iter__(self) -> Iterator[Row]:
         self._fetch_pipeline()
-        self._check_result_for_fetch()
+        res = self._check_result_for_fetch()
 
-        def load(pos: int) -> Row | None:
-            return self._tx.load_row(pos, self._make_row)
-
-        while (row := load(self._pos)) is not None:
+        while self._pos < res.ntuples:
+            row = self._tx.load_row(self._pos, self._make_row)
             self._pos += 1
             yield row
 
index 56d613b316104e532ff7f844f6adf397668777eb..b714024ee8ad673e4ffc405503684090dba3fd9d 100644 (file)
@@ -150,16 +150,12 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
                     self._stream_send_gen(query, params, binary=binary, size=size)
                 )
                 first = True
-                while await self._conn.wait(self._stream_fetchone_gen(first)):
-                    for pos in range(size):
-                        if (rec := self._tx.load_row(pos, self._make_row)) is None:
-                            break
-                        yield rec
+                while res := await self._conn.wait(self._stream_fetchone_gen(first)):
+                    for pos in range(res.ntuples):
+                        yield self._tx.load_row(pos, self._make_row)
                     first = False
-
             except e._NO_TRACEBACK as ex:
                 raise ex.with_traceback(None)
-
             finally:
                 if self._pgconn.transaction_status == ACTIVE:
                     # Try to cancel the query, then consume the results
@@ -189,10 +185,12 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         :rtype: Row | None, with Row defined by `row_factory`
         """
         await self._fetch_pipeline()
-        self._check_result_for_fetch()
-        if (record := self._tx.load_row(self._pos, self._make_row)) is not None:
+        res = self._check_result_for_fetch()
+        if self._pos < res.ntuples:
+            record = self._tx.load_row(self._pos, self._make_row)
             self._pos += 1
-        return record
+            return record
+        return None
 
     async def fetchmany(self, size: int = 0) -> list[Row]:
         """
@@ -203,13 +201,12 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         :rtype: Sequence[Row], with Row defined by `row_factory`
         """
         await self._fetch_pipeline()
-        self._check_result_for_fetch()
-        assert self.pgresult
+        res = self._check_result_for_fetch()
 
         if not size:
             size = self.arraysize
         records = self._tx.load_rows(
-            self._pos, min(self._pos + size, self.pgresult.ntuples), self._make_row
+            self._pos, min(self._pos + size, res.ntuples), self._make_row
         )
         self._pos += len(records)
         return records
@@ -221,20 +218,17 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         :rtype: Sequence[Row], with Row defined by `row_factory`
         """
         await self._fetch_pipeline()
-        self._check_result_for_fetch()
-        assert self.pgresult
-        records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
-        self._pos = self.pgresult.ntuples
+        res = self._check_result_for_fetch()
+        records = self._tx.load_rows(self._pos, res.ntuples, self._make_row)
+        self._pos = res.ntuples
         return records
 
     async def __aiter__(self) -> AsyncIterator[Row]:
         await self._fetch_pipeline()
-        self._check_result_for_fetch()
-
-        def load(pos: int) -> Row | None:
-            return self._tx.load_row(pos, self._make_row)
+        res = self._check_result_for_fetch()
 
-        while (row := load(self._pos)) is not None:
+        while self._pos < res.ntuples:
+            row = self._tx.load_row(self._pos, self._make_row)
             self._pos += 1
             yield row
 
index fe19c120f4b1cc7a1247a499cfae56c8482a3fb2..977d2c22258dffcfdb1dae81940d3aa9a3075eed 100644 (file)
@@ -46,7 +46,7 @@ class Transformer(abc.AdaptContext):
     def as_literal(self, obj: Any) -> bytes: ...
     def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ...
     def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> list[Row]: ...
-    def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None: ...
+    def load_row(self, row: int, make_row: RowMaker[Row]) -> Row: ...
     def load_sequence(self, record: Sequence[abc.Buffer | None]) -> tuple[Any, ...]: ...
     def get_loader(self, oid: int, format: pq.Format) -> abc.Loader: ...
 
index 3a5fc9a87611856d1a53b073b71219aad11bfb21..f699fdc6218b9d47f2c6de3dd0145d8cc26cb3fd 100644 (file)
@@ -492,12 +492,14 @@ cdef class Transformer:
                 Py_DECREF(<object>brecord)
         return records
 
-    def load_row(self, int row, object make_row) -> Row | None:
+    def load_row(self, int row, object make_row) -> Row:
         if self._pgresult is None:
-            return None
+            raise e.InterfaceError("result not set")
 
         if not 0 <= row < self._ntuples:
-            return None
+            raise e.InterfaceError(
+                f"row must be included between 0 and {self._ntuples}"
+            )
 
         cdef libpq.PGresult *res = self._pgresult._pgresult_ptr
         # cheeky access to the internal PGresult structure