]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add support for pipeline mode in execute()/fetch*()
authorDenis Laxalde <denis.laxalde@dalibo.com>
Mon, 11 Oct 2021 15:16:39 +0000 (17:16 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Apr 2022 23:17:57 +0000 (01:17 +0200)
When activated on the connection, a pipeline active pipeline handles
a queue of commands to send and a queue of results to process.

The command queue simply contains Callable[[], None], which are built
from partial applications of pgconn.send_*() methods et al.

The queue of results to process either contains a None, when respective
command returns no tuple, or a tuple with respective cursor and query
information needed to maintain automatic prepared statement.

Everywhere we run the execute() generator in non-pipeline mode, we now
enqueue items in the pipeline queues. Then we run
pipeline_communicate(), through the _communicate_gen() method of
BasePipeline, in BaseCursor._execute(many)_gen().

Since pipeline_communicate() may not fetch all results, we need a
dedicated fetch (forced) step upon call to cursor.fetch*(); this is done
by Cursor._fetch_pipeline() called in fetch*() methods. This calls
PQsendFlushRequest() in order to avoid blocking on PQgetResult().

At exit of pipeline mode, we unconditionally emit a PQpipelineSync()
call in order to restore the connection in a usable state in case of
error and we force results fetch after sending any pending commands
(e.g. commands not emitted through an execute() call).

The pipeline-demo.py test script is updated to include examples using
the high-level API. This only works with the 'python' of libpq bindings
because we monkeypatch the pgconn attribute of the connection.

psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
psycopg/psycopg/server_cursor.py
tests/scripts/pipeline-demo.py
tests/test_pipeline.py
tests/test_pipeline_async.py

index 2fa4ecf0ac52feaf51edb900f6719a2cfd631c19..c3b1c44c14700bcd620e82a72f5beb4a365bb66d 100644 (file)
@@ -20,7 +20,7 @@ from . import errors as e
 from . import waiting
 from . import postgres
 from .pq import ConnStatus, ExecStatus, TransactionStatus, Format
-from .abc import AdaptContext, ConnectionType, Params, Query, RV
+from .abc import AdaptContext, Command, ConnectionType, Params, Query, RV
 from .abc import PQGen, PQGenConn
 from .sql import Composable, SQL
 from ._tpc import Xid
@@ -28,17 +28,18 @@ from .rows import Row, RowFactory, tuple_row, TupleRow, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from .cursor import Cursor
-from ._compat import TypeAlias
+from ._compat import Deque, TypeAlias
 from ._cmodule import _psycopg
 from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
 from .generators import notifies
 from ._encodings import pgconn_encoding
-from ._preparing import PrepareManager
+from ._preparing import Key, Prepare, PrepareManager
 from .transaction import Transaction
 from .server_cursor import ServerCursor
 
 if TYPE_CHECKING:
     from .pq.abc import PGconn, PGresult
+    from .cursor import BaseCursor
     from psycopg_pool.base import BasePool
 
 logger = logging.getLogger("psycopg")
@@ -53,12 +54,18 @@ CursorRow = TypeVar("CursorRow")
 if _psycopg:
     connect = _psycopg.connect
     execute = _psycopg.execute
+    fetch_many = _psycopg.fetch_many
+    pipeline_communicate = _psycopg.pipeline_communicate
+    send = _psycopg.send
 
 else:
     from . import generators
 
     connect = generators.connect
     execute = generators.execute
+    fetch_many = generators.fetch_many
+    pipeline_communicate = generators.pipeline_communicate
+    send = generators.send
 
 
 class Notify(NamedTuple):
@@ -79,21 +86,103 @@ Notify.__module__ = "psycopg"
 NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None]
 NotifyHandler: TypeAlias = Callable[[Notify], None]
 
+PipelinePendingResult = Union[
+    None,
+    Tuple[
+        "BaseCursor[Any, Any]",
+        Optional[Tuple[Key, Prepare, bytes]],
+    ],
+]
+
 
 class BasePipeline:
     def __init__(self, pgconn: "PGconn") -> None:
         self.pgconn = pgconn
+        self.command_queue = Deque[Command]()
+        self.result_queue = Deque[PipelinePendingResult]()
 
     @property
     def status(self) -> pq.PipelineStatus:
         return pq.PipelineStatus(self.pgconn.pipeline_status)
 
+    def sync(self) -> None:
+        """Enqueue a PQpipelineSync() command."""
+        self.command_queue.append(self.pgconn.pipeline_sync)
+        self.result_queue.append(None)
+
     def _enter(self) -> None:
         self.pgconn.enter_pipeline_mode()
 
     def _exit(self) -> None:
         self.pgconn.exit_pipeline_mode()
 
