]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add fetch methods to named cursors
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 9 Feb 2021 20:04:00 +0000 (21:04 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 9 Feb 2021 20:09:20 +0000 (21:09 +0100)
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/named_cursor.py
psycopg3/psycopg3/transaction.py
tests/fix_db.py
tests/test_named_cursor.py
tests/test_named_cursor_async.py
tests/test_transaction.py
tests/test_transaction_async.py

index 7d097a60135e8d0f724a39b36a180543712d1593..c2a5a8436715481177318db99fdb5a6fdf7e8dbd 100644 (file)
@@ -321,12 +321,12 @@ class BaseConnection(AdaptContext):
         conn._autocommit = autocommit
         return conn
 
-    def _exec_command(self, command: Query) -> PQGen[None]:
+    def _exec_command(self, command: Query) -> PQGen["PGresult"]:
         """
         Generator to send a command and receive the result to the backend.
 
-        Only used to implement internal commands such as commit, returning
-        no result. The cursor can do more complex stuff.
+        Only used to implement internal commands such as "commit", with eventual
+        arguments bound client-side. The cursor can do more complex stuff.
         """
         if self.pgconn.status != ConnStatus.OK:
             if self.pgconn.status == ConnStatus.BAD:
@@ -343,7 +343,7 @@ class BaseConnection(AdaptContext):
 
         self.pgconn.send_query(command)
         result = (yield from execute(self.pgconn))[-1]
-        if result.status != ExecStatus.COMMAND_OK:
+        if result.status not in (ExecStatus.COMMAND_OK, ExecStatus.TUPLES_OK):
             if result.status == ExecStatus.FATAL_ERROR:
                 raise e.error_from_result(
                     result, encoding=self.client_encoding
@@ -353,6 +353,7 @@ class BaseConnection(AdaptContext):
                     f"unexpected result {ExecStatus(result.status).name}"
                     f" from command {command.decode('utf8')!r}"
                 )
+        return result
 
     def _start_query(self) -> PQGen[None]:
         """Generator to start a transaction if necessary."""
index 54e10e890e3bf339517ef381089c93159d483444..109b2bb236feb5cc127bb6281cc275c2e4d10d37 100644 (file)
@@ -7,7 +7,8 @@ psycopg3 named cursor objects (server-side cursors)
 import weakref
 import warnings
 from types import TracebackType
-from typing import Any, Generic, Optional, Type, TYPE_CHECKING
+from typing import Any, AsyncIterator, Generic, List, Iterator, Optional
+from typing import Sequence, Type, Tuple, TYPE_CHECKING
 
 from . import sql
 from .pq import Format
@@ -18,6 +19,8 @@ if TYPE_CHECKING:
     from .connection import BaseConnection  # noqa: F401
     from .connection import Connection, AsyncConnection  # noqa: F401
 
+DEFAULT_ITERSIZE = 100
+
 
 class NamedCursorHelper(Generic[ConnectionType]):
     __slots__ = ("name", "_wcur")
@@ -46,17 +49,18 @@ class NamedCursorHelper(Generic[ConnectionType]):
     ) -> PQGen[None]:
         """Generator implementing `NamedCursor.execute()`."""
         cur = self._cur
+        conn = cur._conn
         yield from cur._start_query(query)
         pgq = cur._convert_query(query, params)
         cur._execute_send(pgq)
-        results = yield from execute(cur._conn.pgconn)
+        results = yield from execute(conn.pgconn)
         cur._execute_results(results)
 
         # The above result is an COMMAND_OK. Get the cursor result shape
