From: Martin Baláž Date: Wed, 7 May 2025 15:25:13 +0000 (+0200) Subject: fix: fix loading of rows when row maker returns None X-Git-Tag: 3.2.8~15^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f786deda1c70460c591af3807ca1c54e82828798;p=thirdparty%2Fpsycopg.git fix: fix loading of rows when row maker returns None Close #1073 --- diff --git a/psycopg/psycopg/_cursor_base.py b/psycopg/psycopg/_cursor_base.py index f6e85f190..10ddc1f31 100644 --- a/psycopg/psycopg/_cursor_base.py +++ b/psycopg/psycopg/_cursor_base.py @@ -565,7 +565,7 @@ class BaseCursor(Generic[ConnectionType, Row]): name, pgq.params, param_formats=pgq.formats, result_format=fmt ) - def _check_result_for_fetch(self) -> None: + def _check_result_for_fetch(self) -> PGresult: if self.closed: raise e.InterfaceError("the cursor is closed") @@ -573,7 +573,7 @@ class BaseCursor(Generic[ConnectionType, Row]): raise e.ProgrammingError("no result available") if (status := res.status) == TUPLES_OK: - return + return res elif status == FATAL_ERROR: raise e.error_from_result(res, encoding=self._encoding) elif status == PIPELINE_ABORTED: @@ -607,15 +607,14 @@ class BaseCursor(Generic[ConnectionType, Row]): ) def _scroll(self, value: int, mode: str) -> None: - self._check_result_for_fetch() - assert self.pgresult + res = self._check_result_for_fetch() if mode == "relative": newpos = self._pos + value elif mode == "absolute": newpos = value else: raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") - if not 0 <= newpos < self.pgresult.ntuples: + if not 0 <= newpos < res.ntuples: raise IndexError("position out of bound") self._pos = newpos diff --git a/psycopg/psycopg/_py_transformer.py b/psycopg/psycopg/_py_transformer.py index 820c567d6..7328c8fe3 100644 --- a/psycopg/psycopg/_py_transformer.py +++ b/psycopg/psycopg/_py_transformer.py @@ -312,12 +312,14 @@ class Transformer(AdaptContext): return records - def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None: + def load_row(self, row: int, make_row: RowMaker[Row]) -> Row: if not (res := self._pgresult): - return None + raise e.InterfaceError("result not set") - if not 0 <= row < self._ntuples: - return None + if not 0 <= row <= self._ntuples: + raise e.InterfaceError( + f"row must be included between 0 and {self._ntuples}" + ) record: list[Any] = [None] * self._nfields for col in range(self._nfields): diff --git a/psycopg/psycopg/abc.py b/psycopg/psycopg/abc.py index 3b9639762..8767df08e 100644 --- a/psycopg/psycopg/abc.py +++ b/psycopg/psycopg/abc.py @@ -243,7 +243,7 @@ class Transformer(Protocol): def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> list[Row]: ... - def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None: ... + def load_row(self, row: int, make_row: RowMaker[Row]) -> Row: ... def load_sequence(self, record: Sequence[Buffer | None]) -> tuple[Any, ...]: ... diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index a91d28eb6..b5754930e 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -150,11 +150,9 @@ class Cursor(BaseCursor["Connection[Any]", Row]): self._stream_send_gen(query, params, binary=binary, size=size) ) first = True - while self._conn.wait(self._stream_fetchone_gen(first)): - for pos in range(size): - if (rec := self._tx.load_row(pos, self._make_row)) is None: - break - yield rec + while res := self._conn.wait(self._stream_fetchone_gen(first)): + for pos in range(res.ntuples): + yield self._tx.load_row(pos, self._make_row) first = False except e._NO_TRACEBACK as ex: raise ex.with_traceback(None) @@ -185,10 +183,12 @@ class Cursor(BaseCursor["Connection[Any]", Row]): :rtype: Row | None, with Row defined by `row_factory` """ self._fetch_pipeline() - self._check_result_for_fetch() - if (record := self._tx.load_row(self._pos, self._make_row)) is not None: + res = self._check_result_for_fetch() + if self._pos < res.ntuples: + record = self._tx.load_row(self._pos, self._make_row) self._pos += 1 - return record + return record + return None def fetchmany(self, size: int = 0) -> list[Row]: """ @@ -199,13 +199,12 @@ class Cursor(BaseCursor["Connection[Any]", Row]): :rtype: Sequence[Row], with Row defined by `row_factory` """ self._fetch_pipeline() - self._check_result_for_fetch() - assert self.pgresult + res = self._check_result_for_fetch() if not size: size = self.arraysize records = self._tx.load_rows( - self._pos, min(self._pos + size, self.pgresult.ntuples), self._make_row + self._pos, min(self._pos + size, res.ntuples), self._make_row ) self._pos += len(records) return records @@ -217,20 +216,17 @@ class Cursor(BaseCursor["Connection[Any]", Row]): :rtype: Sequence[Row], with Row defined by `row_factory` """ self._fetch_pipeline() - self._check_result_for_fetch() - assert self.pgresult - records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row) - self._pos = self.pgresult.ntuples + res = self._check_result_for_fetch() + records = self._tx.load_rows(self._pos, res.ntuples, self._make_row) + self._pos = res.ntuples return records def __iter__(self) -> Iterator[Row]: self._fetch_pipeline() - self._check_result_for_fetch() + res = self._check_result_for_fetch() - def load(pos: int) -> Row | None: - return self._tx.load_row(pos, self._make_row) - - while (row := load(self._pos)) is not None: + while self._pos < res.ntuples: + row = self._tx.load_row(self._pos, self._make_row) self._pos += 1 yield row diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index 56d613b31..b714024ee 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -150,16 +150,12 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): self._stream_send_gen(query, params, binary=binary, size=size) ) first = True - while await self._conn.wait(self._stream_fetchone_gen(first)): - for pos in range(size): - if (rec := self._tx.load_row(pos, self._make_row)) is None: - break - yield rec + while res := await self._conn.wait(self._stream_fetchone_gen(first)): + for pos in range(res.ntuples): + yield self._tx.load_row(pos, self._make_row) first = False - except e._NO_TRACEBACK as ex: raise ex.with_traceback(None) - finally: if self._pgconn.transaction_status == ACTIVE: # Try to cancel the query, then consume the results @@ -189,10 +185,12 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): :rtype: Row | None, with Row defined by `row_factory` """ await self._fetch_pipeline() - self._check_result_for_fetch() - if (record := self._tx.load_row(self._pos, self._make_row)) is not None: + res = self._check_result_for_fetch() + if self._pos < res.ntuples: + record = self._tx.load_row(self._pos, self._make_row) self._pos += 1 - return record + return record + return None async def fetchmany(self, size: int = 0) -> list[Row]: """ @@ -203,13 +201,12 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): :rtype: Sequence[Row], with Row defined by `row_factory` """ await self._fetch_pipeline() - self._check_result_for_fetch() - assert self.pgresult + res = self._check_result_for_fetch() if not size: size = self.arraysize records = self._tx.load_rows( - self._pos, min(self._pos + size, self.pgresult.ntuples), self._make_row + self._pos, min(self._pos + size, res.ntuples), self._make_row ) self._pos += len(records) return records @@ -221,20 +218,17 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): :rtype: Sequence[Row], with Row defined by `row_factory` """ await self._fetch_pipeline() - self._check_result_for_fetch() - assert self.pgresult - records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row) - self._pos = self.pgresult.ntuples + res = self._check_result_for_fetch() + records = self._tx.load_rows(self._pos, res.ntuples, self._make_row) + self._pos = res.ntuples return records async def __aiter__(self) -> AsyncIterator[Row]: await self._fetch_pipeline() - self._check_result_for_fetch() - - def load(pos: int) -> Row | None: - return self._tx.load_row(pos, self._make_row) + res = self._check_result_for_fetch() - while (row := load(self._pos)) is not None: + while self._pos < res.ntuples: + row = self._tx.load_row(self._pos, self._make_row) self._pos += 1 yield row diff --git a/psycopg_c/psycopg_c/_psycopg.pyi b/psycopg_c/psycopg_c/_psycopg.pyi index fe19c120f..977d2c222 100644 --- a/psycopg_c/psycopg_c/_psycopg.pyi +++ b/psycopg_c/psycopg_c/_psycopg.pyi @@ -46,7 +46,7 @@ class Transformer(abc.AdaptContext): def as_literal(self, obj: Any) -> bytes: ... def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ... def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> list[Row]: ... - def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None: ... + def load_row(self, row: int, make_row: RowMaker[Row]) -> Row: ... def load_sequence(self, record: Sequence[abc.Buffer | None]) -> tuple[Any, ...]: ... def get_loader(self, oid: int, format: pq.Format) -> abc.Loader: ... diff --git a/psycopg_c/psycopg_c/_psycopg/transform.pyx b/psycopg_c/psycopg_c/_psycopg/transform.pyx index 3a5fc9a87..f699fdc62 100644 --- a/psycopg_c/psycopg_c/_psycopg/transform.pyx +++ b/psycopg_c/psycopg_c/_psycopg/transform.pyx @@ -492,12 +492,14 @@ cdef class Transformer: Py_DECREF(brecord) return records - def load_row(self, int row, object make_row) -> Row | None: + def load_row(self, int row, object make_row) -> Row: if self._pgresult is None: - return None + raise e.InterfaceError("result not set") if not 0 <= row < self._ntuples: - return None + raise e.InterfaceError( + f"row must be included between 0 and {self._ntuples}" + ) cdef libpq.PGresult *res = self._pgresult._pgresult_ptr # cheeky access to the internal PGresult structure