+    def _communicate_gen(self) -> PQGen[None]:
+        """Communicate with pipeline to send commands and possibly fetch
+        results, which are then processed.
+        """
+        fetched = yield from pipeline_communicate(self.pgconn, self.command_queue)
+        to_process = [(self.result_queue.popleft(), results) for results in fetched]
+        for queued, results in to_process:
+            self._process_results(queued, results)
+
+    def _fetch_gen(self, *, flush: bool) -> PQGen[None]:
+        """Fetch available results from the connection and process them with
+        pipeline queued items.
+
+        If 'flush' is True, a PQsendFlushRequest() is issued in order to make
+        sure results can be fetched. Otherwise, the caller may emit a
+        PQpipelineSync() call to ensure the output buffer gets flushed before
+        fetching.
+        """
+        if not self.result_queue:
+            return
+
+        if flush:
+            self.pgconn.send_flush_request()
+            yield from send(self.pgconn)
+
+        to_process = []
+        while self.result_queue:
+            results = yield from fetch_many(self.pgconn)
+            if not results:
+                # No more results to fetch, but there may still be pending
+                # commands.
+                break
+            queued = self.result_queue.popleft()
+            to_process.append((queued, results))
+
+        for queued, results in to_process:
+            self._process_results(queued, results)
+
+    def _process_results(
+        self, queued: PipelinePendingResult, results: List["PGresult"]
+    ) -> None:
+        """Process a results set fetched from the current pipeline.
+
+        This matchs 'results' with its respective element in the pipeline
+        queue. For commands (None value in the pipeline queue), results are
+        checked directly. For prepare statement creation requests, update the
+        cache. Otherwise, results are attached to their respective cursor.
+        """
+        if queued is None:
+            (result,) = results
+            if result.status == ExecStatus.FATAL_ERROR:
+                raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
+            elif result.status == ExecStatus.PIPELINE_ABORTED:
+                raise e.OperationalError("pipeline aborted")
+        else:
+            cursor, prepinfo = queued
+            cursor._check_results(results)
+            if not cursor._results:
+                cursor._results = results
+                cursor._set_current_result(0)
+            else:
+                cursor._results.extend(results)
+            if prepinfo:
+                key, prep, name = prepinfo
+                # Update the prepare state of the query.
+                cursor._conn._prepared.validate(key, prep, name, results)
+
 
 class Pipeline(BasePipeline):
     """Handler for connection in pipeline mode."""
