]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: make cursors iterators 1064/head
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 3 May 2025 02:07:37 +0000 (04:07 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 4 May 2025 21:32:02 +0000 (23:32 +0200)
This allows to call `next(cursor)`, which is guaranteed to never return
None, making type checking simpler.

14 files changed:
docs/advanced/typing.rst
docs/api/cursors.rst
docs/news.rst
psycopg/psycopg/_server_cursor.py
psycopg/psycopg/_server_cursor_async.py
psycopg/psycopg/_server_cursor_base.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
tests/test_cursor_common.py
tests/test_cursor_common_async.py
tests/test_cursor_server.py
tests/test_cursor_server_async.py
tests/test_typing.py
tools/async_to_sync.py

index 98efb4067c00406398ed7aa7f232ba7cdfe60940..c5b722fbdde310917c461ecdddfa0098e846b872 100644 (file)
@@ -38,6 +38,8 @@ annotations such as `!Connection[Any]` and `!Cursor[Any]`.
 
    rec = cur.fetchone()     # type is tuple[Any, ...] | None
 
+   rec = next(cur)          # type is tuple[Any, ...]
+
    recs = cur.fetchall()    # type is List[tuple[Any, ...]]
 
 
@@ -66,6 +68,56 @@ cursors and annotate the returned objects accordingly. See
    # drec type is dict[str, Any] | None
 
 
+.. _typing-fetchone:
+
+The ``fetchone()`` frustration
+------------------------------
+
+.. versionchanged:: 3.3
+
+If you use a static type checker and you are 100% sure that the cursor will
+exactly one record, it is frustrating to be told that the returned row might
+be `!None`. For example:
+
+.. code:: python
+
+    import psycopg
+    from psycopg.rows import scalar_row
+
+    def count_records() -> int:
+        conn = psycopg.connect()
+        cur = conn.cursor(row_factory=scalar_row)
+        cur.execute("SELECT count(*) FROM mytable")
+        rv: int = cur.fetchone()  # mypy error here
+        return rv
+
+The :sql:`count(*)` will always return a record with a number, even if the
+table is empty (it will just report 0). However, Mypy will report an error
+such as *incompatible types in assignment (expression has type "Any | None",
+variable has type "int")*. In order to work around the error you will need
+to use an `!if`, an `!assert` or some other workaround (like ``(rv,) =
+cur.fetchall()`` or some other horrible trick).
+
+Since Psycopg 3.3, cursors are iterables__, therefore they support the
+`next` function. A `!next(cur)` will behave like `!cur.fetchone()`, but it
+is guaranteed to return a row (in case there are no rows in the result set it
+will not return anything but will raise `!StopIteration`). Therefore the
+function above can terminate with:
+
+.. code:: python
+
+    def count_records() -> int:
+        ...
+        rv: int = next(cur)
+        return rv
+
+and your static checker will be happy.
+
+Similarly, in async code, you can use an `!await` `anext`\ `!(cur)` expression.
+
+.. __: https://docs.python.org/3/glossary.html#term-iterable
+
+
 .. _pool-generic:
 
 Generic pool types
index e1f6d0b0254bf9bbeaae6b8108c7d467ee305c12..826b568f0c9b75aae5d8fed75743479f022892ba 100644 (file)
@@ -239,13 +239,21 @@ The `!Cursor` class
 
     .. note::
 
-        Cursors are iterable objects, so just using the::
+        Cursors are iterators, so just using the::
 
             for record in cursor:
                 ...
 
         syntax will iterate on the records in the current result set.
 
+        .. versionchanged:: 3.3
+
+            it is now possible to use `!next(cursor)`. Previously, cursors were
+            iterables__, not iterators__.
+
+            .. __: https://docs.python.org/3/glossary.html#term-iterable
+            .. __: https://docs.python.org/3/glossary.html#term-iterator
+
     .. autoattribute:: row_factory
 
         The property affects the objects returned by the `fetchone()`,
index 29d163bcad83add2cd703432ff54c619fcc24688..17fcaaf25d22539f51bfed8f3307c7b8497e9af2 100644 (file)
@@ -13,6 +13,8 @@ Future releases
 Psycopg 3.3.0 (unreleased)
 ^^^^^^^^^^^^^^^^^^^^^^^^^^
 
+- Cursors are now iterators, not only iterables. This means you can call
+  ``next(cur)`` to fetch the next row (:ticket:`#1064`).
 - Drop support for Python 3.8 (:ticket:`#976`) and 3.9 (:ticket:`#1056`).
 
 
index 3b27d1fc736b556625021b2e118dea47536a0daf..cfddfd97c0eb786c4e873450861b3d9879cfda6b 100644 (file)
@@ -11,7 +11,7 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, Any, overload
 from warnings import warn
-from collections.abc import Iterable, Iterator
+from collections.abc import Iterable
 
 from . import errors as e
 from .abc import Params, Query
@@ -136,15 +136,27 @@ class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]):
         self._pos += len(recs)
         return recs
 
