]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added Cursor.stream() to support fetch in single-row mode
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 22 Jan 2021 03:05:12 +0000 (04:05 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 22 Jan 2021 03:12:51 +0000 (04:12 +0100)
13 files changed:
docs/cursor.rst
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/generators.py
psycopg3/psycopg3/pq/_pq_ctypes.py
psycopg3/psycopg3/pq/_pq_ctypes.pyi
psycopg3/psycopg3/pq/pq_ctypes.py
psycopg3/psycopg3/pq/proto.py
psycopg3_c/psycopg3_c/pq/libpq.pxd
psycopg3_c/psycopg3_c/pq/pgconn.pyx
tests/pq/test_async.py
tests/pq/test_pgconn.py
tests/test_cursor.py
tests/test_cursor_async.py

index fa32973fbe76c430d2c301e046283f492e0487c6..0b307a083fac09f1904edf15045e99ebd2dce8ac 100644 (file)
@@ -85,6 +85,21 @@ The `!Cursor` class
 
         See :ref:`copy` for information about :sql:`COPY`.
 
+    .. automethod:: stream(query, params=None) -> Iterable[Sequence[Any]]
+
+        This command is similar to execute + iter; however it supports endless
+        data streams. The feature is not available in PostgreSQL, but some
+        implementations exist: Materialize `TAIL`__ and CockroachDB
+        `CHANGEFEED`__ for instance.
+
+        The feature, and the API supporting it, are still experimental.
+        Beware... ðŸ‘€
+
+        .. __: https://materialize.com/docs/sql/tail/#main
+        .. __: https://www.cockroachlabs.com/docs/stable/changefeed-for.html
+
+        The parameters are the same of `execute()`.
+
     .. attribute:: format
 
         The format of the data returned by the queries. It can be selected
index 3dff227ecfb18940982be2669be61531d4d218a5..abc5d2cd4d51040401fb8ef26131458c0aba1d5d 100644 (file)
@@ -7,12 +7,13 @@ psycopg3 cursor objects
 import sys
 from types import TracebackType
 from typing import Any, AsyncIterator, Callable, Generic, Iterator, List
-from typing import Optional, Sequence, Type, TYPE_CHECKING
+from typing import Optional, NoReturn, Sequence, Type, TYPE_CHECKING
 from contextlib import contextmanager
 
 from . import pq
 from . import adapt
 from . import errors as e
+from . import generators
 
 from .pq import ExecStatus, Format
 from .copy import Copy, AsyncCopy
@@ -39,8 +40,6 @@ if pq.__impl__ == "c":
     execute = _psycopg3.execute
 
 else:
-    from . import generators
-
     execute = generators.execute
 
 
@@ -245,6 +244,44 @@ class BaseCursor(Generic[ConnectionType]):
 
         self._execute_results(results)
 
+    def _stream_send_gen(
+        self, query: Query, params: Optional[Params] = None
+    ) -> PQGen[None]:
+        """Generator to send the query for `Cursor.stream()`."""
+        yield from self._start_query(query)
+        pgq = self._convert_query(query, params)
+        self._execute_send(pgq, no_pqexec=True)
+        self._conn.pgconn.set_single_row_mode()
+        self._last_query = query
+
+    def _stream_fetchone_gen(self) -> PQGen[Optional["PGresult"]]:
+        yield from generators.send(self._conn.pgconn)
+        res = yield from generators.fetch(self._conn.pgconn)
+        if res is None:
+            return None
+
+        elif res.status == ExecStatus.SINGLE_TUPLE:
+            self.pgresult = res  # will set it on the transformer too
+            # TODO: the transformer may do excessive work here: create a
+            # path that doesn't clear the loaders every time.
+            return res
+
+        elif res.status in (ExecStatus.TUPLES_OK, ExecStatus.COMMAND_OK):
+            # End of single row results
+            status = res.status
+            while res:
+                res = yield from generators.fetch(self._conn.pgconn)
+            if status != ExecStatus.TUPLES_OK:
+                raise e.ProgrammingError(
+                    "the operation in stream() didn't produce a result"
+                )
+            return None
+
+        else:
+            # Errors, unexpected values
+            self._raise_from_results([res])
+            return None  # TODO: shouldn't be needed
+
     def _start_query(self, query: Optional[Query] = None) -> PQGen[None]:
         """Generator to start the processing of a query.
 
@@ -323,7 +360,7 @@ class BaseCursor(Generic[ConnectionType]):
 
         for res in results:
             if res.status not in self._status_ok:
-                return self._raise_from_results(results)
+                self._raise_from_results(results)
 
         self._results = list(results)
         self.pgresult = results[0]
@@ -336,7 +373,7 @@ class BaseCursor(Generic[ConnectionType]):
 
         return
 
-    def _raise_from_results(self, results: Sequence["PGresult"]) -> None:
+    def _raise_from_results(self, results: Sequence["PGresult"]) -> NoReturn:
         statuses = {res.status for res in results}
         badstats = statuses.difference(self._status_ok)
         if results[-1].status == ExecStatus.FATAL_ERROR:
@@ -345,7 +382,7 @@ class BaseCursor(Generic[ConnectionType]):
             )
         elif statuses.intersection(self._status_copy):
             raise e.ProgrammingError(
-                "COPY cannot be used with execute(); use copy() insead"
+                "COPY cannot be used with this method; use copy() insead"
             )
         else:
             raise e.InternalError(
@@ -439,6 +476,19 @@ class Cursor(BaseCursor["Connection"]):
         with self._conn.lock:
             self._conn.wait(self._executemany_gen(query, params_seq))
 
+    def stream(
+        self, query: Query, params: Optional[Params] = None
+    ) -> Iterator[Sequence[Any]]:
+        """
+        Iterate row-by-row on a result from the database.
+        """
+        with self._conn.lock:
+            self._conn.wait(self._stream_send_gen(query, params))
+            while self._conn.wait(self._stream_fetchone_gen()):
+                rec = self._tx.load_row(0)
+                assert rec is not None
+                yield rec
+
     def fetchone(self) -> Optional[Sequence[Any]]:
         """
         Return the next record from the current recordset.
@@ -539,6 +589,16 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
         async with self._conn.lock:
             await self._conn.wait(self._executemany_gen(query, params_seq))
 
+    async def stream(
+        self, query: Query, params: Optional[Params] = None
+    ) -> AsyncIterator[Sequence[Any]]:
+        async with self._conn.lock:
+            await self._conn.wait(self._stream_send_gen(query, params))
+            while await self._conn.wait(self._stream_fetchone_gen()):
+                rec = self._tx.load_row(0)
+                assert rec is not None
+                yield rec
+
     async def fetchone(self) -> Optional[Sequence[Any]]:
         self._check_result()
         rv = self._tx.load_row(self._pos)
index 899e1f40ec5a6d2c7554572479af1009e4b2ac2a..95b2994d45619fac6e974691d6053a0365440325 100644 (file)
@@ -71,7 +71,7 @@ def execute(pgconn: PGconn) -> PQGen[List[PGresult]]:
     or error).
     """
     yield from send(pgconn)
-    rv = yield from _fetch(pgconn)
+    rv = yield from fetch_many(pgconn)
     return rv
 
 
@@ -83,7 +83,7 @@ def send(pgconn: PGconn) -> PQGen[None]:
     similar. Flush the query and then return the result using nonblocking
     functions.
 
-    After this generator has finished you may want to cycle using `_fetch()`
+    After this generator has finished you may want to cycle using `fetch()`
     to retrieve the results available.
     """
     while 1:
