rec = cur.fetchone() # type is tuple[Any, ...] | None
+ rec = next(cur) # type is tuple[Any, ...]
+
recs = cur.fetchall() # type is List[tuple[Any, ...]]
# drec type is dict[str, Any] | None
+.. _typing-fetchone:
+
+The ``fetchone()`` frustration
+------------------------------
+
+.. versionchanged:: 3.3
+
+If you use a static type checker and you are 100% sure that the cursor will
+exactly one record, it is frustrating to be told that the returned row might
+be `!None`. For example:
+
+.. code:: python
+
+ import psycopg
+ from psycopg.rows import scalar_row
+
+ def count_records() -> int:
+ conn = psycopg.connect()
+ cur = conn.cursor(row_factory=scalar_row)
+ cur.execute("SELECT count(*) FROM mytable")
+ rv: int = cur.fetchone() # mypy error here
+ return rv
+
+The :sql:`count(*)` will always return a record with a number, even if the
+table is empty (it will just report 0). However, Mypy will report an error
+such as *incompatible types in assignment (expression has type "Any | None",
+variable has type "int")*. In order to work around the error you will need
+to use an `!if`, an `!assert` or some other workaround (like ``(rv,) =
+cur.fetchall()`` or some other horrible trick).
+
+Since Psycopg 3.3, cursors are iterables__, therefore they support the
+`next` function. A `!next(cur)` will behave like `!cur.fetchone()`, but it
+is guaranteed to return a row (in case there are no rows in the result set it
+will not return anything but will raise `!StopIteration`). Therefore the
+function above can terminate with:
+
+.. code:: python
+
+ def count_records() -> int:
+ ...
+ rv: int = next(cur)
+ return rv
+
+and your static checker will be happy.
+
+Similarly, in async code, you can use an `!await` `anext`\ `!(cur)` expression.
+
+.. __: https://docs.python.org/3/glossary.html#term-iterable
+
+
.. _pool-generic:
Generic pool types
.. note::
- Cursors are iterable objects, so just using the::
+ Cursors are iterators, so just using the::
for record in cursor:
...
syntax will iterate on the records in the current result set.
+ .. versionchanged:: 3.3
+
+ it is now possible to use `!next(cursor)`. Previously, cursors were
+ iterables__, not iterators__.
+
+ .. __: https://docs.python.org/3/glossary.html#term-iterable
+ .. __: https://docs.python.org/3/glossary.html#term-iterator
+
.. autoattribute:: row_factory
The property affects the objects returned by the `fetchone()`,
Psycopg 3.3.0 (unreleased)
^^^^^^^^^^^^^^^^^^^^^^^^^^
+- Cursors are now iterators, not only iterables. This means you can call
+ ``next(cur)`` to fetch the next row (:ticket:`#1064`).
- Drop support for Python 3.8 (:ticket:`#976`) and 3.9 (:ticket:`#1056`).
from typing import TYPE_CHECKING, Any, overload
from warnings import warn
-from collections.abc import Iterable, Iterator
+from collections.abc import Iterable
from . import errors as e
from .abc import Params, Query
self._pos += len(recs)
return recs
- def __iter__(self) -> Iterator[Row]:
- while True:
+ def __iter__(self) -> Self:
+ return self
+
+ def __next__(self) -> Row:
+ # Fetch a new page if we never fetched any, or we are at the end of
+ # a page of size itersize, meaning there is likely a following one.
+ if (
+ self._iter_rows is None
+ or self._page_pos >= len(self._iter_rows) >= self.itersize
+ ):
with self._conn.lock:
- recs = self._conn.wait(self._fetch_gen(self.itersize))
- for rec in recs:
- self._pos += 1
- yield rec
- if len(recs) < self.itersize:
- break
+ self._iter_rows = self._conn.wait(self._fetch_gen(self.itersize))
+ self._page_pos += 0
+
+ if self._page_pos >= len(self._iter_rows):
+ raise StopIteration("no more records to return")
+
+ rec = self._iter_rows[self._page_pos]
+ self._page_pos += 1
+ self._pos += 1
+ return rec
def scroll(self, value: int, mode: str = "relative") -> None:
with self._conn.lock:
from typing import TYPE_CHECKING, Any, overload
from warnings import warn
-from collections.abc import AsyncIterator, Iterable
+from collections.abc import Iterable
from . import errors as e
from .abc import Params, Query
self._pos += len(recs)
return recs
- async def __aiter__(self) -> AsyncIterator[Row]:
- while True:
+ def __aiter__(self) -> Self:
+ return self
+
+ async def __anext__(self) -> Row:
+ # Fetch a new page if we never fetched any, or we are at the end of
+ # a page of size itersize, meaning there is likely a following one.
+ if self._iter_rows is None or (
+ self._page_pos >= len(self._iter_rows) >= self.itersize
+ ):
async with self._conn.lock:
- recs = await self._conn.wait(self._fetch_gen(self.itersize))
- for rec in recs:
- self._pos += 1
- yield rec
- if len(recs) < self.itersize:
- break
+ self._iter_rows = await self._conn.wait(self._fetch_gen(self.itersize))
+ self._page_pos += 0
+
+ if self._page_pos >= len(self._iter_rows):
+ raise StopAsyncIteration("no more records to return")
+
+ rec = self._iter_rows[self._page_pos]
+ self._page_pos += 1
+ self._pos += 1
+ return rec
async def scroll(self, value: int, mode: str = "relative") -> None:
async with self._conn.lock:
class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
"""Mixin to add ServerCursor behaviour and implementation a BaseCursor."""
- __slots__ = "_name _scrollable _withhold _described itersize _format".split()
+ __slots__ = """_name _scrollable _withhold _described itersize _format
+ _iter_rows _page_pos
+ """.split()
def __init__(self, name: str, scrollable: bool | None, withhold: bool):
self._name = name
self.itersize: int = DEFAULT_ITERSIZE
self._format = TEXT
+ # Hold the state during iteration: a fetched page and position within it
+ self._iter_rows: list[Row] | None = None
+ self._page_pos = 0
+
def __repr__(self) -> str:
# Insert the name as the second word
parts = super().__repr__().split(None, 1)
yield from self._close_gen()
self._described = False
+ self._iter_rows = None
yield from self._start_query(query)
pgq = self._convert_query(query, params)
self._execute_send(pgq, force_extended=True)
self._pos = self.pgresult.ntuples
return records
- def __iter__(self) -> Iterator[Row]:
- self._fetch_pipeline()
- self._check_result_for_fetch()
-
- def load(pos: int) -> Row | None:
- return self._tx.load_row(pos, self._make_row)
+ def __iter__(self) -> Self:
+ return self
- while (row := load(self._pos)) is not None:
- self._pos += 1
- yield row
+ def __next__(self) -> Row:
+ if (rec := self.fetchone()) is not None:
+ return rec
+ raise StopIteration("no more records to return")
def scroll(self, value: int, mode: str = "relative") -> None:
"""
self._pos = self.pgresult.ntuples
return records
- async def __aiter__(self) -> AsyncIterator[Row]:
- await self._fetch_pipeline()
- self._check_result_for_fetch()
-
- def load(pos: int) -> Row | None:
- return self._tx.load_row(pos, self._make_row)
+ def __aiter__(self) -> Self:
+ return self
- while (row := load(self._pos)) is not None:
- self._pos += 1
- yield row
+ async def __anext__(self) -> Row:
+ if (rec := await self.fetchone()) is not None:
+ return rec
+ raise StopAsyncIteration("no more records to return")
async def scroll(self, value: int, mode: str = "relative") -> None:
"""
assert cur.fetchone() == ("hello",)
+def test_next(conn):
+ cur = conn.cursor()
+ cur.execute("select 1")
+ assert next(cur) == (1,)
+ with pytest.raises(StopIteration):
+ next(cur)
+
+
def test_query_parse_cache_size(conn):
cur = conn.cursor()
cls = type(cur)
assert (await cur.fetchone()) == ("hello",)
+async def test_next(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert await anext(cur) == (1,)
+ with pytest.raises(StopAsyncIteration):
+ await anext(cur)
+
+
async def test_query_parse_cache_size(aconn):
cur = aconn.cursor()
cls = type(cur)
def test_iter_rownumber(conn):
with conn.cursor("foo") as cur:
+ cur.itersize = 2
cur.execute(ph(cur, "select generate_series(1, %s) as bar"), (3,))
for row in cur:
assert cur.rownumber == row[0]
assert "fetch forward 2" in cmd.lower()
+def test_next(conn):
+ with conn.cursor() as cur:
+ cur.execute("select 1")
+ assert next(cur) == (1,)
+ with pytest.raises(StopIteration):
+ next(cur)
+
+
def test_cant_scroll_by_default(conn):
cur = conn.cursor("tmp")
assert cur.scrollable is None
async def test_iter_rownumber(aconn):
async with aconn.cursor("foo") as cur:
+ cur.itersize = 2
await cur.execute(ph(cur, "select generate_series(1, %s) as bar"), (3,))
async for row in cur:
assert cur.rownumber == row[0]
assert "fetch forward 2" in cmd.lower()
+async def test_next(aconn):
+ async with aconn.cursor() as cur:
+ await cur.execute("select 1")
+ assert await anext(cur) == (1,)
+ with pytest.raises(StopAsyncIteration):
+ await anext(cur)
+
+
async def test_cant_scroll_by_default(aconn):
cur = aconn.cursor("tmp")
assert cur.scrollable is None
("many", "list[{type}]"),
("all", "list[{type}]"),
("iter", "{type}"),
+ ("next", "{type}"),
],
)
def test_fetch_type(conn_class, server_side, factory, type, fetch, typemod, mypy):
if "Async" in conn_class:
async_ = "async "
await_ = "await "
+ next_ = "anext"
else:
async_ = await_ = ""
+ next_ = "next"
curs = f"conn.cursor({factory})"
if server_side:
stmts += f"obj = {await_} curs.fetchall()"
elif fetch == "iter":
stmts += f"{async_}for obj in curs: pass"
+ elif fetch == "next":
+ stmts += f"obj = {await_} {next_}(curs)"
else:
pytest.fail(f"unexpected fetch: {fetch}")
"AsyncServerCursor": "ServerCursor",
"AsyncTransaction": "Transaction",
"AsyncWriter": "Writer",
+ "StopAsyncIteration": "StopIteration",
"__aenter__": "__enter__",
"__aexit__": "__exit__",
"__aiter__": "__iter__",
+ "__anext__": "__next__",
"_copy_async": "_copy",
"_server_cursor_async": "_server_cursor",
"aclose": "close",
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
self._fix_docstring(node.body)
+ node.name = self.names_map.get(node.name, node.name)
if node.decorator_list:
self._fix_decorator(node.decorator_list)
self.generic_visit(node)