]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Cursor is iterable
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 30 Oct 2020 02:21:29 +0000 (03:21 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 30 Oct 2020 02:21:29 +0000 (03:21 +0100)
psycopg3/psycopg3/cursor.py
tests/test_cursor.py
tests/test_cursor_async.py

index b14670a9ceedf6d7d9fb8cb330dfb54ea76e9c53..bc83bde0ab7b447d79275c911177c4e25bd296ba 100644 (file)
@@ -5,7 +5,8 @@ psycopg3 cursor objects
 # Copyright (C) 2020 The Psycopg Team
 
 from types import TracebackType
-from typing import Any, Callable, List, Optional, Sequence, Type, TYPE_CHECKING
+from typing import Any, AsyncIterator, Callable, Iterator, List, Optional
+from typing import Sequence, Type, TYPE_CHECKING
 from operator import attrgetter
 
 from . import errors as e
@@ -438,21 +439,19 @@ class Cursor(BaseCursor):
         return rv
 
     def fetchall(self) -> List[Sequence[Any]]:
+        return list(self)
+
+    def __iter__(self) -> Iterator[Sequence[Any]]:
         self._check_result()
 
-        rv: List[Sequence[Any]] = []
-        pos = self._pos
         load = self._transformer.load_row
 
         while 1:
-            row = load(pos)
+            row = load(self._pos)
             if row is None:
                 break
-            pos += 1
-            rv.append(row)
-
-        self._pos = pos
-        return rv
+            self._pos += 1
+            yield row
 
     def copy(self, statement: Query, vars: Optional[Params] = None) -> Copy:
         with self.connection.lock:
@@ -561,21 +560,23 @@ class AsyncCursor(BaseCursor):
         return rv
 
     async def fetchall(self) -> List[Sequence[Any]]:
+        res = []
+        async for rec in self:
+            res.append(rec)
+
+        return res
+
+    async def __aiter__(self) -> AsyncIterator[Sequence[Any]]:
         self._check_result()
 
-        rv: List[Sequence[Any]] = []
-        pos = self._pos
         load = self._transformer.load_row
 
         while 1:
-            row = load(pos)
+            row = load(self._pos)
             if row is None:
                 break
-            pos += 1
-            rv.append(row)
-
-        self._pos = pos
-        return rv
+            self._pos += 1
+            yield row
 
     async def copy(
         self, statement: Query, vars: Optional[Params] = None
index 4a943c0f1d3e7b6f39dbccd478177204a4a8a4b3..fc6779e3cbcdb0a707da59bd4596fb9f615aff9e 100644 (file)
@@ -202,6 +202,27 @@ def test_rowcount(conn):
     assert cur.rowcount == -1
 
 
+def test_iter(conn):
+    cur = conn.cursor()
+    cur.execute("select generate_series(1, 3)")
+    assert list(cur) == [(1,), (2,), (3,)]
+
+
+def test_iter_stop(conn):
+    cur = conn.cursor()
+    cur.execute("select generate_series(1, 3)")
+    for rec in cur:
+        assert rec == (1,)
+        break
+
+    for rec in cur:
+        assert rec == (2,)
+        break
+
+    assert cur.fetchone() == (3,)
+    assert list(cur) == []
+
+
 class TestColumn:
     def test_description_attribs(self, conn):
         curs = conn.cursor()
index 3462b093bcb5fdf343c45797fa3fec3656a202b0..f0c028317e84fd9e01e5146bbd150a1d4e24456f 100644 (file)
@@ -202,3 +202,28 @@ async def test_rowcount(aconn):
 
     await cur.close()
     assert cur.rowcount == -1
+
+
+async def test_iter(aconn):
+    cur = aconn.cursor()
+    await cur.execute("select generate_series(1, 3)")
+    res = []
+    async for rec in cur:
+        res.append(rec)
+    assert res == [(1,), (2,), (3,)]
+
+
+async def test_iter_stop(aconn):
+    cur = aconn.cursor()
+    await cur.execute("select generate_series(1, 3)")
+    async for rec in cur:
+        assert rec == (1,)
+        break
+
+    async for rec in cur:
+        assert rec == (2,)
+        break
+
+    assert (await cur.fetchone()) == (3,)
+    async for rec in cur:
+        assert False