-        cur._conn.pgconn.send_describe_portal(
-            self.name.encode(cur._conn.client_encoding)
+        conn.pgconn.send_describe_portal(
+            self.name.encode(conn.client_encoding)
         )
-        results = yield from execute(cur._conn.pgconn)
+        results = yield from execute(conn.pgconn)
         cur._execute_results(results)
 
     def _close_gen(self) -> PQGen[None]:
@@ -64,6 +68,24 @@ class NamedCursorHelper(Generic[ConnectionType]):
         query = sql.SQL("close {}").format(sql.Identifier(self.name))
         yield from cur._conn._exec_command(query)
 
+    def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Tuple[Any, ...]]]:
+        if num is not None:
+            howmuch: sql.Composable = sql.Literal(num)
+        else:
+            howmuch = sql.SQL("all")
+
+        cur = self._cur
+        query = sql.SQL("fetch forward {} from {}").format(
+            howmuch, sql.Identifier(self.name)
+        )
+        res = yield from cur._conn._exec_command(query)
+
+        # TODO: loaders don't need to be refreshed
+        cur.pgresult = res
+        nrows = res.ntuples
+        cur._pos += nrows
+        return cur._tx.load_rows(0, nrows)
+
     def _make_declare_statement(
         self, query: Query, scrollable: bool, hold: bool
     ) -> sql.Composable:
@@ -85,7 +107,7 @@ class NamedCursorHelper(Generic[ConnectionType]):
 
 class NamedCursor(BaseCursor["Connection"]):
     __module__ = "psycopg3"
-    __slots__ = ("_helper",)
+    __slots__ = ("_helper", "itersize")
 
     def __init__(
         self,
@@ -96,6 +118,7 @@ class NamedCursor(BaseCursor["Connection"]):
     ):
         super().__init__(connection, format=format)
         self._helper = NamedCursorHelper(name, self)
+        self.itersize = DEFAULT_ITERSIZE
 
     def __del__(self) -> None:
         if not self._closed:
@@ -146,10 +169,36 @@ class NamedCursor(BaseCursor["Connection"]):
             self._conn.wait(self._helper._declare_gen(query, params))
         return self
 
+    def fetchone(self) -> Optional[Sequence[Any]]:
+        with self._conn.lock:
+            recs = self._conn.wait(self._helper._fetch_gen(1))
+        return recs[0] if recs else None
+
+    def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]:
+        if not size:
+            size = self.arraysize
+        with self._conn.lock:
+            recs = self._conn.wait(self._helper._fetch_gen(size))
+        return recs
+
+    def fetchall(self) -> Sequence[Sequence[Any]]:
+        with self._conn.lock:
+            recs = self._conn.wait(self._helper._fetch_gen(None))
+        return recs
+
+    def __iter__(self) -> Iterator[Sequence[Any]]:
+        while True:
+            with self._conn.lock:
+                recs = self._conn.wait(self._helper._fetch_gen(self.itersize))
+            for rec in recs:
+                yield rec
+            if len(recs) < self.itersize:
+                break
+
 
 class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
     __module__ = "psycopg3"
-    __slots__ = ("_helper",)
+    __slots__ = ("_helper", "itersize")
 
     def __init__(
         self,
@@ -160,6 +209,7 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
     ):
         super().__init__(connection, format=format)
         self._helper = NamedCursorHelper(name, self)
+        self.itersize = DEFAULT_ITERSIZE
 
     def __del__(self) -> None:
         if not self._closed:
@@ -209,3 +259,31 @@ class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
         async with self._conn.lock:
             await self._conn.wait(self._helper._declare_gen(query, params))
         return self
+
+    async def fetchone(self) -> Optional[Sequence[Any]]:
+        async with self._conn.lock:
+            recs = await self._conn.wait(self._helper._fetch_gen(1))
+        return recs[0] if recs else None
+
+    async def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]:
+        if not size:
+            size = self.arraysize
+        async with self._conn.lock:
+            recs = await self._conn.wait(self._helper._fetch_gen(size))
+        return recs
+
+    async def fetchall(self) -> Sequence[Sequence[Any]]:
+        async with self._conn.lock:
+            recs = await self._conn.wait(self._helper._fetch_gen(None))
+        return recs
+
+    async def __aiter__(self) -> AsyncIterator[Sequence[Any]]:
+        while True:
+            async with self._conn.lock:
+                recs = await self._conn.wait(
+                    self._helper._fetch_gen(self.itersize)
+                )
+            for rec in recs:
+                yield rec
+            if len(recs) < self.itersize:
+                break
index a07ce6f73f69fb189c7f3ed58b0ec60a15d05f9f..650a4f1ef38bcd72ca4eea914e32257699a6214b 100644 (file)
@@ -13,6 +13,7 @@ from . import pq
 from . import sql
 from .pq import TransactionStatus
 from .proto import ConnectionType, PQGen