-    def __iter__(self) -> Iterator[Row]:
-        while True:
+    def __iter__(self) -> Self:
+        return self
+
+    def __next__(self) -> Row:
+        # Fetch a new page if we never fetched any, or we are at the end of
+        # a page of size itersize, meaning there is likely a following one.
+        if (
+            self._iter_rows is None
+            or self._page_pos >= len(self._iter_rows) >= self.itersize
+        ):
             with self._conn.lock:
-                recs = self._conn.wait(self._fetch_gen(self.itersize))
-            for rec in recs:
-                self._pos += 1
-                yield rec
-            if len(recs) < self.itersize:
-                break
+                self._iter_rows = self._conn.wait(self._fetch_gen(self.itersize))
+                self._page_pos += 0
+
+        if self._page_pos >= len(self._iter_rows):
+            raise StopIteration("no more records to return")
+
+        rec = self._iter_rows[self._page_pos]
+        self._page_pos += 1
+        self._pos += 1
+        return rec
 
     def scroll(self, value: int, mode: str = "relative") -> None:
         with self._conn.lock:
index 37e0e2621e0abff08d9ba310bf07fec9d1f7bd51..41628fe74cbf7fb015160bf277e40a587dffa26a 100644 (file)
@@ -8,7 +8,7 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, Any, overload
 from warnings import warn
-from collections.abc import AsyncIterator, Iterable
+from collections.abc import Iterable
 
 from . import errors as e
 from .abc import Params, Query