@@ -94,12 +94,12 @@ def send(pgconn: PGconn) -> PQGen[None]:
         ready = yield Wait.RW
         if ready & Ready.R:
             # This call may read notifies: they will be saved in the
-            # PGconn buffer and passed to Python later, in `_fetch()`.
+            # PGconn buffer and passed to Python later, in `fetch()`.
             pgconn.consume_input()
         continue
 
 
-def _fetch(pgconn: PGconn) -> PQGen[List[PGresult]]:
+def fetch_many(pgconn: PGconn) -> PQGen[List[PGresult]]:
     """
     Generator retrieving results from the database without blocking.
 
@@ -108,28 +108,13 @@ def _fetch(pgconn: PGconn) -> PQGen[List[PGresult]]:
 
     Return the list of results returned by the database (whether success
     or error).
-
-    Note that this generator doesn't yield the socket number, which must have
-    been already sent in the sending part of the cycle.
     """
     results: List[PGresult] = []
     while 1:
-        pgconn.consume_input()
-        if pgconn.is_busy():
-            yield Wait.R
-            continue
-
-        # Consume notifies
-        while 1:
-            n = pgconn.notifies()
-            if n is None:
-                break
-            if pgconn.notify_handler:
-                pgconn.notify_handler(n)
-
-        res = pgconn.get_result()
-        if res is None:
+        res = yield from fetch(pgconn)
+        if not res:
             break
