From: Martin Baláž Date: Wed, 7 May 2025 15:25:13 +0000 (+0200) Subject: fix: fix loading of rows when row maker returns None X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=f77cfe538298ce1a989bd57e2ef5e512bde276cd;p=thirdparty%2Fpsycopg.git fix: fix loading of rows when row maker returns None Close #1073 --- diff --git a/psycopg/psycopg/_cursor_base.py b/psycopg/psycopg/_cursor_base.py index ee7c6462b..f314f1229 100644 --- a/psycopg/psycopg/_cursor_base.py +++ b/psycopg/psycopg/_cursor_base.py @@ -569,7 +569,7 @@ class BaseCursor(Generic[ConnectionType, Row]): name, pgq.params, param_formats=pgq.formats, result_format=fmt ) - def _check_result_for_fetch(self) -> None: + def _check_result_for_fetch(self) -> PGresult: if self.closed: raise e.InterfaceError("the cursor is closed") @@ -577,7 +577,7 @@ class BaseCursor(Generic[ConnectionType, Row]): raise e.ProgrammingError("no result available") if (status := res.status) == TUPLES_OK: - return + return res elif status == FATAL_ERROR: raise e.error_from_result(res, encoding=self._encoding) elif status == PIPELINE_ABORTED: @@ -611,15 +611,14 @@ class BaseCursor(Generic[ConnectionType, Row]): ) def _scroll(self, value: int, mode: str) -> None: - self._check_result_for_fetch() - assert self.pgresult + res = self._check_result_for_fetch() if mode == "relative": newpos = self._pos + value elif mode == "absolute": newpos = value else: raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") - if not 0 <= newpos < self.pgresult.ntuples: + if not 0 <= newpos < res.ntuples: raise IndexError("position out of bound") self._pos = newpos diff --git a/psycopg/psycopg/_py_transformer.py b/psycopg/psycopg/_py_transformer.py index ffba74e8c..a41332b9d 100644 --- a/psycopg/psycopg/_py_transformer.py +++ b/psycopg/psycopg/_py_transformer.py @@ -312,12 +312,14 @@ class Transformer(AdaptContext): return records - def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None: + def load_row(self, row: int, make_row: RowMaker[Row]) -> Row: if not (res := self._pgresult): - return None + raise e.InterfaceError("result not set") - if not 0 <= row < self._ntuples: - return None + if not 0 <= row <= self._ntuples: + raise e.InterfaceError( + f"row must be included between 0 and {self._ntuples}" + ) record: list[Any] = [None] * self._nfields for col in range(self._nfields): diff --git a/psycopg/psycopg/abc.py b/psycopg/psycopg/abc.py index a95c9e069..764b9a1d5 100644 --- a/psycopg/psycopg/abc.py +++ b/psycopg/psycopg/abc.py @@ -243,7 +243,7 @@ class Transformer(Protocol): def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> list[Row]: ... - def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None: ... + def load_row(self, row: int, make_row: RowMaker[Row]) -> Row: ... def load_sequence(self, record: Sequence[Buffer | None]) -> tuple[Any, ...]: ... diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 56411dac0..6b62898ca 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -151,11 +151,9 @@ class Cursor(BaseCursor["Connection[Any]", Row]): self._stream_send_gen(query, params, binary=binary, size=size) ) first = True - while self._conn.wait(self._stream_fetchone_gen(first)): - for pos in range(size): - if (rec := self._tx.load_row(pos, self._make_row)) is None: - break - yield rec + while res := self._conn.wait(self._stream_fetchone_gen(first)): + for pos in range(res.ntuples): + yield self._tx.load_row(pos, self._make_row) first = False except e._NO_TRACEBACK as ex: raise ex.with_traceback(None) @@ -186,10 +184,12 @@ class Cursor(BaseCursor["Connection[Any]", Row]): :rtype: Row | None, with Row defined by `row_factory` """ self._fetch_pipeline() - self._check_result_for_fetch() - if (record := self._tx.load_row(self._pos, self._make_row)) is not None: + res = self._check_result_for_fetch() + if self._pos < res.ntuples: + record = self._tx.load_row(self._pos, self._make_row) self._pos += 1 - return record + return record + return None def fetchmany(self, size: int = 0) -> list[Row]: """ @@ -200,13 +200,12 @@ class Cursor(BaseCursor["Connection[Any]", Row]): :rtype: Sequence[Row], with Row defined by `row_factory` """ self._fetch_pipeline() - self._check_result_for_fetch() - assert self.pgresult + res = self._check_result_for_fetch() if not size: size = self.arraysize records = self._tx.load_rows( - self._pos, min(self._pos + size, self.pgresult.ntuples), self._make_row + self._pos, min(self._pos + size, res.ntuples), self._make_row ) self._pos += len(records) return records @@ -218,18 +217,21 @@ class Cursor(BaseCursor["Connection[Any]", Row]): :rtype: Sequence[Row], with Row defined by `row_factory` """ self._fetch_pipeline() - self._check_result_for_fetch() - assert self.pgresult - records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row) - self._pos = self.pgresult.ntuples + res = self._check_result_for_fetch() + records = self._tx.load_rows(self._pos, res.ntuples, self._make_row) + self._pos = res.ntuples return records def __iter__(self) -> Self: return self def __next__(self) -> Row: - if (rec := self.fetchone()) is not None: - return rec + self._fetch_pipeline() + res = self._check_result_for_fetch() + if self._pos < res.ntuples: + record = self._tx.load_row(self._pos, self._make_row) + self._pos += 1 + return record raise StopIteration("no more records to return") def scroll(self, value: int, mode: str = "relative") -> None: diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index d2f6e7757..5fc93419f 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -151,16 +151,12 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): self._stream_send_gen(query, params, binary=binary, size=size) ) first = True - while await self._conn.wait(self._stream_fetchone_gen(first)): - for pos in range(size): - if (rec := self._tx.load_row(pos, self._make_row)) is None: - break - yield rec + while res := await self._conn.wait(self._stream_fetchone_gen(first)): + for pos in range(res.ntuples): + yield self._tx.load_row(pos, self._make_row) first = False - except e._NO_TRACEBACK as ex: raise ex.with_traceback(None) - finally: if self._pgconn.transaction_status == ACTIVE: # Try to cancel the query, then consume the results @@ -190,10 +186,12 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): :rtype: Row | None, with Row defined by `row_factory` """ await self._fetch_pipeline() - self._check_result_for_fetch() - if (record := self._tx.load_row(self._pos, self._make_row)) is not None: + res = self._check_result_for_fetch() + if self._pos < res.ntuples: + record = self._tx.load_row(self._pos, self._make_row) self._pos += 1 - return record + return record + return None async def fetchmany(self, size: int = 0) -> list[Row]: """ @@ -204,13 +202,12 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): :rtype: Sequence[Row], with Row defined by `row_factory` """ await self._fetch_pipeline() - self._check_result_for_fetch() - assert self.pgresult + res = self._check_result_for_fetch() if not size: size = self.arraysize records = self._tx.load_rows( - self._pos, min(self._pos + size, self.pgresult.ntuples), self._make_row + self._pos, min(self._pos + size, res.ntuples), self._make_row ) self._pos += len(records) return records @@ -222,18 +219,21 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): :rtype: Sequence[Row], with Row defined by `row_factory` """ await self._fetch_pipeline() - self._check_result_for_fetch() - assert self.pgresult - records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row) - self._pos = self.pgresult.ntuples + res = self._check_result_for_fetch() + records = self._tx.load_rows(self._pos, res.ntuples, self._make_row) + self._pos = res.ntuples return records def __aiter__(self) -> Self: return self async def __anext__(self) -> Row: - if (rec := await self.fetchone()) is not None: - return rec + await self._fetch_pipeline() + res = self._check_result_for_fetch() + if self._pos < res.ntuples: + record = self._tx.load_row(self._pos, self._make_row) + self._pos += 1 + return record raise StopAsyncIteration("no more records to return") async def scroll(self, value: int, mode: str = "relative") -> None: diff --git a/psycopg_c/psycopg_c/_psycopg.pyi b/psycopg_c/psycopg_c/_psycopg.pyi index 498ee7d49..c17823467 100644 --- a/psycopg_c/psycopg_c/_psycopg.pyi +++ b/psycopg_c/psycopg_c/_psycopg.pyi @@ -46,7 +46,7 @@ class Transformer(abc.AdaptContext): def as_literal(self, obj: Any) -> bytes: ... def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ... def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> list[Row]: ... - def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None: ... + def load_row(self, row: int, make_row: RowMaker[Row]) -> Row: ... def load_sequence(self, record: Sequence[abc.Buffer | None]) -> tuple[Any, ...]: ... def get_loader(self, oid: int, format: pq.Format) -> abc.Loader: ... diff --git a/psycopg_c/psycopg_c/_psycopg/transform.pyx b/psycopg_c/psycopg_c/_psycopg/transform.pyx index 3a5fc9a87..f699fdc62 100644 --- a/psycopg_c/psycopg_c/_psycopg/transform.pyx +++ b/psycopg_c/psycopg_c/_psycopg/transform.pyx @@ -492,12 +492,14 @@ cdef class Transformer: Py_DECREF(brecord) return records - def load_row(self, int row, object make_row) -> Row | None: + def load_row(self, int row, object make_row) -> Row: if self._pgresult is None: - return None + raise e.InterfaceError("result not set") if not 0 <= row < self._ntuples: - return None + raise e.InterfaceError( + f"row must be included between 0 and {self._ntuples}" + ) cdef libpq.PGresult *res = self._pgresult._pgresult_ptr # cheeky access to the internal PGresult structure