@@ -136,15 +136,26 @@ class AsyncServerCursor(
         self._pos += len(recs)
         return recs
 
-    async def __aiter__(self) -> AsyncIterator[Row]:
-        while True:
+    def __aiter__(self) -> Self:
+        return self
+
+    async def __anext__(self) -> Row:
+        # Fetch a new page if we never fetched any, or we are at the end of
+        # a page of size itersize, meaning there is likely a following one.
+        if self._iter_rows is None or (
+            self._page_pos >= len(self._iter_rows) >= self.itersize
+        ):
             async with self._conn.lock:
-                recs = await self._conn.wait(self._fetch_gen(self.itersize))
-            for rec in recs:
-                self._pos += 1
-                yield rec
-            if len(recs) < self.itersize:
-                break
+                self._iter_rows = await self._conn.wait(self._fetch_gen(self.itersize))
+                self._page_pos += 0
+
+        if self._page_pos >= len(self._iter_rows):
+            raise StopAsyncIteration("no more records to return")
+
+        rec = self._iter_rows[self._page_pos]
+        self._page_pos += 1
+        self._pos += 1
+        return rec
 
     async def scroll(self, value: int, mode: str = "relative") -> None:
         async with self._conn.lock:
index 4d1fba0d50a585426a567170eb7017c40d663358..fef6f9070edf5c3e3edd49794e9aa6633dff6f2a 100644 (file)
@@ -28,7 +28,9 @@ INTRANS = pq.TransactionStatus.INTRANS
 class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
     """Mixin to add ServerCursor behaviour and implementation a BaseCursor."""
 
-    __slots__ = "_name _scrollable _withhold _described itersize _format".split()
+    __slots__ = """_name _scrollable _withhold _described itersize _format
+        _iter_rows _page_pos
+    """.split()
 
     def __init__(self, name: str, scrollable: bool | None, withhold: bool):
         self._name = name
@@ -38,6 +40,10 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
         self.itersize: int = DEFAULT_ITERSIZE
         self._format = TEXT
 
+        # Hold the state during iteration: a fetched page and position within it
+        self._iter_rows: list[Row] | None = None
+        self._page_pos = 0
+
     def __repr__(self) -> str:
         # Insert the name as the second word
         parts = super().__repr__().split(None, 1)
@@ -91,6 +97,7 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
             yield from self._close_gen()
             self._described = False
 
+        self._iter_rows = None
         yield from self._start_query(query)
         pgq = self._convert_query(query, params)
         self._execute_send(pgq, force_extended=True)
index 25449e38442755c6c69edfb48f7fd7e85d0a971e..56411dac086b4b06c1a85e696a396e730b8aeca9 100644 (file)
@@ -224,16 +224,13 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         self._pos = self.pgresult.ntuples
         return records
 
-    def __iter__(self) -> Iterator[Row]:
-        self._fetch_pipeline()
-        self._check_result_for_fetch()
-
-        def load(pos: int) -> Row | None:
-            return self._tx.load_row(pos, self._make_row)
+    def __iter__(self) -> Self:
+        return self
 
-        while (row := load(self._pos)) is not None:
-            self._pos += 1
-            yield row
+    def __next__(self) -> Row:
+        if (rec := self.fetchone()) is not None:
+            return rec
+        raise StopIteration("no more records to return")
 
     def scroll(self, value: int, mode: str = "relative") -> None:
         """
index 0601088d9baae7f75e6ffb9c2ccb33edd365dac8..d2f6e775772a7c3a6172c565a8075a8b250a2332 100644 (file)
@@ -228,16 +228,13 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         self._pos = self.pgresult.ntuples
         return records
 
-    async def __aiter__(self) -> AsyncIterator[Row]:
-        await self._fetch_pipeline()
-        self._check_result_for_fetch()
-
-        def load(pos: int) -> Row | None:
-            return self._tx.load_row(pos, self._make_row)
+    def __aiter__(self) -> Self:
+        return self
 
-        while (row := load(self._pos)) is not None:
-            self._pos += 1
-            yield row
+    async def __anext__(self) -> Row:
+        if (rec := await self.fetchone()) is not None:
+            return rec
+        raise StopAsyncIteration("no more records to return")
 
     async def scroll(self, value: int, mode: str = "relative") -> None:
         """
index e6237148b4e75a146db89a57911747de20a25c61..89f91e677cca10ba7fb9f0220322ff42e11031f2 100644 (file)
@@ -162,6 +162,14 @@ def test_execute_sql(conn):
     assert cur.fetchone() == ("hello",)
 
 
+def test_next(conn):
+    cur = conn.cursor()
+    cur.execute("select 1")
+    assert next(cur) == (1,)
+    with pytest.raises(StopIteration):
+        next(cur)
+
+
 def test_query_parse_cache_size(conn):
     cur = conn.cursor()
     cls = type(cur)
index ae07fb67adf42b5e7d7e42813060670511e21228..ac8ca65fb90f1ac184e8157f4bb105c9247ec0b4 100644 (file)
@@ -160,6 +160,14 @@ async def test_execute_sql(aconn):
     assert (await cur.fetchone()) == ("hello",)
 
 
+async def test_next(aconn):
+    cur = aconn.cursor()
+    await cur.execute("select 1")
+    assert await anext(cur) == (1,)
+    with pytest.raises(StopAsyncIteration):
+        await anext(cur)
+
+
 async def test_query_parse_cache_size(aconn):
     cur = aconn.cursor()
     cls = type(cur)
index cc2d4cb5e4738b5314ec2346fba1b60187cc370a..fd9192da04e6fcbb771d86cabd87b4700a64803f 100644 (file)
@@ -431,6 +431,7 @@ def test_iter(conn):
 
 def test_iter_rownumber(conn):
     with conn.cursor("foo") as cur:
+        cur.itersize = 2
         cur.execute(ph(cur, "select generate_series(1, %s) as bar"), (3,))
         for row in cur:
             assert cur.rownumber == row[0]
@@ -450,6 +451,14 @@ def test_itersize(conn, commands):
             assert "fetch forward 2" in cmd.lower()
 
 
+def test_next(conn):
+    with conn.cursor() as cur:
+        cur.execute("select 1")
+        assert next(cur) == (1,)
+        with pytest.raises(StopIteration):
+            next(cur)
+
+
 def test_cant_scroll_by_default(conn):
     cur = conn.cursor("tmp")
     assert cur.scrollable is None
index 98ae9a27857bea3cbbdab6f81f31d537a9987ff8..b94590ae417a37d957990ea5fcad2494124958be 100644 (file)
@@ -437,6 +437,7 @@ async def test_iter(aconn):
 
 async def test_iter_rownumber(aconn):
     async with aconn.cursor("foo") as cur:
+        cur.itersize = 2
         await cur.execute(ph(cur, "select generate_series(1, %s) as bar"), (3,))
         async for row in cur:
             assert cur.rownumber == row[0]
@@ -456,6 +457,14 @@ async def test_itersize(aconn, acommands):
             assert "fetch forward 2" in cmd.lower()
 
 
+async def test_next(aconn):
+    async with aconn.cursor() as cur:
+        await cur.execute("select 1")
+        assert await anext(cur) == (1,)
+        with pytest.raises(StopAsyncIteration):
+            await anext(cur)
+
+
 async def test_cant_scroll_by_default(aconn):
     cur = aconn.cursor("tmp")
     assert cur.scrollable is None
index 24894961c3a37ff0f8c74b05f677f2bb381b2f19..28a3d9fd92e20e2249ffd13d6ed23f34b94a0d47 100644 (file)
@@ -248,14 +248,17 @@ obj = {curs}
         ("many", "list[{type}]"),
         ("all", "list[{type}]"),
         ("iter", "{type}"),
+        ("next", "{type}"),
     ],
 )
 def test_fetch_type(conn_class, server_side, factory, type, fetch, typemod, mypy):
     if "Async" in conn_class:
         async_ = "async "
         await_ = "await "