@@ -444,7 +533,7 @@ class BaseConnection(Generic[Row]):
 
     def _exec_command(
         self, command: Query, result_format: Format = Format.TEXT
-    ) -> PQGen["PGresult"]:
+    ) -> PQGen[Optional["PGresult"]]:
         """
         Generator to send a command and receive the result to the backend.
 
@@ -458,6 +547,20 @@ class BaseConnection(Generic[Row]):
         elif isinstance(command, Composable):
             command = command.as_bytes(self)
 
+        if self._pipeline:
+            if result_format == Format.TEXT:
+                cmd = partial(self.pgconn.send_query, command)
+            else:
+                cmd = partial(
+                    self.pgconn.send_query_params,
+                    command,
+                    None,
+                    result_format=result_format,
+                )
+            self._pipeline.command_queue.append(cmd)
+            self._pipeline.result_queue.append(None)
+            return None
+
         if result_format == Format.TEXT:
             self.pgconn.send_query(command)
         else:
@@ -895,7 +998,19 @@ class Connection(BaseConnection[Row]):
         pipeline = self._pipeline = Pipeline(self.pgconn)
         try:
             with pipeline:
-                yield
+                try:
+                    yield
+                finally:
+                    with self.lock:
+                        pipeline.sync()
+                        try:
+                            # Send an pending commands (e.g. COMMIT or Sync);
+                            # while processing results, we might get errors...
+                            self.wait(pipeline._communicate_gen())
+                        finally:
+                            # then fetch all remaining results but without forcing
+                            # flush since we emitted a sync just before.
+                            self.wait(pipeline._fetch_gen(flush=False))
         finally:
             assert pipeline.status == pq.PipelineStatus.OFF, pipeline.status
             self._pipeline = None
index daa4f157cdac29002421d2e00eefb5a8941c52d3..5208c8b245c1457f0db4a38ab5b955e15ae0f272 100644 (file)
@@ -319,7 +319,19 @@ class AsyncConnection(BaseConnection[Row]):
         pipeline = self._pipeline = AsyncPipeline(self.pgconn)
         try:
             async with pipeline:
-                yield
+                try:
+                    yield
+                finally:
+                    async with self.lock:
+                        pipeline.sync()
+                        try:
+                            # Send an pending commands (e.g. COMMIT or Sync);
+                            # while processing results, we might get errors...
+                            await self.wait(pipeline._communicate_gen())
+                        finally:
+                            # then fetch all remaining results but without forcing
+                            # flush since we emitted a sync just before.
+                            await self.wait(pipeline._fetch_gen(flush=False))
         finally:
             assert pipeline.status == PipelineStatus.OFF, pipeline.status
             self._pipeline = None
index 7709454d1c4a7bf067d7d115ca2891fcfaa5f3e0..40c57241b9a4bfd39100adfc196aaa50be6ce649 100644 (file)
@@ -4,6 +4,7 @@ psycopg cursor objects
 
 # Copyright (C) 2020 The Psycopg Team
 
+from functools import partial
 from types import TracebackType
 from typing import Any, Generic, Iterable, Iterator, List
 from typing import Optional, NoReturn, Sequence, Type, TypeVar, TYPE_CHECKING
@@ -194,9 +195,14 @@ class BaseCursor(Generic[ConnectionType, Row]):
         results = yield from self._maybe_prepare_gen(
             pgq, prepare=prepare, binary=binary
         )
-        self._check_results(results)
-        self._results = results
-        self._set_current_result(0)
+        if self._conn._pipeline:
+            yield from self._conn._pipeline._communicate_gen()
+        else:
+            assert results is not None
+            self._check_results(results)
+            self._results = results
+            self._set_current_result(0)
+
         self._last_query = query
 
         for cmd in self._conn._prepared.get_maintenance_commands():
@@ -218,20 +224,28 @@ class BaseCursor(Generic[ConnectionType, Row]):
                 pgq.dump(params)
 
             results = yield from self._maybe_prepare_gen(pgq, prepare=True)
-            self._check_results(results)
-            if returning and results[0].status == ExecStatus.TUPLES_OK:
-                self._results.extend(results)
 
-            for res in results:
-                nrows += res.command_tuples or 0
+            if self._conn._pipeline:
+                yield from self._conn._pipeline._communicate_gen()
+            else:
+                assert results is not None
+                self._check_results(results)
+                if returning and results[0].status == ExecStatus.TUPLES_OK:
+                    self._results.extend(results)
 
-        if self._results:
-            self._set_current_result(0)
+                for res in results:
+                    nrows += res.command_tuples or 0
+
+        if not self._conn._pipeline:
+            if self._results:
+                self._set_current_result(0)
+
+            # Override rowcount for the first result. Calls to nextset() will
+            # change it to the value of that result only, but we hope nobody
+            # will notice.
+            # You haven't read this comment.
+            self._rowcount = nrows
 
-        # Override rowcount for the first result. Calls to nextset() will change
-        # it to the value of that result only, but we hope nobody will notice.
-        # You haven't read this comment.
-        self._rowcount = nrows
         self._last_query = query
 
         for cmd in self._conn._prepared.get_maintenance_commands():
@@ -243,7 +257,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         *,
         prepare: Optional[bool] = None,
         binary: Optional[bool] = None,
-    ) -> PQGen[List["PGresult"]]:
+    ) -> PQGen[Optional[List["PGresult"]]]:
         # Check if the query is prepared or needs preparing
         prep, name = self._conn._prepared.get(pgq, prepare)
         if prep is Prepare.NO:
@@ -253,19 +267,28 @@ class BaseCursor(Generic[ConnectionType, Row]):
             # If the query is not already prepared, prepare it.
             if prep is Prepare.SHOULD:
                 self._send_prepare(name, pgq)
-                (result,) = yield from execute(self._pgconn)
-                if result.status == ExecStatus.FATAL_ERROR:
-                    raise e.error_from_result(result, encoding=self._encoding)
+                if not self._conn._pipeline:
+                    (result,) = yield from execute(self._pgconn)
+                    if result.status == ExecStatus.FATAL_ERROR:
+                        raise e.error_from_result(result, encoding=self._encoding)
             # Then execute it.
             self._send_query_prepared(name, pgq, binary=binary)
 
-        # run the query
-        results = yield from execute(self._pgconn)
-
         # Update the prepare state of the query.
         # If an operation requires to flush our prepared statements cache,
         # it will be added to the maintenance commands to execute later.
         key = self._conn._prepared.maybe_add_to_cache(pgq, prep, name)
+
+        if self._conn._pipeline:
+            queued = None
+            if key is not None:
+                queued = (key, prep, name)
+            self._conn._pipeline.result_queue.append((self, queued))
+            return None
+
+        # run the query
+        results = yield from execute(self._pgconn)
+
         if key is not None:
             self._conn._prepared.validate(key, prep, name, results)
 
@@ -363,17 +386,34 @@ class BaseCursor(Generic[ConnectionType, Row]):
 
         self._query = query
         if query.params or no_pqexec or fmt == Format.BINARY:
-            self._pgconn.send_query_params(
-                query.query,
-                query.params,
-                param_formats=query.formats,
-                param_types=query.types,
-                result_format=fmt,
-            )
+            if self._conn._pipeline:
+                self._conn._pipeline.command_queue.append(
+                    partial(
+                        self._pgconn.send_query_params,
+                        query.query,
+                        query.params,
+                        param_formats=query.formats,
+                        param_types=query.types,
+                        result_format=fmt,
+                    )
+                )
+            else:
+                self._pgconn.send_query_params(
+                    query.query,
+                    query.params,
+                    param_formats=query.formats,
+                    param_types=query.types,
+                    result_format=fmt,
+                )
         else:
             # if we don't have to, let's use exec_ as it can run more than
             # one query in one go
-            self._pgconn.send_query(query.query)
+            if self._conn._pipeline:
+                self._conn._pipeline.command_queue.append(
+                    partial(self._pgconn.send_query, query.query)
+                )
+            else:
+                self._pgconn.send_query(query.query)
 
     def _convert_query(
         self, query: Query, params: Optional[Params] = None
@@ -442,7 +482,18 @@ class BaseCursor(Generic[ConnectionType, Row]):
         self._rowcount = nrows if nrows is not None else -1
 
     def _send_prepare(self, name: bytes, query: PostgresQuery) -> None:
-        self._pgconn.send_prepare(name, query.query, param_types=query.types)
+        if self._conn._pipeline:
+            self._conn._pipeline.command_queue.append(
+                partial(
+                    self._pgconn.send_prepare,
+                    name,
+                    query.query,
+                    param_types=query.types,
+                )
+            )
+            self._conn._pipeline.result_queue.append(None)
+        else:
+            self._pgconn.send_prepare(name, query.query, param_types=query.types)
 
     def _send_query_prepared(
         self, name: bytes, pgq: PostgresQuery, *, binary: Optional[bool] = None
@@ -452,9 +503,20 @@ class BaseCursor(Generic[ConnectionType, Row]):
         else:
             fmt = Format.BINARY if binary else Format.TEXT
 
-        self._pgconn.send_query_prepared(
-            name, pgq.params, param_formats=pgq.formats, result_format=fmt
-        )
+        if self._conn._pipeline:
+            self._conn._pipeline.command_queue.append(
+                partial(
+                    self._pgconn.send_query_prepared,
+                    name,
+                    pgq.params,
+                    param_formats=pgq.formats,
+                    result_format=fmt,
+                )
+            )
+        else:
+            self._pgconn.send_query_prepared(
+                name, pgq.params, param_formats=pgq.formats, result_format=fmt
+            )
 
     def _check_result_for_fetch(self) -> None:
         if self.closed:
@@ -610,6 +672,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
 
         :rtype: Optional[Row], with Row defined by `row_factory`
         """
