From 32df7f84ce71807251e105419cd5b804c186f049 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Fri, 30 Oct 2020 03:21:29 +0100 Subject: [PATCH] Cursor is iterable --- psycopg3/psycopg3/cursor.py | 35 ++++++++++++++++++----------------- tests/test_cursor.py | 21 +++++++++++++++++++++ tests/test_cursor_async.py | 25 +++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 17 deletions(-) diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index b14670a9c..bc83bde0a 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -5,7 +5,8 @@ psycopg3 cursor objects # Copyright (C) 2020 The Psycopg Team from types import TracebackType -from typing import Any, Callable, List, Optional, Sequence, Type, TYPE_CHECKING +from typing import Any, AsyncIterator, Callable, Iterator, List, Optional +from typing import Sequence, Type, TYPE_CHECKING from operator import attrgetter from . import errors as e @@ -438,21 +439,19 @@ class Cursor(BaseCursor): return rv def fetchall(self) -> List[Sequence[Any]]: + return list(self) + + def __iter__(self) -> Iterator[Sequence[Any]]: self._check_result() - rv: List[Sequence[Any]] = [] - pos = self._pos load = self._transformer.load_row while 1: - row = load(pos) + row = load(self._pos) if row is None: break - pos += 1 - rv.append(row) - - self._pos = pos - return rv + self._pos += 1 + yield row def copy(self, statement: Query, vars: Optional[Params] = None) -> Copy: with self.connection.lock: @@ -561,21 +560,23 @@ class AsyncCursor(BaseCursor): return rv async def fetchall(self) -> List[Sequence[Any]]: + res = [] + async for rec in self: + res.append(rec) + + return res + + async def __aiter__(self) -> AsyncIterator[Sequence[Any]]: self._check_result() - rv: List[Sequence[Any]] = [] - pos = self._pos load = self._transformer.load_row while 1: - row = load(pos) + row = load(self._pos) if row is None: break - pos += 1 - rv.append(row) - - self._pos = pos - return rv + self._pos += 1 + yield row async def copy( self, statement: Query, vars: Optional[Params] = None diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 4a943c0f1..fc6779e3c 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -202,6 +202,27 @@ def test_rowcount(conn): assert cur.rowcount == -1 +def test_iter(conn): + cur = conn.cursor() + cur.execute("select generate_series(1, 3)") + assert list(cur) == [(1,), (2,), (3,)] + + +def test_iter_stop(conn): + cur = conn.cursor() + cur.execute("select generate_series(1, 3)") + for rec in cur: + assert rec == (1,) + break + + for rec in cur: + assert rec == (2,) + break + + assert cur.fetchone() == (3,) + assert list(cur) == [] + + class TestColumn: def test_description_attribs(self, conn): curs = conn.cursor() diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index 3462b093b..f0c028317 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -202,3 +202,28 @@ async def test_rowcount(aconn): await cur.close() assert cur.rowcount == -1 + + +async def test_iter(aconn): + cur = aconn.cursor() + await cur.execute("select generate_series(1, 3)") + res = [] + async for rec in cur: + res.append(rec) + assert res == [(1,), (2,), (3,)] + + +async def test_iter_stop(aconn): + cur = aconn.cursor() + await cur.execute("select generate_series(1, 3)") + async for rec in cur: + assert rec == (1,) + break + + async for rec in cur: + assert rec == (2,) + break + + assert (await cur.fetchone()) == (3,) + async for rec in cur: + assert False -- 2.47.2