]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: add a size parameter to Cursor.stream() 794/head
authorDenis Laxalde <denis.laxalde@dalibo.com>
Wed, 17 Apr 2024 09:38:43 +0000 (11:38 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 13 Jun 2024 14:05:36 +0000 (16:05 +0200)
This triggers results retrieval by chunks from the server, if > 1,
leveraging the "chunked rows mode" from libpq 17.

A new has_stream_chunked() capability is added.

docs/api/cursors.rst
docs/api/objects.rst
docs/news.rst
psycopg/psycopg/_capabilities.py
psycopg/psycopg/_cursor_base.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
psycopg/psycopg/rows.py
tests/test_capabilities.py
tests/test_cursor_common.py
tests/test_cursor_common_async.py

index 9abb3f989c3c19aa447f7b2cea58902f53ae2248..1d0ae445d87a8d3301b940d0f507a70e2fa544ae 100644 (file)
@@ -180,7 +180,14 @@ The `!Cursor` class
         .. __: https://materialize.com/docs/sql/subscribe/
         .. __: https://www.cockroachlabs.com/docs/stable/changefeed-for.html
 
-        The parameters are the same of `execute()`.
+        The parameters are the same of `execute()`, except for `size` which
+        can be used to set results retrieval by chunks instead of row-by-row.
+
+        .. note::
+
+            This `size` parameter is only available from libpq 17, you can use
+            the `~Capabilities.has_stream_chunked` capability to check if this
+            is supported.
 
         .. warning::
 
index 33b080835ed7e9b93accc67cd11d14e0cfac83db..a94f08b5ba1abb73fb98d971fa89cf99f209d781 100644 (file)
@@ -157,6 +157,7 @@ Libpq capabilities information
             The `!cancel_safe()` method is implemented anyway, but it will use
             the legacy :pq:`PQcancel` implementation.
 
+    .. automethod:: has_stream_chunked
     .. automethod:: has_pgbouncer_prepared
 
         .. seealso:: :ref:`pgbouncer`
index 1797feef2ca18c79648aef7cb180065e62b51a13..87649946dd93b2ef1b13756981b2ceb79c9ac0d9 100644 (file)
@@ -43,6 +43,8 @@ Psycopg 3.2 (unreleased)
   termination (:ticket:`#754`).
 - Add support for libpq function to retrieve results in chunks introduced in
   libpq v17 (:ticket:`#793`).
+- Add a `size` parameter to `~Cursor.stream()` to enable results retrieval by
+  chunks instead of row-by-row (:ticket:`#794`).
 - Add support for libpq function to change role passwords introduced in
   libpq v17 (:ticket:`#818`).
 
index 491b8c79dc12f27d3a9214a004b396cbf7dc41cf..2af756330458072041f377c85a153d423c048faa 100644 (file)
@@ -54,6 +54,16 @@ class Capabilities:
         """
         return self._has_feature("Connection.cancel_safe()", 170000, check=check)
 
+    def has_stream_chunked(self, check: bool = False) -> bool:
+        """Check if `Cursor.stream()` can handle a `size` parameter value
+        greater than 1 to retrieve results by chunks.
+
+        The feature requires libpq 17.0 and greater.
+        """
+        return self._has_feature(
+            "Cursor.stream() with 'size' parameter greater than 1", 170000, check=check
+        )
+
     def has_pgbouncer_prepared(self, check: bool = False) -> bool:
         """Check if prepared statements in PgBouncer are supported.
 
index 9448ec505d3a2cb1ddb5bb34b680c0a1c0d8ca44..833666bbaf98fb937474889a74c9ab2e4d68f9db 100644 (file)
@@ -15,6 +15,7 @@ from . import adapt
 from . import errors as e
 from .abc import ConnectionType, Query, Params, PQGen
 from .rows import Row, RowMaker
+from ._capabilities import capabilities
 from ._column import Column
 from .pq.misc import connection_summary
 from ._queries import PostgresQuery, PostgresClientQuery
@@ -36,6 +37,7 @@ COPY_IN = pq.ExecStatus.COPY_IN
 COPY_BOTH = pq.ExecStatus.COPY_BOTH
 FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
 SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
+TUPLES_CHUNK = pq.ExecStatus.TUPLES_CHUNK
 PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
 
 ACTIVE = pq.TransactionStatus.ACTIVE
@@ -116,7 +118,10 @@ class BaseCursor(Generic[ConnectionType, Row]):
         # the query said we got tuples (mostly to handle the super useful
         # query "SELECT ;"
         if res and (
-            res.nfields or res.status == TUPLES_OK or res.status == SINGLE_TUPLE
+            res.nfields
+            or res.status == TUPLES_OK
+            or res.status == SINGLE_TUPLE
+            or res.status == TUPLES_CHUNK
         ):
             return [Column(self, i) for i in range(res.nfields)]
         else:
@@ -314,12 +319,19 @@ class BaseCursor(Generic[ConnectionType, Row]):
         params: Params | None = None,
         *,
         binary: bool | None = None,
+        size: int,
     ) -> PQGen[None]:
         """Generator to send the query for `Cursor.stream()`."""
         yield from self._start_query(query)
         pgq = self._convert_query(query, params)
         self._execute_send(pgq, binary=binary, force_extended=True)
