From: Daniele Varrazzo Date: Sat, 3 May 2025 02:07:37 +0000 (+0200) Subject: feat: make cursors iterators X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F1064%2Fhead;p=thirdparty%2Fpsycopg.git feat: make cursors iterators This allows to call `next(cursor)`, which is guaranteed to never return None, making type checking simpler. --- diff --git a/docs/advanced/typing.rst b/docs/advanced/typing.rst index 98efb4067..c5b722fbd 100644 --- a/docs/advanced/typing.rst +++ b/docs/advanced/typing.rst @@ -38,6 +38,8 @@ annotations such as `!Connection[Any]` and `!Cursor[Any]`. rec = cur.fetchone() # type is tuple[Any, ...] | None + rec = next(cur) # type is tuple[Any, ...] + recs = cur.fetchall() # type is List[tuple[Any, ...]] @@ -66,6 +68,56 @@ cursors and annotate the returned objects accordingly. See # drec type is dict[str, Any] | None +.. _typing-fetchone: + +The ``fetchone()`` frustration +------------------------------ + +.. versionchanged:: 3.3 + +If you use a static type checker and you are 100% sure that the cursor will +exactly one record, it is frustrating to be told that the returned row might +be `!None`. For example: + +.. code:: python + + import psycopg + from psycopg.rows import scalar_row + + def count_records() -> int: + conn = psycopg.connect() + cur = conn.cursor(row_factory=scalar_row) + cur.execute("SELECT count(*) FROM mytable") + rv: int = cur.fetchone() # mypy error here + return rv + +The :sql:`count(*)` will always return a record with a number, even if the +table is empty (it will just report 0). However, Mypy will report an error +such as *incompatible types in assignment (expression has type "Any | None", +variable has type "int")*. In order to work around the error you will need +to use an `!if`, an `!assert` or some other workaround (like ``(rv,) = +cur.fetchall()`` or some other horrible trick). + +Since Psycopg 3.3, cursors are iterables__, therefore they support the +`next` function. A `!next(cur)` will behave like `!cur.fetchone()`, but it +is guaranteed to return a row (in case there are no rows in the result set it +will not return anything but will raise `!StopIteration`). Therefore the +function above can terminate with: + +.. code:: python + + def count_records() -> int: + ... + rv: int = next(cur) + return rv + +and your static checker will be happy. + +Similarly, in async code, you can use an `!await` `anext`\ `!(cur)` expression. + +.. __: https://docs.python.org/3/glossary.html#term-iterable + + .. _pool-generic: Generic pool types diff --git a/docs/api/cursors.rst b/docs/api/cursors.rst index e1f6d0b02..826b568f0 100644 --- a/docs/api/cursors.rst +++ b/docs/api/cursors.rst @@ -239,13 +239,21 @@ The `!Cursor` class .. note:: - Cursors are iterable objects, so just using the:: + Cursors are iterators, so just using the:: for record in cursor: ... syntax will iterate on the records in the current result set. + .. versionchanged:: 3.3 + + it is now possible to use `!next(cursor)`. Previously, cursors were + iterables__, not iterators__. + + .. __: https://docs.python.org/3/glossary.html#term-iterable + .. __: https://docs.python.org/3/glossary.html#term-iterator + .. autoattribute:: row_factory The property affects the objects returned by the `fetchone()`, diff --git a/docs/news.rst b/docs/news.rst index 29d163bca..17fcaaf25 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -13,6 +13,8 @@ Future releases Psycopg 3.3.0 (unreleased) ^^^^^^^^^^^^^^^^^^^^^^^^^^ +- Cursors are now iterators, not only iterables. This means you can call + ``next(cur)`` to fetch the next row (:ticket:`#1064`). - Drop support for Python 3.8 (:ticket:`#976`) and 3.9 (:ticket:`#1056`). diff --git a/psycopg/psycopg/_server_cursor.py b/psycopg/psycopg/_server_cursor.py index 3b27d1fc7..cfddfd97c 100644 --- a/psycopg/psycopg/_server_cursor.py +++ b/psycopg/psycopg/_server_cursor.py @@ -11,7 +11,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, overload from warnings import warn -from collections.abc import Iterable, Iterator +from collections.abc import Iterable from . import errors as e from .abc import Params, Query @@ -136,15 +136,27 @@ class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]): self._pos += len(recs) return recs - def __iter__(self) -> Iterator[Row]: - while True: + def __iter__(self) -> Self: + return self + + def __next__(self) -> Row: + # Fetch a new page if we never fetched any, or we are at the end of + # a page of size itersize, meaning there is likely a following one. + if ( + self._iter_rows is None + or self._page_pos >= len(self._iter_rows) >= self.itersize + ): with self._conn.lock: - recs = self._conn.wait(self._fetch_gen(self.itersize)) - for rec in recs: - self._pos += 1 - yield rec - if len(recs) < self.itersize: - break + self._iter_rows = self._conn.wait(self._fetch_gen(self.itersize)) + self._page_pos += 0 + + if self._page_pos >= len(self._iter_rows): + raise StopIteration("no more records to return") + + rec = self._iter_rows[self._page_pos] + self._page_pos += 1 + self._pos += 1 + return rec def scroll(self, value: int, mode: str = "relative") -> None: with self._conn.lock: diff --git a/psycopg/psycopg/_server_cursor_async.py b/psycopg/psycopg/_server_cursor_async.py index 37e0e2621..41628fe74 100644 --- a/psycopg/psycopg/_server_cursor_async.py +++ b/psycopg/psycopg/_server_cursor_async.py @@ -8,7 +8,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, overload from warnings import warn -from collections.abc import AsyncIterator, Iterable +from collections.abc import Iterable from . import errors as e from .abc import Params, Query @@ -136,15 +136,26 @@ class AsyncServerCursor( self._pos += len(recs) return recs - async def __aiter__(self) -> AsyncIterator[Row]: - while True: + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> Row: + # Fetch a new page if we never fetched any, or we are at the end of + # a page of size itersize, meaning there is likely a following one. + if self._iter_rows is None or ( + self._page_pos >= len(self._iter_rows) >= self.itersize + ): async with self._conn.lock: - recs = await self._conn.wait(self._fetch_gen(self.itersize)) - for rec in recs: - self._pos += 1 - yield rec - if len(recs) < self.itersize: - break + self._iter_rows = await self._conn.wait(self._fetch_gen(self.itersize)) + self._page_pos += 0 + + if self._page_pos >= len(self._iter_rows): + raise StopAsyncIteration("no more records to return") + + rec = self._iter_rows[self._page_pos] + self._page_pos += 1 + self._pos += 1 + return rec async def scroll(self, value: int, mode: str = "relative") -> None: async with self._conn.lock: diff --git a/psycopg/psycopg/_server_cursor_base.py b/psycopg/psycopg/_server_cursor_base.py index 4d1fba0d5..fef6f9070 100644 --- a/psycopg/psycopg/_server_cursor_base.py +++ b/psycopg/psycopg/_server_cursor_base.py @@ -28,7 +28,9 @@ INTRANS = pq.TransactionStatus.INTRANS class ServerCursorMixin(BaseCursor[ConnectionType, Row]): """Mixin to add ServerCursor behaviour and implementation a BaseCursor.""" - __slots__ = "_name _scrollable _withhold _described itersize _format".split() + __slots__ = """_name _scrollable _withhold _described itersize _format + _iter_rows _page_pos + """.split() def __init__(self, name: str, scrollable: bool | None, withhold: bool): self._name = name @@ -38,6 +40,10 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]): self.itersize: int = DEFAULT_ITERSIZE self._format = TEXT + # Hold the state during iteration: a fetched page and position within it + self._iter_rows: list[Row] | None = None + self._page_pos = 0 + def __repr__(self) -> str: # Insert the name as the second word parts = super().__repr__().split(None, 1) @@ -91,6 +97,7 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]): yield from self._close_gen() self._described = False + self._iter_rows = None yield from self._start_query(query) pgq = self._convert_query(query, params) self._execute_send(pgq, force_extended=True) diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 25449e384..56411dac0 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -224,16 +224,13 @@ class Cursor(BaseCursor["Connection[Any]", Row]): self._pos = self.pgresult.ntuples return records - def __iter__(self) -> Iterator[Row]: - self._fetch_pipeline() - self._check_result_for_fetch() - - def load(pos: int) -> Row | None: - return self._tx.load_row(pos, self._make_row) + def __iter__(self) -> Self: + return self - while (row := load(self._pos)) is not None: - self._pos += 1 - yield row + def __next__(self) -> Row: + if (rec := self.fetchone()) is not None: + return rec + raise StopIteration("no more records to return") def scroll(self, value: int, mode: str = "relative") -> None: """ diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index 0601088d9..d2f6e7757 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -228,16 +228,13 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): self._pos = self.pgresult.ntuples return records - async def __aiter__(self) -> AsyncIterator[Row]: - await self._fetch_pipeline() - self._check_result_for_fetch() - - def load(pos: int) -> Row | None: - return self._tx.load_row(pos, self._make_row) + def __aiter__(self) -> Self: + return self - while (row := load(self._pos)) is not None: - self._pos += 1 - yield row + async def __anext__(self) -> Row: + if (rec := await self.fetchone()) is not None: + return rec + raise StopAsyncIteration("no more records to return") async def scroll(self, value: int, mode: str = "relative") -> None: """ diff --git a/tests/test_cursor_common.py b/tests/test_cursor_common.py index e6237148b..89f91e677 100644 --- a/tests/test_cursor_common.py +++ b/tests/test_cursor_common.py @@ -162,6 +162,14 @@ def test_execute_sql(conn): assert cur.fetchone() == ("hello",) +def test_next(conn): + cur = conn.cursor() + cur.execute("select 1") + assert next(cur) == (1,) + with pytest.raises(StopIteration): + next(cur) + + def test_query_parse_cache_size(conn): cur = conn.cursor() cls = type(cur) diff --git a/tests/test_cursor_common_async.py b/tests/test_cursor_common_async.py index ae07fb67a..ac8ca65fb 100644 --- a/tests/test_cursor_common_async.py +++ b/tests/test_cursor_common_async.py @@ -160,6 +160,14 @@ async def test_execute_sql(aconn): assert (await cur.fetchone()) == ("hello",) +async def test_next(aconn): + cur = aconn.cursor() + await cur.execute("select 1") + assert await anext(cur) == (1,) + with pytest.raises(StopAsyncIteration): + await anext(cur) + + async def test_query_parse_cache_size(aconn): cur = aconn.cursor() cls = type(cur) diff --git a/tests/test_cursor_server.py b/tests/test_cursor_server.py index cc2d4cb5e..fd9192da0 100644 --- a/tests/test_cursor_server.py +++ b/tests/test_cursor_server.py @@ -431,6 +431,7 @@ def test_iter(conn): def test_iter_rownumber(conn): with conn.cursor("foo") as cur: + cur.itersize = 2 cur.execute(ph(cur, "select generate_series(1, %s) as bar"), (3,)) for row in cur: assert cur.rownumber == row[0] @@ -450,6 +451,14 @@ def test_itersize(conn, commands): assert "fetch forward 2" in cmd.lower() +def test_next(conn): + with conn.cursor() as cur: + cur.execute("select 1") + assert next(cur) == (1,) + with pytest.raises(StopIteration): + next(cur) + + def test_cant_scroll_by_default(conn): cur = conn.cursor("tmp") assert cur.scrollable is None diff --git a/tests/test_cursor_server_async.py b/tests/test_cursor_server_async.py index 98ae9a278..b94590ae4 100644 --- a/tests/test_cursor_server_async.py +++ b/tests/test_cursor_server_async.py @@ -437,6 +437,7 @@ async def test_iter(aconn): async def test_iter_rownumber(aconn): async with aconn.cursor("foo") as cur: + cur.itersize = 2 await cur.execute(ph(cur, "select generate_series(1, %s) as bar"), (3,)) async for row in cur: assert cur.rownumber == row[0] @@ -456,6 +457,14 @@ async def test_itersize(aconn, acommands): assert "fetch forward 2" in cmd.lower() +async def test_next(aconn): + async with aconn.cursor() as cur: + await cur.execute("select 1") + assert await anext(cur) == (1,) + with pytest.raises(StopAsyncIteration): + await anext(cur) + + async def test_cant_scroll_by_default(aconn): cur = aconn.cursor("tmp") assert cur.scrollable is None diff --git a/tests/test_typing.py b/tests/test_typing.py index 24894961c..28a3d9fd9 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -248,14 +248,17 @@ obj = {curs} ("many", "list[{type}]"), ("all", "list[{type}]"), ("iter", "{type}"), + ("next", "{type}"), ], ) def test_fetch_type(conn_class, server_side, factory, type, fetch, typemod, mypy): if "Async" in conn_class: async_ = "async " await_ = "await " + next_ = "anext" else: async_ = await_ = "" + next_ = "next" curs = f"conn.cursor({factory})" if server_side: @@ -273,6 +276,8 @@ curs = {curs} stmts += f"obj = {await_} curs.fetchall()" elif fetch == "iter": stmts += f"{async_}for obj in curs: pass" + elif fetch == "next": + stmts += f"obj = {await_} {next_}(curs)" else: pytest.fail(f"unexpected fetch: {fetch}") diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index 24d8918cc..4fa525802 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -296,9 +296,11 @@ class RenameAsyncToSync(ast.NodeTransformer): # type: ignore "AsyncServerCursor": "ServerCursor", "AsyncTransaction": "Transaction", "AsyncWriter": "Writer", + "StopAsyncIteration": "StopIteration", "__aenter__": "__enter__", "__aexit__": "__exit__", "__aiter__": "__iter__", + "__anext__": "__next__", "_copy_async": "_copy", "_server_cursor_async": "_server_cursor", "aclose": "close", @@ -363,6 +365,7 @@ class RenameAsyncToSync(ast.NodeTransformer): # type: ignore def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: self._fix_docstring(node.body) + node.name = self.names_map.get(node.name, node.name) if node.decorator_list: self._fix_decorator(node.decorator_list) self.generic_visit(node)