+from .pq.proto import PGresult
 
 if TYPE_CHECKING:
     from .connection import Connection, AsyncConnection  # noqa: F401
@@ -80,7 +81,7 @@ class BaseTransaction(Generic[ConnectionType]):
         sp = f"{self.savepoint_name!r} " if self.savepoint_name else ""
         return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>"
 
-    def _enter_gen(self) -> PQGen[None]:
+    def _enter_gen(self) -> PQGen[PGresult]:
         if self._entered:
             raise TypeError("transaction blocks can be used only once")
         self._entered = True
@@ -126,7 +127,7 @@ class BaseTransaction(Generic[ConnectionType]):
         else:
             return (yield from self._rollback_gen(exc_val))
 
-    def _commit_gen(self) -> PQGen[None]:
+    def _commit_gen(self) -> PQGen[PGresult]:
         assert self._conn._savepoints[-1] == self._savepoint_name
         self._conn._savepoints.pop()
         self._exited = True
index 039bc1be0194acc51b9bc6c786b4298760b213fa..1cab75592dcdbb0b16999fa722363ef1f2932801 100644 (file)
@@ -65,3 +65,45 @@ def svcconn(dsn):
     conn = Connection.connect(dsn, autocommit=True)
     yield conn
     conn.close()
+
+
+@pytest.fixture
+def commands(conn, monkeypatch):
+    """The list of commands issued internally by the test connection."""
+    yield patch_exec(conn, monkeypatch)
+
+
+@pytest.fixture
+def acommands(aconn, monkeypatch):
+    """The list of commands issued internally by the test async connection."""
+    yield patch_exec(aconn, monkeypatch)
+
+
+def patch_exec(conn, monkeypatch):
+    """Helper to implement the commands fixture both sync and async."""
+    from psycopg3 import sql
+
+    _orig_exec_command = conn._exec_command
+    L = ListPopAll()
+
+    def _exec_command(command):
+        cmdcopy = command
+        if isinstance(cmdcopy, bytes):
+            cmdcopy = cmdcopy.decode(conn.client_encoding)
+        elif isinstance(cmdcopy, sql.Composable):
+            cmdcopy = cmdcopy.as_string(conn)
+
+        L.insert(0, cmdcopy)
+        return _orig_exec_command(command)
+
+    monkeypatch.setattr(conn, "_exec_command", _exec_command)
+    return L
+
+
+class ListPopAll(list):
+    """A list, with a popall() method."""
+
+    def popall(self):
+        out = self[:]
+        del self[:]
+        return out
index b9421a0cb2dc30567e5d73f70bbad060975fe4de..5dd89cbae48cbd204d5be15b2a7030e9ac7faf2e 100644 (file)
@@ -38,3 +38,77 @@ def test_warn_close(conn, recwarn):
     cur.execute("select generate_series(1, 10) as bar")
     del cur
     assert ".close()" in str(recwarn.pop(ResourceWarning).message)
