# Copyright (C) 2020 The Psycopg Team
from types import TracebackType
-from typing import Any, Callable, List, Optional, Sequence, Type, TYPE_CHECKING
+from typing import Any, AsyncIterator, Callable, Iterator, List, Optional
+from typing import Sequence, Type, TYPE_CHECKING
from operator import attrgetter
from . import errors as e
return rv
def fetchall(self) -> List[Sequence[Any]]:
+ return list(self)
+
+ def __iter__(self) -> Iterator[Sequence[Any]]:
self._check_result()
- rv: List[Sequence[Any]] = []
- pos = self._pos
load = self._transformer.load_row
while 1:
- row = load(pos)
+ row = load(self._pos)
if row is None:
break
- pos += 1
- rv.append(row)
-
- self._pos = pos
- return rv
+ self._pos += 1
+ yield row
def copy(self, statement: Query, vars: Optional[Params] = None) -> Copy:
with self.connection.lock:
return rv
async def fetchall(self) -> List[Sequence[Any]]:
+ res = []
+ async for rec in self:
+ res.append(rec)
+
+ return res
+
+ async def __aiter__(self) -> AsyncIterator[Sequence[Any]]:
self._check_result()
- rv: List[Sequence[Any]] = []
- pos = self._pos
load = self._transformer.load_row
while 1:
- row = load(pos)
+ row = load(self._pos)
if row is None:
break
- pos += 1
- rv.append(row)
-
- self._pos = pos
- return rv
+ self._pos += 1
+ yield row
async def copy(
self, statement: Query, vars: Optional[Params] = None
assert cur.rowcount == -1
+def test_iter(conn):
+ cur = conn.cursor()
+ cur.execute("select generate_series(1, 3)")
+ assert list(cur) == [(1,), (2,), (3,)]
+
+
+def test_iter_stop(conn):
+ cur = conn.cursor()
+ cur.execute("select generate_series(1, 3)")
+ for rec in cur:
+ assert rec == (1,)
+ break
+
+ for rec in cur:
+ assert rec == (2,)
+ break
+
+ assert cur.fetchone() == (3,)
+ assert list(cur) == []
+
+
class TestColumn:
def test_description_attribs(self, conn):
curs = conn.cursor()
await cur.close()
assert cur.rowcount == -1
+
+
+async def test_iter(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select generate_series(1, 3)")
+ res = []
+ async for rec in cur:
+ res.append(rec)
+ assert res == [(1,), (2,), (3,)]
+
+
+async def test_iter_stop(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select generate_series(1, 3)")
+ async for rec in cur:
+ assert rec == (1,)
+ break
+
+ async for rec in cur:
+ assert rec == (2,)
+ break
+
+ assert (await cur.fetchone()) == (3,)
+ async for rec in cur:
+ assert False