+
         results.append(res)
         if res.status in _copy_statuses:
             # After entering copy mode the libpq will create a phony result
@@ -139,6 +124,32 @@ def _fetch(pgconn: PGconn) -> PQGen[List[PGresult]]:
     return results
 
 
+def fetch(pgconn: PGconn) -> PQGen[Optional[PGresult]]:
+    """
+    Generator retrieving a single result from the database without blocking.
+
+    The query must have already been sent to the server, so pgconn.flush() has
+    already returned 0.
+
+    Return a result from the database (whether success or error).
+    """
+    while 1:
+        pgconn.consume_input()
+        if not pgconn.is_busy():
+            break
+        yield Wait.R
+
+    # Consume notifies
+    while 1:
+        n = pgconn.notifies()
+        if not n:
+            break
+        if pgconn.notify_handler:
+            pgconn.notify_handler(n)
+
+    return pgconn.get_result()
+
+
 _copy_statuses = (
     ExecStatus.COPY_IN,
     ExecStatus.COPY_OUT,
@@ -176,7 +187,7 @@ def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]:
         return data
 
     # Retrieve the final result of copy
-    (result,) = yield from _fetch(pgconn)
+    (result,) = yield from fetch_many(pgconn)
     if result.status != ExecStatus.COMMAND_OK:
         encoding = py_codecs.get(
             pgconn.parameter_status(b"client_encoding") or "", "utf-8"
@@ -205,7 +216,7 @@ def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
             break
 
     # Retrieve the final result of copy
-    (result,) = yield from _fetch(pgconn)
+    (result,) = yield from fetch_many(pgconn)
     if result.status != ExecStatus.COMMAND_OK:
         encoding = py_codecs.get(
             pgconn.parameter_status(b"client_encoding") or "", "utf-8"
index ed66fd5284f9e1e3db4f578530b735c814c10c7c..1f9846b97ce16c5e4379515425b9384b63ea1389 100644 (file)
@@ -499,6 +499,12 @@ PQflush.argtypes = [PGconn_ptr]
 PQflush.restype = c_int
 
 
+# 33.5. Retrieving Query Results Row-by-Row
+PQsetSingleRowMode = pq.PQsetSingleRowMode
+PQsetSingleRowMode.argtypes = [PGconn_ptr]
+PQsetSingleRowMode.restype = c_int
+
+
 # 33.6. Canceling Queries in Progress
 
 PQgetCancel = pq.PQgetCancel
index 1bc67c9daf8a07c654467e2b1ea31583ee57eee0..39485f91b2f96028f51c7344c76f6ea9d249e48a 100644 (file)
@@ -180,6 +180,7 @@ def PQisBusy(arg1: Optional[PGconn_struct]) -> int: ...
 def PQsetnonblocking(arg1: Optional[PGconn_struct], arg2: int) -> int: ...
 def PQisnonblocking(arg1: Optional[PGconn_struct]) -> int: ...
 def PQflush(arg1: Optional[PGconn_struct]) -> int: ...
+def PQsetSingleRowMode(arg1: Optional[PGconn_struct]) -> int: ...
 def PQgetCancel(arg1: Optional[PGconn_struct]) -> PGcancel_struct: ...
 def PQfreeCancel(arg1: Optional[PGcancel_struct]) -> None: ...
 def PQputCopyData(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> int: ...
index 127d82439eb75525d2dcafb0671e7b6eb04661cf..75d90155232240f9ab0c18b37f1416f0893e73e3 100644 (file)
@@ -500,6 +500,10 @@ class PGconn:
             raise PQerror(f"flushing failed: {error_message(self)}")
         return rv
 
+    def set_single_row_mode(self) -> None:
+        if not impl.PQsetSingleRowMode(self.pgconn_ptr):
+            raise PQerror("setting single row mode failed")
+
     def get_cancel(self) -> "PGcancel":
         """
         Create an object with the information needed to cancel a command.
index d2df18618a94a3c37f45b190886279bd3f7e42ff..a626519592f7c429c231a3f9692aee80f8894b8c 100644 (file)
@@ -214,6 +214,9 @@ class PGconn(Protocol):
     def flush(self) -> int:
         ...
 
+    def set_single_row_mode(self) -> None:
+        ...
+
     def get_cancel(self) -> "PGcancel":
         ...
 
index e572a0e8675a4306cd308c99ee0e21436e5ea2f8..dd51fd6f6a989fdc72b6e762ff77e732e37b8f56 100644 (file)
@@ -230,6 +230,9 @@ cdef extern from "libpq-fe.h":
     int PQisnonblocking(const PGconn *conn)
     int PQflush(PGconn *conn)
 
+    # 33.5. Retrieving Query Results Row-by-Row
+    int PQsetSingleRowMode(PGconn *conn)
+
     # 33.6. Canceling Queries in Progress
     PGcancel *PQgetCancel(PGconn *conn)
     void PQfreeCancel(PGcancel *cancel)
index dcc015a6bb37a752bb1580b4ef2ffa40867afb09..a979fcdebe6ccdc0913af016c09f973e8d234e8b 100644 (file)
@@ -409,6 +409,10 @@ cdef class PGconn:
             raise PQerror(f"flushing failed: {error_message(self)}")
         return rv
 
+    def set_single_row_mode(self) -> None:
+        if not libpq.PQsetSingleRowMode(self.pgconn_ptr):
+            raise PQerror("setting single row mode failed")
+
     def get_cancel(self) -> PGcancel:
         cdef libpq.PGcancel *ptr = libpq.PQgetCancel(self.pgconn_ptr)
         if not ptr:
index 820f6e85d799d116c13677cfe6bcd3b4311c3a9a..0e648b373df5504cd93b436c7ffade996482bcbf 100644 (file)
@@ -82,6 +82,28 @@ def test_send_query_compact_test(pgconn):
         pgconn.send_query(b"select 1")
 
 
+def test_single_row_mode(pgconn):
+    pgconn.send_query(b"select generate_series(1,2)")
+    pgconn.set_single_row_mode()
+
+    results = execute_wait(pgconn)
+    assert len(results) == 3
+
+    res = results[0]
+    assert res.status == pq.ExecStatus.SINGLE_TUPLE
+    assert res.ntuples == 1
+    assert res.get_value(0, 0) == b"1"
+
+    res = results[1]
+    assert res.status == pq.ExecStatus.SINGLE_TUPLE
+    assert res.ntuples == 1
+    assert res.get_value(0, 0) == b"2"
+
+    res = results[2]
+    assert res.status == pq.ExecStatus.TUPLES_OK
+    assert res.ntuples == 0
+
+
 def test_send_query_params(pgconn):
     pgconn.send_query_params(b"select $1::int + $2", [b"5", b"3"])
     (res,) = execute_wait(pgconn)
index 7eabfa5e13680a267bcd85b187a045600ee8283b..e122cd0922feef35341dc220f45bf5ff9714977f 100644 (file)
@@ -356,6 +356,14 @@ def test_ssl_in_use(pgconn):
         pgconn.ssl_in_use
 
 
+def test_set_single_row_mode(pgconn):
+    with pytest.raises(pq.PQerror):
+        pgconn.set_single_row_mode()
+
+    pgconn.send_query(b"select 1")
+    pgconn.set_single_row_mode()
+
+
 def test_cancel(pgconn):
     cancel = pgconn.get_cancel()
     cancel.cancel()
index 5b9a3e57bbe09522d2922fad97b1606571f9ac53..b2d905b1bf5b39a1dd8a6e5283820bf73aa71b21 100644 (file)
@@ -1,10 +1,12 @@
 import gc
 import pickle
 import weakref
+import datetime as dt
 
 import pytest
 
 import psycopg3
+from psycopg3 import sql
 from psycopg3.oids import builtins
 from psycopg3.adapt import Format
 
@@ -296,6 +298,46 @@ def test_query_params_executemany(conn):
     # assert cur.params == [b"x"]
 
 
+def test_stream(conn):
+    cur = conn.cursor()
+    recs = []
+    for rec in cur.stream(
+        "select i, '2021-01-01'::date + i from generate_series(1, %s) as i",
+        [2],
+    ):
+        recs.append(rec)
+
+    assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]
+
+
+def test_stream_sql(conn):
+    cur = conn.cursor()
+    recs = list(
+        cur.stream(
+            sql.SQL(
+                "select i, '2021-01-01'::date + i from generate_series(1, {}) as i"
+            ).format(2)
+        )
+    )
+
+    assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]
+
+
+@pytest.mark.parametrize(
+    "query",
+    [
+        "create table test_stream_badq ()",
+        "copy (select 1) to stdout",
+        "wat?",
+    ],
+)
+def test_stream_badquery(conn, query):
+    cur = conn.cursor()
+    with pytest.raises(psycopg3.ProgrammingError):
+        for rec in cur.stream(query):
+            pass
+
+
 class TestColumn:
     def test_description_attribs(self, conn):
         curs = conn.cursor()
index e73b9ea08683bf5557c251a0dfd9f5e913a28854..6285aa5b57157ebc9375d0e0741f465a1e747ed7 100644 (file)
@@ -1,8 +1,10 @@
 import gc
 import pytest
 import weakref
+import datetime as dt
 
 import psycopg3
+from psycopg3 import sql
 from psycopg3.adapt import Format
 
 pytestmark = pytest.mark.asyncio
@@ -301,6 +303,46 @@ async def test_query_params_executemany(aconn):
     # assert cur.params == [b"x"]
 
 
+async def test_stream(aconn):
+    cur = await aconn.cursor()
+    recs = []
+    async for rec in cur.stream(
+        "select i, '2021-01-01'::date + i from generate_series(1, %s) as i",
+        [2],
+    ):
+        recs.append(rec)
+
+    assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]
+
+
+async def test_stream_sql(aconn):
+    cur = await aconn.cursor()
+    recs = []
+    async for rec in cur.stream(
+        sql.SQL(
+            "select i, '2021-01-01'::date + i from generate_series(1, {}) as i"
+        ).format(2)
+    ):
+        recs.append(rec)
+
+    assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]
+
+
+@pytest.mark.parametrize(
+    "query",
+    [
+        "create table test_stream_badq ()",
+        "copy (select 1) to stdout",
+        "wat?",
+    ],
+)
+async def test_stream_badquery(aconn, query):
+    cur = await aconn.cursor()
+    with pytest.raises(psycopg3.ProgrammingError):
+        async for rec in cur.stream(query):
+            pass
+
+
 async def test_str(aconn):
     cur = await aconn.cursor()
     assert "[IDLE]" in str(cur)