+        next_ = "anext"
     else:
         async_ = await_ = ""
+        next_ = "next"
 
     curs = f"conn.cursor({factory})"
     if server_side:
@@ -273,6 +276,8 @@ curs = {curs}
         stmts += f"obj = {await_} curs.fetchall()"
     elif fetch == "iter":
         stmts += f"{async_}for obj in curs: pass"
+    elif fetch == "next":
+        stmts += f"obj = {await_} {next_}(curs)"
     else:
         pytest.fail(f"unexpected fetch: {fetch}")
 
index 24d8918cc3b62895036547f6ad971f6e07e79b69..4fa5258023eca979325da7ad74b6a3dede5a011f 100755 (executable)
@@ -296,9 +296,11 @@ class RenameAsyncToSync(ast.NodeTransformer):  # type: ignore
         "AsyncServerCursor": "ServerCursor",
         "AsyncTransaction": "Transaction",
         "AsyncWriter": "Writer",
+        "StopAsyncIteration": "StopIteration",
         "__aenter__": "__enter__",
         "__aexit__": "__exit__",
         "__aiter__": "__iter__",
+        "__anext__": "__next__",
         "_copy_async": "_copy",
         "_server_cursor_async": "_server_cursor",
         "aclose": "close",
@@ -363,6 +365,7 @@ class RenameAsyncToSync(ast.NodeTransformer):  # type: ignore
 
     def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
         self._fix_docstring(node.body)
+        node.name = self.names_map.get(node.name, node.name)
         if node.decorator_list:
             self._fix_decorator(node.decorator_list)
         self.generic_visit(node)