+        self._fetch_pipeline()
         self._check_result_for_fetch()
         record = self._tx.load_row(self._pos, self._make_row)
         if record is not None:
@@ -624,6 +687,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
 
         :rtype: Sequence[Row], with Row defined by `row_factory`
         """
+        self._fetch_pipeline()
         self._check_result_for_fetch()
         assert self.pgresult
 
@@ -643,6 +707,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
 
         :rtype: Sequence[Row], with Row defined by `row_factory`
         """
+        self._fetch_pipeline()
         self._check_result_for_fetch()
         assert self.pgresult
         records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
@@ -650,6 +715,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         return records
 
     def __iter__(self) -> Iterator[Row]:
+        self._fetch_pipeline()
         self._check_result_for_fetch()
 
         def load(pos: int) -> Optional[Row]:
@@ -673,6 +739,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         Raise `!IndexError` in case a scroll operation would leave the result
         set. In this case the position will not change.
         """
+        self._fetch_pipeline()
         self._scroll(value, mode)
 
     @contextmanager
@@ -687,3 +754,9 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
 
         with Copy(self) as copy:
             yield copy
+
+    def _fetch_pipeline(self) -> None:
+        if not self.pgresult and self._conn._pipeline:
+            with self._conn.lock:
+                self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))
+            assert self.pgresult
index b1211e5aacdff827b1a9c72e34a035dde4b0aba8..75131430de20073ea78f54c7251911ac597ffb8c 100644 (file)
@@ -113,6 +113,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
                 first = False
 
     async def fetchone(self) -> Optional[Row]:
+        await self._fetch_pipeline()
         self._check_result_for_fetch()
         rv = self._tx.load_row(self._pos, self._make_row)
         if rv is not None:
@@ -120,6 +121,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         return rv
 
     async def fetchmany(self, size: int = 0) -> List[Row]:
+        await self._fetch_pipeline()
         self._check_result_for_fetch()
         assert self.pgresult
 
@@ -134,6 +136,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         return records
 
     async def fetchall(self) -> List[Row]:
+        await self._fetch_pipeline()
         self._check_result_for_fetch()
         assert self.pgresult
         records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
@@ -141,6 +144,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         return records
 
     async def __aiter__(self) -> AsyncIterator[Row]:
+        await self._fetch_pipeline()
         self._check_result_for_fetch()
 
         def load(pos: int) -> Optional[Row]:
@@ -166,3 +170,9 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
 
         async with AsyncCopy(self) as copy:
             yield copy
+
+    async def _fetch_pipeline(self) -> None:
+        if not self.pgresult and self._conn._pipeline:
+            async with self._conn.lock:
+                await self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))
+            assert self.pgresult
index 8439ecea182f17c32afc5b7430865269f51af485..9561297f1351562cbe60b952048da8ffa251b77b 100644 (file)
@@ -108,6 +108,8 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
                 "SELECT 1 FROM pg_catalog.pg_cursors WHERE name = {}"
             ).format(sql.Literal(self.name))
             res = yield from cur._conn._exec_command(query)
+            # pipeline mode otherwise, unsupported here.
+            assert res is not None
             if res.ntuples == 0:
                 return
 
@@ -129,6 +131,8 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
             sql.Identifier(self.name),
         )
         res = yield from cur._conn._exec_command(query, result_format=self._format)
+        # pipeline mode otherwise, unsupported here.
+        assert res is not None
 
         cur.pgresult = res
         cur._tx.set_pgresult(res, set_loaders=False)
index b8068476bfa5e0785419bc8bb5c01f09ea227611..c0636587d80d0fcb5c0e86d9e69b881d84a15c07 100644 (file)
@@ -203,6 +203,52 @@ async def pipeline_demo_pq_async(rows_to_send: int, logger: logging.Logger) -> N
                         raise e.error_from_result(r)
 
 
+def pipeline_demo(rows_to_send: int, logger: logging.Logger) -> None:
+    """Pipeline demo using sync API."""
+    conn = Connection.connect()
+    conn.autocommit = True
+    conn.pgconn = LoggingPGconn(conn.pgconn, logger)  # type: ignore[assignment]
+    with conn.pipeline():
+        with conn.transaction():
+            conn.execute("DROP TABLE IF EXISTS pq_pipeline_demo")
+            conn.execute(
+                "CREATE UNLOGGED TABLE pq_pipeline_demo("
+                " id serial primary key,"
+                " itemno integer,"
+                " int8filler int8"
+                ")"
+            )
+            for r in range(rows_to_send, 0, -1):
+                conn.execute(
+                    "INSERT INTO pq_pipeline_demo(itemno, int8filler)"
+                    " VALUES (%s, %s)",
+                    (r, 1 << 62),
+                )
+
+
+async def pipeline_demo_async(rows_to_send: int, logger: logging.Logger) -> None:
+    """Pipeline demo using async API."""
+    aconn = await AsyncConnection.connect()
+    await aconn.set_autocommit(True)
+    aconn.pgconn = LoggingPGconn(aconn.pgconn, logger)  # type: ignore[assignment]
+    async with aconn.pipeline():
+        async with aconn.transaction():
+            await aconn.execute("DROP TABLE IF EXISTS pq_pipeline_demo")
+            await aconn.execute(
+                "CREATE UNLOGGED TABLE pq_pipeline_demo("
+                " id serial primary key,"
+                " itemno integer,"
+                " int8filler int8"
+                ")"
+            )
+            for r in range(rows_to_send, 0, -1):
+                await aconn.execute(
+                    "INSERT INTO pq_pipeline_demo(itemno, int8filler)"
+                    " VALUES (%s, %s)",
+                    (r, 1 << 62),
+                )
+
+
 def main() -> None:
     parser = argparse.ArgumentParser()
     parser.add_argument(
@@ -213,6 +259,9 @@ def main() -> None:
         type=int,
         help="number of rows to insert",
     )
+    parser.add_argument(
+        "--pq", action="store_true", help="use low-level psycopg.pq API"
+    )
     parser.add_argument(
         "--async", dest="async_", action="store_true", help="use async API"
     )
@@ -228,10 +277,20 @@ def main() -> None:
     else:
         logger.addHandler(logging.StreamHandler())
         pipeline_logger.addHandler(logging.StreamHandler())
-    if args.async_:
-        asyncio.run(pipeline_demo_pq_async(args.nrows, pipeline_logger))
+    if args.pq:
+        if args.async_:
+            asyncio.run(pipeline_demo_pq_async(args.nrows, pipeline_logger))
+        else:
+            pipeline_demo_pq(args.nrows, pipeline_logger)
     else:
-        pipeline_demo_pq(args.nrows, pipeline_logger)
+        if pq.__impl__ != "python":
+            parser.error(
+                "only supported for Python implementation (set PSYCOPG_IMPL=python)"
+            )
+        if args.async_:
+            asyncio.run(pipeline_demo_async(args.nrows, pipeline_logger))
+        else:
+            pipeline_demo(args.nrows, pipeline_logger)
 
 
 if __name__ == "__main__":
index 31662dd76c259895467b26592e621db6c3562208..99f4046634f28a1980b2a1f0f6bead8549b2cee7 100644 (file)
@@ -1,8 +1,15 @@
+import concurrent.futures
+
 import pytest
 
 import psycopg
 from psycopg import pq
-from psycopg.errors import ProgrammingError
+from psycopg.errors import (
+    OperationalError,
+    ProgrammingError,
+    UndefinedColumn,
+    UndefinedTable,
+)
 
 pytestmark = pytest.mark.libpq(">= 14")
 
@@ -29,3 +36,220 @@ def test_server_cursor(conn):
     with conn.cursor(name="pipeline") as cur, conn.pipeline():
         with pytest.raises(psycopg.NotSupportedError):
             cur.execute("select 1")
+
+
+def test_cannot_insert_multiple_commands(conn):
+    with pytest.raises(psycopg.errors.SyntaxError) as cm:
+        with conn.pipeline():
+            conn.execute("select 1; select 2")
+    assert cm.value.sqlstate == "42601"
+
+
+def test_pipeline_processed_at_exit(conn):
+    with conn.cursor() as cur:
+        with conn.pipeline():
+            cur.execute("select 1")
+
+            # PQsendQuery[BEGIN], PQsendQuery
+            assert len(conn._pipeline.result_queue) == 2
+
+        assert cur.fetchone() == (1,)
+
+
+def test_pipeline_errors_processed_at_exit(conn):
+    conn.autocommit = True
+    with pytest.raises((OperationalError, UndefinedTable)):
+        with conn.pipeline():
+            conn.execute("select * from nosuchtable")
+            conn.execute("create table voila ()")
+    cur = conn.execute(
+        "select count(*) from pg_tables where tablename = %s", ("voila",)
+    )
+    (count,) = cur.fetchone()
+    assert count == 0
+
+
+def test_pipeline(conn):
+    with conn.pipeline():
+        c1 = conn.cursor()
+        c2 = conn.cursor()
+        c1.execute("select 1")
+        c2.execute("select 2")
+
+        # PQsendQuery[BEGIN], PQsendQuery(2)
+        assert len(conn._pipeline.result_queue) == 3
+
+        (r1,) = c1.fetchone()
+        assert r1 == 1
+
+    (r2,) = c2.fetchone()
+    assert r2 == 2
+
+
+def test_autocommit(conn):
+    conn.autocommit = True
+    with conn.pipeline(), conn.cursor() as c:
+        c.execute("select 1")
+
+        (r,) = c.fetchone()
+        assert r == 1
+
+
+def test_pipeline_aborted(conn):
+    conn.autocommit = True
+    with conn.pipeline():
+        c1 = conn.execute("select 1")
+        with pytest.raises(UndefinedTable):
+            conn.execute("select * from doesnotexist").fetchone()
+        with pytest.raises(OperationalError, match="pipeline aborted"):
+            conn.execute("select 'aborted'").fetchone()
+        # Sync restore the connection in usable state.
+        conn._pipeline.sync()
+        c2 = conn.execute("select 2")
+
+    (r,) = c1.fetchone()
+    assert r == 1
+
+    (r,) = c2.fetchone()
+    assert r == 2
+
+
+def test_pipeline_commit_aborted(conn):
+    with pytest.raises((UndefinedColumn, OperationalError)):
+        with conn.pipeline():
+            conn.execute("select error")
+            conn.execute("create table voila ()")
+            conn.commit()
+
+
+def test_executemany(conn):
+    conn.autocommit = True
+    conn.execute("drop table if exists execmanypipeline")
+    conn.execute(
+        "create unlogged table execmanypipeline ("
+        " id serial primary key, num integer)"
+    )
+    with conn.pipeline(), conn.cursor() as cur:
+        cur.executemany(
+            "insert into execmanypipeline(num) values (%s) returning id",
+            [(10,), (20,)],
+        )
+        assert cur.fetchone() == (1,)
+        assert cur.nextset()
+        assert cur.fetchone() == (2,)
+        assert cur.nextset() is None
+
+
+def test_prepared(conn):
+    conn.autocommit = True
+    with conn.pipeline():
+        c1 = conn.execute("select %s::int", [10], prepare=True)
+        c2 = conn.execute("select count(*) from pg_prepared_statements")
+
+        (r,) = c1.fetchone()
+        assert r == 10
+
+        (r,) = c2.fetchone()
+        assert r == 1
+
+
+def test_auto_prepare(conn):
+    conn.autocommit = True
+    conn.prepared_threshold = 5
+    with conn.pipeline():
+        cursors = [
+            conn.execute("select count(*) from pg_prepared_statements")
+            for i in range(10)
+        ]
+
+        assert len(conn._prepared._names) == 1
+
+    res = [c.fetchone()[0] for c in cursors]
+    assert res == [0] * 5 + [1] * 5
+
+
+def test_transaction(conn):
+    with conn.pipeline():
+        with conn.transaction():
+            cur = conn.execute("select 'tx'")
+
+        (r,) = cur.fetchone()
+        assert r == "tx"
+
+        with conn.transaction():
+            cur = conn.execute("select 'rb'")
+            raise psycopg.Rollback()
+
+        (r,) = cur.fetchone()
+        assert r == "rb"
+
+
+def test_transaction_nested(conn):
+    with conn.pipeline():
+        with conn.transaction():
+            outer = conn.execute("select 'outer'")
+            with pytest.raises(ZeroDivisionError):
+                with conn.transaction():
+                    inner = conn.execute("select 'inner'")
+                    1 / 0
+
+        (r,) = outer.fetchone()
+        assert r == "outer"
+        (r,) = inner.fetchone()
+        assert r == "inner"
+
+
+def test_outer_transaction(conn):
+    with conn.transaction():
+        with conn.pipeline():
+            conn.execute("drop table if exists outertx")
+            conn.execute("create table outertx as (select 1)")
+            cur = conn.execute("select * from outertx")
+    (r,) = cur.fetchone()
+    assert r == 1
+    cur = conn.execute("select count(*) from pg_tables where tablename = 'outertx'")
+    assert cur.fetchone()[0] == 1
+
+
+def test_outer_transaction_error(conn):
+    with conn.transaction():
+        with pytest.raises((UndefinedColumn, OperationalError)):
+            with conn.pipeline():
+                conn.execute("select error")
+                conn.execute("create table voila ()")
+
+
+def test_concurrency(conn):
+    with conn.transaction():
+        conn.execute("drop table if exists pipeline_concurrency")
+        conn.execute(
+            "create unlogged table pipeline_concurrency ("
+            " id serial primary key,"
+            " value integer"
+            ")"
+        )
+        conn.execute("drop table if exists accessed")
+        conn.execute("create unlogged table accessed as (select now() as value)")
+
+    def update(value):
+        cur = conn.execute(
+            "insert into pipeline_concurrency(value) values (%s) returning id",
+            (value,),
+        )
+        conn.execute("update accessed set value = now()")
+        return cur
+
+    conn.autocommit = True
+
+    (before,) = conn.execute("select value from accessed").fetchone()
+
+    values = range(1, 10)
+    with conn.pipeline():
+        with concurrent.futures.ThreadPoolExecutor() as e:
+            cursors = e.map(update, values, timeout=len(values))
+            assert sum(cur.fetchone()[0] for cur in cursors) == sum(values)
+
+    (s,) = conn.execute("select sum(value) from pipeline_concurrency").fetchone()
+    assert s == sum(values)
+    (after,) = conn.execute("select value from accessed").fetchone()
+    assert after > before
index 3aa6b2e59c7a88809f758dd8fc83c1608160140f..5c3dedc9490bf36ea8421273f428f629880c83b0 100644 (file)
@@ -1,8 +1,15 @@
+import asyncio
+
 import pytest
 
 import psycopg
 from psycopg import pq
-from psycopg.errors import ProgrammingError
+from psycopg.errors import (
+    OperationalError,
+    ProgrammingError,
+    UndefinedColumn,
+    UndefinedTable,
+)
 
 pytestmark = [
     pytest.mark.libpq(">= 14"),
@@ -32,3 +39,226 @@ async def test_server_cursor(aconn):
     async with aconn.cursor(name="pipeline") as cur, aconn.pipeline():
         with pytest.raises(psycopg.NotSupportedError):
             await cur.execute("select 1")
+
+
+async def test_cannot_insert_multiple_commands(aconn):
+    with pytest.raises(psycopg.errors.SyntaxError) as cm:
+        async with aconn.pipeline():
+            await aconn.execute("select 1; select 2")
+    assert cm.value.sqlstate == "42601"
+
+
+async def test_pipeline_processed_at_exit(aconn):
+    async with aconn.cursor() as cur:
+        async with aconn.pipeline():
+            await cur.execute("select 1")
+
+            # PQsendQuery[BEGIN], PQsendQuery
+            assert len(aconn._pipeline.result_queue) == 2
+
+        assert await cur.fetchone() == (1,)
+
+
+async def test_pipeline_errors_processed_at_exit(aconn):
+    await aconn.set_autocommit(True)
+    with pytest.raises((OperationalError, UndefinedTable)):
+        async with aconn.pipeline():
+            await aconn.execute("select * from nosuchtable")
+            await aconn.execute("create table voila ()")
+    cur = await aconn.execute(
+        "select count(*) from pg_tables where tablename = %s", ("voila",)
+    )
+    (count,) = await cur.fetchone()
+    assert count == 0
+
+
+async def test_pipeline(aconn):
+    async with aconn.pipeline():
+        c1 = aconn.cursor()
+        c2 = aconn.cursor()
+        await c1.execute("select 1")
+        await c2.execute("select 2")
+
+        # PQsendQuery[BEGIN], PQsendQuery(2)
+        assert len(aconn._pipeline.result_queue) == 3
+
+        (r1,) = await c1.fetchone()
+        assert r1 == 1
+
+    (r2,) = await c2.fetchone()
+    assert r2 == 2
+
+
+async def test_autocommit(aconn):
+    await aconn.set_autocommit(True)
+    async with aconn.pipeline(), aconn.cursor() as c:
+        await c.execute("select 1")
+
+        (r,) = await c.fetchone()
+        assert r == 1
+
+
+async def test_pipeline_aborted(aconn):
+    await aconn.set_autocommit(True)
+    async with aconn.pipeline():
+        c1 = await aconn.execute("select 1")
+        with pytest.raises(UndefinedTable):
+            await (await aconn.execute("select * from doesnotexist")).fetchone()
+        with pytest.raises(OperationalError, match="pipeline aborted"):
+            await (await aconn.execute("select 'aborted'")).fetchone()
+        # Sync restore the connection in usable state.
+        aconn._pipeline.sync()
+        c2 = await aconn.execute("select 2")
+
+    (r,) = await c1.fetchone()
+    assert r == 1
+
+    (r,) = await c2.fetchone()
+    assert r == 2
+
+
+async def test_pipeline_commit_aborted(aconn):
+    with pytest.raises((UndefinedColumn, OperationalError)):
+        async with aconn.pipeline():
+            await aconn.execute("select error")
+            await aconn.execute("create table voila ()")
+            await aconn.commit()
+
+
+async def test_executemany(aconn):
+    await aconn.set_autocommit(True)
+    await aconn.execute("drop table if exists execmanypipeline")
+    await aconn.execute(
+        "create unlogged table execmanypipeline ("
+        " id serial primary key, num integer)"
+    )
+    async with aconn.pipeline(), aconn.cursor() as cur:
+        await cur.executemany(
+            "insert into execmanypipeline(num) values (%s) returning id",
+            [(10,), (20,)],
+        )
+        assert (await cur.fetchone()) == (1,)
+        assert cur.nextset()
+        assert (await cur.fetchone()) == (2,)
+        assert cur.nextset() is None
+
+
+async def test_prepared(aconn):
+    await aconn.set_autocommit(True)
+    async with aconn.pipeline():
+        c1 = await aconn.execute("select %s::int", [10], prepare=True)
+        c2 = await aconn.execute("select count(*) from pg_prepared_statements")
+
+        (r,) = await c1.fetchone()
+        assert r == 10
+
+        (r,) = await c2.fetchone()
+        assert r == 1
+
+
+async def test_auto_prepare(aconn):
+    aconn.prepared_threshold = 5
+    async with aconn.pipeline():
+        cursors = [
+            await aconn.execute("select count(*) from pg_prepared_statements")
+            for i in range(10)
+        ]
+
+        assert len(aconn._prepared._names) == 1
+
+    res = [(await c.fetchone())[0] for c in cursors]
+    assert res == [0] * 5 + [1] * 5
+
+
+async def test_transaction(aconn):
+    async with aconn.pipeline():
+        async with aconn.transaction():
+            cur = await aconn.execute("select 'tx'")
+
+        (r,) = await cur.fetchone()
+        assert r == "tx"
+
+        async with aconn.transaction():
+            cur = await aconn.execute("select 'rb'")
+            raise psycopg.Rollback()
+
+        (r,) = await cur.fetchone()
+        assert r == "rb"
+
+
+async def test_transaction_nested(aconn):
+    async with aconn.pipeline():
+        async with aconn.transaction():
+            outer = await aconn.execute("select 'outer'")
+            with pytest.raises(ZeroDivisionError):
+                async with aconn.transaction():
+                    inner = await aconn.execute("select 'inner'")
+                    1 / 0
+
+        (r,) = await outer.fetchone()
+        assert r == "outer"
+        (r,) = await inner.fetchone()
+        assert r == "inner"
+
+
+async def test_outer_transaction(aconn):
+    async with aconn.transaction():
+        async with aconn.pipeline():
+            await aconn.execute("drop table if exists outertx")
+            await aconn.execute("create table outertx as (select 1)")
+            cur = await aconn.execute("select * from outertx")
+    (r,) = await cur.fetchone()
+    assert r == 1
+    cur = await aconn.execute(
+        "select count(*) from pg_tables where tablename = 'outertx'"
+    )
+    assert (await cur.fetchone())[0] == 1
+
+
+async def test_outer_transaction_error(aconn):
+    async with aconn.transaction():
+        with pytest.raises((UndefinedColumn, OperationalError)):
+            async with aconn.pipeline():
+                await aconn.execute("select error")
+                await aconn.execute("create table voila ()")
+
+
+async def test_concurrency(aconn):
+    async with aconn.transaction():
+        await aconn.execute("drop table if exists pipeline_concurrency")
+        await aconn.execute(
+            "create unlogged table pipeline_concurrency ("
+            " id serial primary key,"
+            " value integer"
+            ")"
+        )
+        await aconn.execute("drop table if exists accessed")
+        await aconn.execute("create unlogged table accessed as (select now() as value)")
+
+    async def update(value):
+        cur = await aconn.execute(
+            "insert into pipeline_concurrency(value) values (%s) returning id",
+            (value,),
+        )
+        await aconn.execute("update accessed set value = now()")
+        return cur
+
+    await aconn.set_autocommit(True)
+
+    (before,) = await (await aconn.execute("select value from accessed")).fetchone()
+
+    values = range(1, 10)
+    async with aconn.pipeline():
+        cursors = await asyncio.wait_for(
+            asyncio.gather(*[update(value) for value in values]),
+            timeout=len(values),
+        )
+
+    assert sum([(await cur.fetchone())[0] for cur in cursors]) == sum(values)
+
+    (s,) = await (
+        await aconn.execute("select sum(value) from pipeline_concurrency")
+    ).fetchone()
+    assert s == sum(values)
+    (after,) = await (await aconn.execute("select value from accessed")).fetchone()
+    assert after > before