+
+
+def test_fetchone(conn):
+    with conn.cursor("foo") as cur:
+        cur.execute("select generate_series(1, %s) as bar", (2,))
+        assert cur.fetchone() == (1,)
+        assert cur.fetchone() == (2,)
+        assert cur.fetchone() is None
+
+
+def test_fetchmany(conn):
+    with conn.cursor("foo") as cur:
+        cur.execute("select generate_series(1, %s) as bar", (5,))
+        assert cur.fetchmany(3) == [(1,), (2,), (3,)]
+        assert cur.fetchone() == (4,)
+        assert cur.fetchmany(3) == [(5,)]
+        assert cur.fetchmany(3) == []
+
+
+def test_fetchall(conn):
+    with conn.cursor("foo") as cur:
+        cur.execute("select generate_series(1, %s) as bar", (3,))
+        assert cur.fetchall() == [(1,), (2,), (3,)]
+        assert cur.fetchall() == []
+
+    with conn.cursor("foo") as cur:
+        cur.execute("select generate_series(1, %s) as bar", (3,))
+        assert cur.fetchone() == (1,)
+        assert cur.fetchall() == [(2,), (3,)]
+        assert cur.fetchall() == []
+
+
+def test_rownumber(conn):
+    cur = conn.cursor("foo")
+    assert cur.rownumber is None
+
+    cur.execute("select 1 from generate_series(1, 42)")
+    assert cur.rownumber == 0
+
+    cur.fetchone()
+    assert cur.rownumber == 1
+    cur.fetchone()
+    assert cur.rownumber == 2
+    cur.fetchmany(10)
+    assert cur.rownumber == 12
+    cur.fetchall()
+    assert cur.rownumber == 42
+
+
+def test_iter(conn):
+    with conn.cursor("foo") as cur:
+        cur.execute("select generate_series(1, %s) as bar", (3,))
+        recs = list(cur)
+    assert recs == [(1,), (2,), (3,)]
+
+    with conn.cursor("foo") as cur:
+        cur.execute("select generate_series(1, %s) as bar", (3,))
+        assert cur.fetchone() == (1,)
+        recs = list(cur)
+    assert recs == [(2,), (3,)]
+
+
+def test_itersize(conn, commands):
+    with conn.cursor("foo") as cur:
+        assert cur.itersize == 100
+        cur.itersize = 2
+        cur.execute("select generate_series(1, %s) as bar", (3,))
+        commands.popall()  # flush begin and other noise
+
+        list(cur)
+        cmds = commands.popall()
+        assert len(cmds) == 2
+        for cmd in cmds:
+            assert ("fetch forward 2") in cmd.lower()
index 0ceeee11436ea54ec58336fa6004083a0eda4d17..d56951a736d4a84859264de5eace517b77c944a9 100644 (file)
@@ -43,3 +43,82 @@ async def test_warn_close(aconn, recwarn):
     await cur.execute("select generate_series(1, 10) as bar")
     del cur
     assert ".close()" in str(recwarn.pop(ResourceWarning).message)
+
+
+async def test_fetchone(aconn):
+    async with await aconn.cursor("foo") as cur:
+        await cur.execute("select generate_series(1, %s) as bar", (2,))
+        assert await cur.fetchone() == (1,)
+        assert await cur.fetchone() == (2,)
+        assert await cur.fetchone() is None
+
+
+async def test_fetchmany(aconn):
+    async with await aconn.cursor("foo") as cur:
+        await cur.execute("select generate_series(1, %s) as bar", (5,))
+        assert await cur.fetchmany(3) == [(1,), (2,), (3,)]
+        assert await cur.fetchone() == (4,)
+        assert await cur.fetchmany(3) == [(5,)]
+        assert await cur.fetchmany(3) == []
+
+
+async def test_fetchall(aconn):
+    async with await aconn.cursor("foo") as cur:
+        await cur.execute("select generate_series(1, %s) as bar", (3,))
+        assert await cur.fetchall() == [(1,), (2,), (3,)]
+        assert await cur.fetchall() == []
+
+    async with await aconn.cursor("foo") as cur:
+        await cur.execute("select generate_series(1, %s) as bar", (3,))
+        assert await cur.fetchone() == (1,)
+        assert await cur.fetchall() == [(2,), (3,)]
+        assert await cur.fetchall() == []
+
+
+async def test_rownumber(aconn):
+    cur = await aconn.cursor("foo")
+    assert cur.rownumber is None
+
+    await cur.execute("select 1 from generate_series(1, 42)")
+    assert cur.rownumber == 0
+
+    await cur.fetchone()
+    assert cur.rownumber == 1
+    await cur.fetchone()
+    assert cur.rownumber == 2
+    await cur.fetchmany(10)
+    assert cur.rownumber == 12
+    await cur.fetchall()
+    assert cur.rownumber == 42
+
+
+async def test_iter(aconn):
+    async with await aconn.cursor("foo") as cur:
+        await cur.execute("select generate_series(1, %s) as bar", (3,))
+        recs = []
+        async for rec in cur:
+            recs.append(rec)
+    assert recs == [(1,), (2,), (3,)]
+
+    async with await aconn.cursor("foo") as cur:
+        await cur.execute("select generate_series(1, %s) as bar", (3,))
+        assert await cur.fetchone() == (1,)
+        recs = []
+        async for rec in cur:
+            recs.append(rec)
+    assert recs == [(2,), (3,)]
+
+
+async def test_itersize(aconn, acommands):
+    async with await aconn.cursor("foo") as cur:
+        assert cur.itersize == 100
+        cur.itersize = 2
+        await cur.execute("select generate_series(1, %s) as bar", (3,))
+        acommands.popall()  # flush begin and other noise
+
+        async for rec in cur:
+            pass
+        cmds = acommands.popall()
+        assert len(cmds) == 2
+        for cmd in cmds:
+            assert ("fetch forward 2") in cmd.lower()
index 5f48edb92788d8f69f233bfc8e6293bc8e1721d4..48795531ede6764378fac11b9cf92ba7c0fedae1 100644 (file)
@@ -43,37 +43,6 @@ def inserted(conn):
         return f()
 
 