-        self._pgconn.set_single_row_mode()
+        if size < 1:
+            raise ValueError("size must be >= 1")
+        elif size == 1:
+            self._pgconn.set_single_row_mode()
+        else:
+            capabilities.has_stream_chunked(check=True)
+            self._pgconn.set_chunked_rows_mode(size)
         self._last_query = query
         yield from send(self._pgconn)
 
@@ -329,7 +341,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
             return None
 
         status = res.status
-        if status == SINGLE_TUPLE:
+        if status == SINGLE_TUPLE or status == TUPLES_CHUNK:
             self.pgresult = res
             self._tx.set_pgresult(res, set_loaders=first)
             if first:
index 6d1ddf01928832237702d533297bbaf9c5cc4512..0415ff319e61ce235579e5dce0fe636bed83a5bb 100644 (file)
@@ -128,22 +128,35 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
             raise ex.with_traceback(None)
 
     def stream(
-        self, query: Query, params: Params | None = None, *, binary: bool | None = None
+        self,
+        query: Query,
+        params: Params | None = None,
+        *,
+        binary: bool | None = None,
+        size: int = 1,
     ) -> Iterator[Row]:
         """
         Iterate row-by-row on a result from the database.
+
+        :param size: if greater than 1, results will be retrieved by chunks of
+            this size from the server (but still yielded row-by-row); this is only
+            available from version 17 of the libpq.
         """
         if self._pgconn.pipeline_status:
             raise e.ProgrammingError("stream() cannot be used in pipeline mode")
 
         with self._conn.lock:
             try:
-                self._conn.wait(self._stream_send_gen(query, params, binary=binary))
+                self._conn.wait(
+                    self._stream_send_gen(query, params, binary=binary, size=size)
+                )
                 first = True
                 while self._conn.wait(self._stream_fetchone_gen(first)):
-                    # We know that, if we got a result, it has a single row.
-                    rec: Row = self._tx.load_row(0, self._make_row)  # type: ignore
-                    yield rec
+                    for pos in range(size):
+                        rec = self._tx.load_row(pos, self._make_row)
+                        if rec is None:
+                            break
+                        yield rec
                     first = False
             except e._NO_TRACEBACK as ex:
                 raise ex.with_traceback(None)
index b708d5d6c6960284831bc3124a2c9639fbd61aa6..7fddd628051de9187d7d4eb685a2433bdd77464c 100644 (file)
@@ -132,10 +132,19 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
             raise ex.with_traceback(None)
 
     async def stream(
-        self, query: Query, params: Params | None = None, *, binary: bool | None = None
+        self,
+        query: Query,
+        params: Params | None = None,
+        *,
+        binary: bool | None = None,
+        size: int = 1,
     ) -> AsyncIterator[Row]:
         """
         Iterate row-by-row on a result from the database.
+
+        :param size: if greater than 1, results will be retrieved by chunks of
+            this size from the server (but still yielded row-by-row); this is only
+            available from version 17 of the libpq.
         """
         if self._pgconn.pipeline_status:
             raise e.ProgrammingError("stream() cannot be used in pipeline mode")
@@ -143,13 +152,15 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         async with self._conn.lock:
             try:
                 await self._conn.wait(
-                    self._stream_send_gen(query, params, binary=binary)
+                    self._stream_send_gen(query, params, binary=binary, size=size)
                 )
                 first = True
                 while await self._conn.wait(self._stream_fetchone_gen(first)):
-                    # We know that, if we got a result, it has a single row.
-                    rec: Row = self._tx.load_row(0, self._make_row)  # type: ignore
-                    yield rec
+                    for pos in range(size):
+                        rec = self._tx.load_row(pos, self._make_row)
+                        if rec is None:
+                            break
+                        yield rec
                     first = False
 
             except e._NO_TRACEBACK as ex:
index cb67f7f0781bb56bedbbce86025369926d3cf69c..db6f5c86aba9c826fc8234785ec720b527096631 100644 (file)
@@ -25,6 +25,7 @@ if TYPE_CHECKING:
 COMMAND_OK = pq.ExecStatus.COMMAND_OK
 TUPLES_OK = pq.ExecStatus.TUPLES_OK
 SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
+TUPLES_CHUNK = pq.ExecStatus.TUPLES_CHUNK
 
 T = TypeVar("T", covariant=True)
 
@@ -265,6 +266,7 @@ def _get_nfields(res: PGresult) -> int | None:
     if (
         res.status == TUPLES_OK
         or res.status == SINGLE_TUPLE
+        or res.status == TUPLES_CHUNK
         # "describe" in named cursors
         or (res.status == COMMAND_OK and nfields)
     ):
index 2d27e62ec92fc97227590962036557f1a5e2d42f..6f2c8fce92501cf8e7d3996ac7ad768e0c4f834d 100644 (file)
@@ -16,6 +16,7 @@ caps = [
     ("has_pipeline", "Connection.pipeline()", 14),
     ("has_set_trace_flags", "PGconn.set_trace_flags()", 14),
     ("has_cancel_safe", "Connection.cancel_safe()", 17),
+    ("has_stream_chunked", "Cursor.stream() with 'size' parameter greater than 1", 17),
     ("has_pgbouncer_prepared", "PgBouncer prepared statements compatibility", 17),
 ]
 
index 2fdf1f8eda6ecd583f801613c298ae76b3cecf21..7b3d9d5235599d67bb394d87470c419467812493 100644 (file)
@@ -693,6 +693,35 @@ def test_stream_no_row(conn):
     assert recs == []
 
 
+def test_stream_chunked_invalid_size(conn):
+    cur = conn.cursor()
+    with pytest.raises(ValueError, match="size must be >= 1"):
+        next(cur.stream("select 1", size=0))
+
+
+@pytest.mark.libpq("< 17")
+def test_stream_chunked_not_supported(conn):
+    cur = conn.cursor()
+    with pytest.raises(psycopg.NotSupportedError):
+        next(cur.stream("select generate_series(1, 4)", size=2))
+
+
+@pytest.mark.libpq(">= 17")
+def test_stream_chunked(conn):
+    cur = conn.cursor()
+    recs = list(cur.stream("select generate_series(1, 5) as a", size=2))
+    assert recs == [(1,), (2,), (3,), (4,), (5,)]
+
+
+@pytest.mark.libpq(">= 17")
+def test_stream_chunked_row_factory(conn):
+    cur = conn.cursor(row_factory=rows.scalar_row)
+    it = cur.stream("select generate_series(1, 5) as a", size=2)
+    for i in range(1, 6):
+        assert next(it) == i
+        assert [c.name for c in cur.description] == ["a"]
+
+
 @pytest.mark.crdb_skip("no col query")
 def test_stream_no_col(conn):
     cur = conn.cursor()
index 0e05a92497e76785673f29586cc7763452d86e4f..6421357e76a28a3d58701d3b85f6ecd18967cc8c 100644 (file)
@@ -695,6 +695,35 @@ async def test_stream_no_row(aconn):
     assert recs == []
 
 
+async def test_stream_chunked_invalid_size(aconn):
+    cur = aconn.cursor()
+    with pytest.raises(ValueError, match=r"size must be >= 1"):
+        await anext(cur.stream("select 1", size=0))
+
+
+@pytest.mark.libpq("< 17")
+async def test_stream_chunked_not_supported(aconn):
+    cur = aconn.cursor()
+    with pytest.raises(psycopg.NotSupportedError):
+        await anext(cur.stream("select generate_series(1, 4)", size=2))
+
+
+@pytest.mark.libpq(">= 17")
+async def test_stream_chunked(aconn):
+    cur = aconn.cursor()
+    recs = await alist(cur.stream("select generate_series(1, 5) as a", size=2))
+    assert recs == [(1,), (2,), (3,), (4,), (5,)]
+
+
+@pytest.mark.libpq(">= 17")
+async def test_stream_chunked_row_factory(aconn):
+    cur = aconn.cursor(row_factory=rows.scalar_row)
+    it = cur.stream("select generate_series(1, 5) as a", size=2)
+    for i in range(1, 6):
+        assert await anext(it) == i
+        assert [c.name for c in cur.description] == ["a"]
+
+
 @pytest.mark.crdb_skip("no col query")
 async def test_stream_no_col(aconn):
     cur = aconn.cursor()