-class ListPopAll(list):
-    """A list, with a popall() method."""
-
-    def popall(self):
-        out = self[:]
-        del self[:]
-        return out
-
-
-@pytest.fixture
-def commands(conn, monkeypatch):
-    """The list of commands issued internally by the test connection."""
-    yield patch_exec(conn, monkeypatch)
-
-
-def patch_exec(conn, monkeypatch):
-    """Helper to implement the commands fixture both sync and async."""
-    _orig_exec_command = conn._exec_command
-    L = ListPopAll()
-
-    def _exec_command(command):
-        if isinstance(command, bytes):
-            command = command.decode(conn.client_encoding)
-
-        L.insert(0, command)
-        return _orig_exec_command(command)
-
-    monkeypatch.setattr(conn, "_exec_command", _exec_command)
-    return L
-
-
 def in_transaction(conn):
     if conn.pgconn.transaction_status == conn.TransactionStatus.IDLE:
         return False
index 2bfab93018dd0a5401579886db4a3a592c83b764..c955039c9ae8822f227908c227307acb6a34f9c3 100644 (file)
@@ -3,18 +3,12 @@ import pytest
 from psycopg3 import ProgrammingError, Rollback
 
 from .test_transaction import in_transaction, insert_row, inserted
-from .test_transaction import ExpectedException, patch_exec
+from .test_transaction import ExpectedException
 from .test_transaction import create_test_table  # noqa  # autouse fixture
 
 pytestmark = pytest.mark.asyncio
 
 
-@pytest.fixture
-def commands(aconn, monkeypatch):
-    """The list of commands issued internally by the test connection."""
-    yield patch_exec(aconn, monkeypatch)
-
-
 async def test_basic(aconn):
     """Basic use of transaction() to BEGIN and COMMIT a transaction."""
     assert not in_transaction(aconn)
@@ -310,7 +304,7 @@ async def test_named_savepoint_escapes_savepoint_name(aconn):
         pass
 
 
-async def test_named_savepoints_successful_exit(aconn, commands):
+async def test_named_savepoints_successful_exit(aconn, acommands):
     """
     Entering a transaction context will do one of these these things:
     1. Begin an outer transaction (if one isn't already in progress)
@@ -320,6 +314,8 @@ async def test_named_savepoints_successful_exit(aconn, commands):
 
     ...and exiting the context successfully will "commit" the same.
     """
+    commands = acommands
+
     # Case 1
     # Using Transaction explicitly becase conn.transaction() enters the contetx
     async with aconn.transaction() as tx:
@@ -363,12 +359,14 @@ async def test_named_savepoints_successful_exit(aconn, commands):
     assert commands.popall() == ["commit"]
 
 
-async def test_named_savepoints_exception_exit(aconn, commands):
+async def test_named_savepoints_exception_exit(aconn, acommands):
     """
     Same as the previous test but checks that when exiting the context with an
     exception, whatever transaction and/or savepoint was started on enter will
     be rolled-back as appropriate.
     """
+    commands = acommands
+
     # Case 1
     with pytest.raises(ExpectedException):
         async with aconn.transaction() as tx: