From: Denis Laxalde Date: Mon, 11 Oct 2021 15:16:39 +0000 (+0200) Subject: Add support for pipeline mode in execute()/fetch*() X-Git-Tag: 3.1~146^2~13 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d92c9119067ecc11a5c7f70f58aeceb7cf186c3f;p=thirdparty%2Fpsycopg.git Add support for pipeline mode in execute()/fetch*() When activated on the connection, a pipeline active pipeline handles a queue of commands to send and a queue of results to process. The command queue simply contains Callable[[], None], which are built from partial applications of pgconn.send_*() methods et al. The queue of results to process either contains a None, when respective command returns no tuple, or a tuple with respective cursor and query information needed to maintain automatic prepared statement. Everywhere we run the execute() generator in non-pipeline mode, we now enqueue items in the pipeline queues. Then we run pipeline_communicate(), through the _communicate_gen() method of BasePipeline, in BaseCursor._execute(many)_gen(). Since pipeline_communicate() may not fetch all results, we need a dedicated fetch (forced) step upon call to cursor.fetch*(); this is done by Cursor._fetch_pipeline() called in fetch*() methods. This calls PQsendFlushRequest() in order to avoid blocking on PQgetResult(). At exit of pipeline mode, we unconditionally emit a PQpipelineSync() call in order to restore the connection in a usable state in case of error and we force results fetch after sending any pending commands (e.g. commands not emitted through an execute() call). The pipeline-demo.py test script is updated to include examples using the high-level API. This only works with the 'python' of libpq bindings because we monkeypatch the pgconn attribute of the connection. --- diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 2fa4ecf0a..c3b1c44c1 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -20,7 +20,7 @@ from . import errors as e from . import waiting from . import postgres from .pq import ConnStatus, ExecStatus, TransactionStatus, Format -from .abc import AdaptContext, ConnectionType, Params, Query, RV +from .abc import AdaptContext, Command, ConnectionType, Params, Query, RV from .abc import PQGen, PQGenConn from .sql import Composable, SQL from ._tpc import Xid @@ -28,17 +28,18 @@ from .rows import Row, RowFactory, tuple_row, TupleRow, args_row from .adapt import AdaptersMap from ._enums import IsolationLevel from .cursor import Cursor -from ._compat import TypeAlias +from ._compat import Deque, TypeAlias from ._cmodule import _psycopg from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo from .generators import notifies from ._encodings import pgconn_encoding -from ._preparing import PrepareManager +from ._preparing import Key, Prepare, PrepareManager from .transaction import Transaction from .server_cursor import ServerCursor if TYPE_CHECKING: from .pq.abc import PGconn, PGresult + from .cursor import BaseCursor from psycopg_pool.base import BasePool logger = logging.getLogger("psycopg") @@ -53,12 +54,18 @@ CursorRow = TypeVar("CursorRow") if _psycopg: connect = _psycopg.connect execute = _psycopg.execute + fetch_many = _psycopg.fetch_many + pipeline_communicate = _psycopg.pipeline_communicate + send = _psycopg.send else: from . import generators connect = generators.connect execute = generators.execute + fetch_many = generators.fetch_many + pipeline_communicate = generators.pipeline_communicate + send = generators.send class Notify(NamedTuple): @@ -79,21 +86,103 @@ Notify.__module__ = "psycopg" NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None] NotifyHandler: TypeAlias = Callable[[Notify], None] +PipelinePendingResult = Union[ + None, + Tuple[ + "BaseCursor[Any, Any]", + Optional[Tuple[Key, Prepare, bytes]], + ], +] + class BasePipeline: def __init__(self, pgconn: "PGconn") -> None: self.pgconn = pgconn + self.command_queue = Deque[Command]() + self.result_queue = Deque[PipelinePendingResult]() @property def status(self) -> pq.PipelineStatus: return pq.PipelineStatus(self.pgconn.pipeline_status) + def sync(self) -> None: + """Enqueue a PQpipelineSync() command.""" + self.command_queue.append(self.pgconn.pipeline_sync) + self.result_queue.append(None) + def _enter(self) -> None: self.pgconn.enter_pipeline_mode() def _exit(self) -> None: self.pgconn.exit_pipeline_mode() + def _communicate_gen(self) -> PQGen[None]: + """Communicate with pipeline to send commands and possibly fetch + results, which are then processed. + """ + fetched = yield from pipeline_communicate(self.pgconn, self.command_queue) + to_process = [(self.result_queue.popleft(), results) for results in fetched] + for queued, results in to_process: + self._process_results(queued, results) + + def _fetch_gen(self, *, flush: bool) -> PQGen[None]: + """Fetch available results from the connection and process them with + pipeline queued items. + + If 'flush' is True, a PQsendFlushRequest() is issued in order to make + sure results can be fetched. Otherwise, the caller may emit a + PQpipelineSync() call to ensure the output buffer gets flushed before + fetching. + """ + if not self.result_queue: + return + + if flush: + self.pgconn.send_flush_request() + yield from send(self.pgconn) + + to_process = [] + while self.result_queue: + results = yield from fetch_many(self.pgconn) + if not results: + # No more results to fetch, but there may still be pending + # commands. + break + queued = self.result_queue.popleft() + to_process.append((queued, results)) + + for queued, results in to_process: + self._process_results(queued, results) + + def _process_results( + self, queued: PipelinePendingResult, results: List["PGresult"] + ) -> None: + """Process a results set fetched from the current pipeline. + + This matchs 'results' with its respective element in the pipeline + queue. For commands (None value in the pipeline queue), results are + checked directly. For prepare statement creation requests, update the + cache. Otherwise, results are attached to their respective cursor. + """ + if queued is None: + (result,) = results + if result.status == ExecStatus.FATAL_ERROR: + raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn)) + elif result.status == ExecStatus.PIPELINE_ABORTED: + raise e.OperationalError("pipeline aborted") + else: + cursor, prepinfo = queued + cursor._check_results(results) + if not cursor._results: + cursor._results = results + cursor._set_current_result(0) + else: + cursor._results.extend(results) + if prepinfo: + key, prep, name = prepinfo + # Update the prepare state of the query. + cursor._conn._prepared.validate(key, prep, name, results) + class Pipeline(BasePipeline): """Handler for connection in pipeline mode.""" @@ -444,7 +533,7 @@ class BaseConnection(Generic[Row]): def _exec_command( self, command: Query, result_format: Format = Format.TEXT - ) -> PQGen["PGresult"]: + ) -> PQGen[Optional["PGresult"]]: """ Generator to send a command and receive the result to the backend. @@ -458,6 +547,20 @@ class BaseConnection(Generic[Row]): elif isinstance(command, Composable): command = command.as_bytes(self) + if self._pipeline: + if result_format == Format.TEXT: + cmd = partial(self.pgconn.send_query, command) + else: + cmd = partial( + self.pgconn.send_query_params, + command, + None, + result_format=result_format, + ) + self._pipeline.command_queue.append(cmd) + self._pipeline.result_queue.append(None) + return None + if result_format == Format.TEXT: self.pgconn.send_query(command) else: @@ -895,7 +998,19 @@ class Connection(BaseConnection[Row]): pipeline = self._pipeline = Pipeline(self.pgconn) try: with pipeline: - yield + try: + yield + finally: + with self.lock: + pipeline.sync() + try: + # Send an pending commands (e.g. COMMIT or Sync); + # while processing results, we might get errors... + self.wait(pipeline._communicate_gen()) + finally: + # then fetch all remaining results but without forcing + # flush since we emitted a sync just before. + self.wait(pipeline._fetch_gen(flush=False)) finally: assert pipeline.status == pq.PipelineStatus.OFF, pipeline.status self._pipeline = None diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index daa4f157c..5208c8b24 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -319,7 +319,19 @@ class AsyncConnection(BaseConnection[Row]): pipeline = self._pipeline = AsyncPipeline(self.pgconn) try: async with pipeline: - yield + try: + yield + finally: + async with self.lock: + pipeline.sync() + try: + # Send an pending commands (e.g. COMMIT or Sync); + # while processing results, we might get errors... + await self.wait(pipeline._communicate_gen()) + finally: + # then fetch all remaining results but without forcing + # flush since we emitted a sync just before. + await self.wait(pipeline._fetch_gen(flush=False)) finally: assert pipeline.status == PipelineStatus.OFF, pipeline.status self._pipeline = None diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 7709454d1..40c57241b 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -4,6 +4,7 @@ psycopg cursor objects # Copyright (C) 2020 The Psycopg Team +from functools import partial from types import TracebackType from typing import Any, Generic, Iterable, Iterator, List from typing import Optional, NoReturn, Sequence, Type, TypeVar, TYPE_CHECKING @@ -194,9 +195,14 @@ class BaseCursor(Generic[ConnectionType, Row]): results = yield from self._maybe_prepare_gen( pgq, prepare=prepare, binary=binary ) - self._check_results(results) - self._results = results - self._set_current_result(0) + if self._conn._pipeline: + yield from self._conn._pipeline._communicate_gen() + else: + assert results is not None + self._check_results(results) + self._results = results + self._set_current_result(0) + self._last_query = query for cmd in self._conn._prepared.get_maintenance_commands(): @@ -218,20 +224,28 @@ class BaseCursor(Generic[ConnectionType, Row]): pgq.dump(params) results = yield from self._maybe_prepare_gen(pgq, prepare=True) - self._check_results(results) - if returning and results[0].status == ExecStatus.TUPLES_OK: - self._results.extend(results) - for res in results: - nrows += res.command_tuples or 0 + if self._conn._pipeline: + yield from self._conn._pipeline._communicate_gen() + else: + assert results is not None + self._check_results(results) + if returning and results[0].status == ExecStatus.TUPLES_OK: + self._results.extend(results) - if self._results: - self._set_current_result(0) + for res in results: + nrows += res.command_tuples or 0 + + if not self._conn._pipeline: + if self._results: + self._set_current_result(0) + + # Override rowcount for the first result. Calls to nextset() will + # change it to the value of that result only, but we hope nobody + # will notice. + # You haven't read this comment. + self._rowcount = nrows - # Override rowcount for the first result. Calls to nextset() will change - # it to the value of that result only, but we hope nobody will notice. - # You haven't read this comment. - self._rowcount = nrows self._last_query = query for cmd in self._conn._prepared.get_maintenance_commands(): @@ -243,7 +257,7 @@ class BaseCursor(Generic[ConnectionType, Row]): *, prepare: Optional[bool] = None, binary: Optional[bool] = None, - ) -> PQGen[List["PGresult"]]: + ) -> PQGen[Optional[List["PGresult"]]]: # Check if the query is prepared or needs preparing prep, name = self._conn._prepared.get(pgq, prepare) if prep is Prepare.NO: @@ -253,19 +267,28 @@ class BaseCursor(Generic[ConnectionType, Row]): # If the query is not already prepared, prepare it. if prep is Prepare.SHOULD: self._send_prepare(name, pgq) - (result,) = yield from execute(self._pgconn) - if result.status == ExecStatus.FATAL_ERROR: - raise e.error_from_result(result, encoding=self._encoding) + if not self._conn._pipeline: + (result,) = yield from execute(self._pgconn) + if result.status == ExecStatus.FATAL_ERROR: + raise e.error_from_result(result, encoding=self._encoding) # Then execute it. self._send_query_prepared(name, pgq, binary=binary) - # run the query - results = yield from execute(self._pgconn) - # Update the prepare state of the query. # If an operation requires to flush our prepared statements cache, # it will be added to the maintenance commands to execute later. key = self._conn._prepared.maybe_add_to_cache(pgq, prep, name) + + if self._conn._pipeline: + queued = None + if key is not None: + queued = (key, prep, name) + self._conn._pipeline.result_queue.append((self, queued)) + return None + + # run the query + results = yield from execute(self._pgconn) + if key is not None: self._conn._prepared.validate(key, prep, name, results) @@ -363,17 +386,34 @@ class BaseCursor(Generic[ConnectionType, Row]): self._query = query if query.params or no_pqexec or fmt == Format.BINARY: - self._pgconn.send_query_params( - query.query, - query.params, - param_formats=query.formats, - param_types=query.types, - result_format=fmt, - ) + if self._conn._pipeline: + self._conn._pipeline.command_queue.append( + partial( + self._pgconn.send_query_params, + query.query, + query.params, + param_formats=query.formats, + param_types=query.types, + result_format=fmt, + ) + ) + else: + self._pgconn.send_query_params( + query.query, + query.params, + param_formats=query.formats, + param_types=query.types, + result_format=fmt, + ) else: # if we don't have to, let's use exec_ as it can run more than # one query in one go - self._pgconn.send_query(query.query) + if self._conn._pipeline: + self._conn._pipeline.command_queue.append( + partial(self._pgconn.send_query, query.query) + ) + else: + self._pgconn.send_query(query.query) def _convert_query( self, query: Query, params: Optional[Params] = None @@ -442,7 +482,18 @@ class BaseCursor(Generic[ConnectionType, Row]): self._rowcount = nrows if nrows is not None else -1 def _send_prepare(self, name: bytes, query: PostgresQuery) -> None: - self._pgconn.send_prepare(name, query.query, param_types=query.types) + if self._conn._pipeline: + self._conn._pipeline.command_queue.append( + partial( + self._pgconn.send_prepare, + name, + query.query, + param_types=query.types, + ) + ) + self._conn._pipeline.result_queue.append(None) + else: + self._pgconn.send_prepare(name, query.query, param_types=query.types) def _send_query_prepared( self, name: bytes, pgq: PostgresQuery, *, binary: Optional[bool] = None @@ -452,9 +503,20 @@ class BaseCursor(Generic[ConnectionType, Row]): else: fmt = Format.BINARY if binary else Format.TEXT - self._pgconn.send_query_prepared( - name, pgq.params, param_formats=pgq.formats, result_format=fmt - ) + if self._conn._pipeline: + self._conn._pipeline.command_queue.append( + partial( + self._pgconn.send_query_prepared, + name, + pgq.params, + param_formats=pgq.formats, + result_format=fmt, + ) + ) + else: + self._pgconn.send_query_prepared( + name, pgq.params, param_formats=pgq.formats, result_format=fmt + ) def _check_result_for_fetch(self) -> None: if self.closed: @@ -610,6 +672,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]): :rtype: Optional[Row], with Row defined by `row_factory` """ + self._fetch_pipeline() self._check_result_for_fetch() record = self._tx.load_row(self._pos, self._make_row) if record is not None: @@ -624,6 +687,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]): :rtype: Sequence[Row], with Row defined by `row_factory` """ + self._fetch_pipeline() self._check_result_for_fetch() assert self.pgresult @@ -643,6 +707,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]): :rtype: Sequence[Row], with Row defined by `row_factory` """ + self._fetch_pipeline() self._check_result_for_fetch() assert self.pgresult records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row) @@ -650,6 +715,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]): return records def __iter__(self) -> Iterator[Row]: + self._fetch_pipeline() self._check_result_for_fetch() def load(pos: int) -> Optional[Row]: @@ -673,6 +739,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]): Raise `!IndexError` in case a scroll operation would leave the result set. In this case the position will not change. """ + self._fetch_pipeline() self._scroll(value, mode) @contextmanager @@ -687,3 +754,9 @@ class Cursor(BaseCursor["Connection[Any]", Row]): with Copy(self) as copy: yield copy + + def _fetch_pipeline(self) -> None: + if not self.pgresult and self._conn._pipeline: + with self._conn.lock: + self._conn.wait(self._conn._pipeline._fetch_gen(flush=True)) + assert self.pgresult diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index b1211e5aa..75131430d 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -113,6 +113,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): first = False async def fetchone(self) -> Optional[Row]: + await self._fetch_pipeline() self._check_result_for_fetch() rv = self._tx.load_row(self._pos, self._make_row) if rv is not None: @@ -120,6 +121,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): return rv async def fetchmany(self, size: int = 0) -> List[Row]: + await self._fetch_pipeline() self._check_result_for_fetch() assert self.pgresult @@ -134,6 +136,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): return records async def fetchall(self) -> List[Row]: + await self._fetch_pipeline() self._check_result_for_fetch() assert self.pgresult records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row) @@ -141,6 +144,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): return records async def __aiter__(self) -> AsyncIterator[Row]: + await self._fetch_pipeline() self._check_result_for_fetch() def load(pos: int) -> Optional[Row]: @@ -166,3 +170,9 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): async with AsyncCopy(self) as copy: yield copy + + async def _fetch_pipeline(self) -> None: + if not self.pgresult and self._conn._pipeline: + async with self._conn.lock: + await self._conn.wait(self._conn._pipeline._fetch_gen(flush=True)) + assert self.pgresult diff --git a/psycopg/psycopg/server_cursor.py b/psycopg/psycopg/server_cursor.py index 8439ecea1..9561297f1 100644 --- a/psycopg/psycopg/server_cursor.py +++ b/psycopg/psycopg/server_cursor.py @@ -108,6 +108,8 @@ class ServerCursorHelper(Generic[ConnectionType, Row]): "SELECT 1 FROM pg_catalog.pg_cursors WHERE name = {}" ).format(sql.Literal(self.name)) res = yield from cur._conn._exec_command(query) + # pipeline mode otherwise, unsupported here. + assert res is not None if res.ntuples == 0: return @@ -129,6 +131,8 @@ class ServerCursorHelper(Generic[ConnectionType, Row]): sql.Identifier(self.name), ) res = yield from cur._conn._exec_command(query, result_format=self._format) + # pipeline mode otherwise, unsupported here. + assert res is not None cur.pgresult = res cur._tx.set_pgresult(res, set_loaders=False) diff --git a/tests/scripts/pipeline-demo.py b/tests/scripts/pipeline-demo.py index b8068476b..c0636587d 100644 --- a/tests/scripts/pipeline-demo.py +++ b/tests/scripts/pipeline-demo.py @@ -203,6 +203,52 @@ async def pipeline_demo_pq_async(rows_to_send: int, logger: logging.Logger) -> N raise e.error_from_result(r) +def pipeline_demo(rows_to_send: int, logger: logging.Logger) -> None: + """Pipeline demo using sync API.""" + conn = Connection.connect() + conn.autocommit = True + conn.pgconn = LoggingPGconn(conn.pgconn, logger) # type: ignore[assignment] + with conn.pipeline(): + with conn.transaction(): + conn.execute("DROP TABLE IF EXISTS pq_pipeline_demo") + conn.execute( + "CREATE UNLOGGED TABLE pq_pipeline_demo(" + " id serial primary key," + " itemno integer," + " int8filler int8" + ")" + ) + for r in range(rows_to_send, 0, -1): + conn.execute( + "INSERT INTO pq_pipeline_demo(itemno, int8filler)" + " VALUES (%s, %s)", + (r, 1 << 62), + ) + + +async def pipeline_demo_async(rows_to_send: int, logger: logging.Logger) -> None: + """Pipeline demo using async API.""" + aconn = await AsyncConnection.connect() + await aconn.set_autocommit(True) + aconn.pgconn = LoggingPGconn(aconn.pgconn, logger) # type: ignore[assignment] + async with aconn.pipeline(): + async with aconn.transaction(): + await aconn.execute("DROP TABLE IF EXISTS pq_pipeline_demo") + await aconn.execute( + "CREATE UNLOGGED TABLE pq_pipeline_demo(" + " id serial primary key," + " itemno integer," + " int8filler int8" + ")" + ) + for r in range(rows_to_send, 0, -1): + await aconn.execute( + "INSERT INTO pq_pipeline_demo(itemno, int8filler)" + " VALUES (%s, %s)", + (r, 1 << 62), + ) + + def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( @@ -213,6 +259,9 @@ def main() -> None: type=int, help="number of rows to insert", ) + parser.add_argument( + "--pq", action="store_true", help="use low-level psycopg.pq API" + ) parser.add_argument( "--async", dest="async_", action="store_true", help="use async API" ) @@ -228,10 +277,20 @@ def main() -> None: else: logger.addHandler(logging.StreamHandler()) pipeline_logger.addHandler(logging.StreamHandler()) - if args.async_: - asyncio.run(pipeline_demo_pq_async(args.nrows, pipeline_logger)) + if args.pq: + if args.async_: + asyncio.run(pipeline_demo_pq_async(args.nrows, pipeline_logger)) + else: + pipeline_demo_pq(args.nrows, pipeline_logger) else: - pipeline_demo_pq(args.nrows, pipeline_logger) + if pq.__impl__ != "python": + parser.error( + "only supported for Python implementation (set PSYCOPG_IMPL=python)" + ) + if args.async_: + asyncio.run(pipeline_demo_async(args.nrows, pipeline_logger)) + else: + pipeline_demo(args.nrows, pipeline_logger) if __name__ == "__main__": diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 31662dd76..99f404663 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,8 +1,15 @@ +import concurrent.futures + import pytest import psycopg from psycopg import pq -from psycopg.errors import ProgrammingError +from psycopg.errors import ( + OperationalError, + ProgrammingError, + UndefinedColumn, + UndefinedTable, +) pytestmark = pytest.mark.libpq(">= 14") @@ -29,3 +36,220 @@ def test_server_cursor(conn): with conn.cursor(name="pipeline") as cur, conn.pipeline(): with pytest.raises(psycopg.NotSupportedError): cur.execute("select 1") + + +def test_cannot_insert_multiple_commands(conn): + with pytest.raises(psycopg.errors.SyntaxError) as cm: + with conn.pipeline(): + conn.execute("select 1; select 2") + assert cm.value.sqlstate == "42601" + + +def test_pipeline_processed_at_exit(conn): + with conn.cursor() as cur: + with conn.pipeline(): + cur.execute("select 1") + + # PQsendQuery[BEGIN], PQsendQuery + assert len(conn._pipeline.result_queue) == 2 + + assert cur.fetchone() == (1,) + + +def test_pipeline_errors_processed_at_exit(conn): + conn.autocommit = True + with pytest.raises((OperationalError, UndefinedTable)): + with conn.pipeline(): + conn.execute("select * from nosuchtable") + conn.execute("create table voila ()") + cur = conn.execute( + "select count(*) from pg_tables where tablename = %s", ("voila",) + ) + (count,) = cur.fetchone() + assert count == 0 + + +def test_pipeline(conn): + with conn.pipeline(): + c1 = conn.cursor() + c2 = conn.cursor() + c1.execute("select 1") + c2.execute("select 2") + + # PQsendQuery[BEGIN], PQsendQuery(2) + assert len(conn._pipeline.result_queue) == 3 + + (r1,) = c1.fetchone() + assert r1 == 1 + + (r2,) = c2.fetchone() + assert r2 == 2 + + +def test_autocommit(conn): + conn.autocommit = True + with conn.pipeline(), conn.cursor() as c: + c.execute("select 1") + + (r,) = c.fetchone() + assert r == 1 + + +def test_pipeline_aborted(conn): + conn.autocommit = True + with conn.pipeline(): + c1 = conn.execute("select 1") + with pytest.raises(UndefinedTable): + conn.execute("select * from doesnotexist").fetchone() + with pytest.raises(OperationalError, match="pipeline aborted"): + conn.execute("select 'aborted'").fetchone() + # Sync restore the connection in usable state. + conn._pipeline.sync() + c2 = conn.execute("select 2") + + (r,) = c1.fetchone() + assert r == 1 + + (r,) = c2.fetchone() + assert r == 2 + + +def test_pipeline_commit_aborted(conn): + with pytest.raises((UndefinedColumn, OperationalError)): + with conn.pipeline(): + conn.execute("select error") + conn.execute("create table voila ()") + conn.commit() + + +def test_executemany(conn): + conn.autocommit = True + conn.execute("drop table if exists execmanypipeline") + conn.execute( + "create unlogged table execmanypipeline (" + " id serial primary key, num integer)" + ) + with conn.pipeline(), conn.cursor() as cur: + cur.executemany( + "insert into execmanypipeline(num) values (%s) returning id", + [(10,), (20,)], + ) + assert cur.fetchone() == (1,) + assert cur.nextset() + assert cur.fetchone() == (2,) + assert cur.nextset() is None + + +def test_prepared(conn): + conn.autocommit = True + with conn.pipeline(): + c1 = conn.execute("select %s::int", [10], prepare=True) + c2 = conn.execute("select count(*) from pg_prepared_statements") + + (r,) = c1.fetchone() + assert r == 10 + + (r,) = c2.fetchone() + assert r == 1 + + +def test_auto_prepare(conn): + conn.autocommit = True + conn.prepared_threshold = 5 + with conn.pipeline(): + cursors = [ + conn.execute("select count(*) from pg_prepared_statements") + for i in range(10) + ] + + assert len(conn._prepared._names) == 1 + + res = [c.fetchone()[0] for c in cursors] + assert res == [0] * 5 + [1] * 5 + + +def test_transaction(conn): + with conn.pipeline(): + with conn.transaction(): + cur = conn.execute("select 'tx'") + + (r,) = cur.fetchone() + assert r == "tx" + + with conn.transaction(): + cur = conn.execute("select 'rb'") + raise psycopg.Rollback() + + (r,) = cur.fetchone() + assert r == "rb" + + +def test_transaction_nested(conn): + with conn.pipeline(): + with conn.transaction(): + outer = conn.execute("select 'outer'") + with pytest.raises(ZeroDivisionError): + with conn.transaction(): + inner = conn.execute("select 'inner'") + 1 / 0 + + (r,) = outer.fetchone() + assert r == "outer" + (r,) = inner.fetchone() + assert r == "inner" + + +def test_outer_transaction(conn): + with conn.transaction(): + with conn.pipeline(): + conn.execute("drop table if exists outertx") + conn.execute("create table outertx as (select 1)") + cur = conn.execute("select * from outertx") + (r,) = cur.fetchone() + assert r == 1 + cur = conn.execute("select count(*) from pg_tables where tablename = 'outertx'") + assert cur.fetchone()[0] == 1 + + +def test_outer_transaction_error(conn): + with conn.transaction(): + with pytest.raises((UndefinedColumn, OperationalError)): + with conn.pipeline(): + conn.execute("select error") + conn.execute("create table voila ()") + + +def test_concurrency(conn): + with conn.transaction(): + conn.execute("drop table if exists pipeline_concurrency") + conn.execute( + "create unlogged table pipeline_concurrency (" + " id serial primary key," + " value integer" + ")" + ) + conn.execute("drop table if exists accessed") + conn.execute("create unlogged table accessed as (select now() as value)") + + def update(value): + cur = conn.execute( + "insert into pipeline_concurrency(value) values (%s) returning id", + (value,), + ) + conn.execute("update accessed set value = now()") + return cur + + conn.autocommit = True + + (before,) = conn.execute("select value from accessed").fetchone() + + values = range(1, 10) + with conn.pipeline(): + with concurrent.futures.ThreadPoolExecutor() as e: + cursors = e.map(update, values, timeout=len(values)) + assert sum(cur.fetchone()[0] for cur in cursors) == sum(values) + + (s,) = conn.execute("select sum(value) from pipeline_concurrency").fetchone() + assert s == sum(values) + (after,) = conn.execute("select value from accessed").fetchone() + assert after > before diff --git a/tests/test_pipeline_async.py b/tests/test_pipeline_async.py index 3aa6b2e59..5c3dedc94 100644 --- a/tests/test_pipeline_async.py +++ b/tests/test_pipeline_async.py @@ -1,8 +1,15 @@ +import asyncio + import pytest import psycopg from psycopg import pq -from psycopg.errors import ProgrammingError +from psycopg.errors import ( + OperationalError, + ProgrammingError, + UndefinedColumn, + UndefinedTable, +) pytestmark = [ pytest.mark.libpq(">= 14"), @@ -32,3 +39,226 @@ async def test_server_cursor(aconn): async with aconn.cursor(name="pipeline") as cur, aconn.pipeline(): with pytest.raises(psycopg.NotSupportedError): await cur.execute("select 1") + + +async def test_cannot_insert_multiple_commands(aconn): + with pytest.raises(psycopg.errors.SyntaxError) as cm: + async with aconn.pipeline(): + await aconn.execute("select 1; select 2") + assert cm.value.sqlstate == "42601" + + +async def test_pipeline_processed_at_exit(aconn): + async with aconn.cursor() as cur: + async with aconn.pipeline(): + await cur.execute("select 1") + + # PQsendQuery[BEGIN], PQsendQuery + assert len(aconn._pipeline.result_queue) == 2 + + assert await cur.fetchone() == (1,) + + +async def test_pipeline_errors_processed_at_exit(aconn): + await aconn.set_autocommit(True) + with pytest.raises((OperationalError, UndefinedTable)): + async with aconn.pipeline(): + await aconn.execute("select * from nosuchtable") + await aconn.execute("create table voila ()") + cur = await aconn.execute( + "select count(*) from pg_tables where tablename = %s", ("voila",) + ) + (count,) = await cur.fetchone() + assert count == 0 + + +async def test_pipeline(aconn): + async with aconn.pipeline(): + c1 = aconn.cursor() + c2 = aconn.cursor() + await c1.execute("select 1") + await c2.execute("select 2") + + # PQsendQuery[BEGIN], PQsendQuery(2) + assert len(aconn._pipeline.result_queue) == 3 + + (r1,) = await c1.fetchone() + assert r1 == 1 + + (r2,) = await c2.fetchone() + assert r2 == 2 + + +async def test_autocommit(aconn): + await aconn.set_autocommit(True) + async with aconn.pipeline(), aconn.cursor() as c: + await c.execute("select 1") + + (r,) = await c.fetchone() + assert r == 1 + + +async def test_pipeline_aborted(aconn): + await aconn.set_autocommit(True) + async with aconn.pipeline(): + c1 = await aconn.execute("select 1") + with pytest.raises(UndefinedTable): + await (await aconn.execute("select * from doesnotexist")).fetchone() + with pytest.raises(OperationalError, match="pipeline aborted"): + await (await aconn.execute("select 'aborted'")).fetchone() + # Sync restore the connection in usable state. + aconn._pipeline.sync() + c2 = await aconn.execute("select 2") + + (r,) = await c1.fetchone() + assert r == 1 + + (r,) = await c2.fetchone() + assert r == 2 + + +async def test_pipeline_commit_aborted(aconn): + with pytest.raises((UndefinedColumn, OperationalError)): + async with aconn.pipeline(): + await aconn.execute("select error") + await aconn.execute("create table voila ()") + await aconn.commit() + + +async def test_executemany(aconn): + await aconn.set_autocommit(True) + await aconn.execute("drop table if exists execmanypipeline") + await aconn.execute( + "create unlogged table execmanypipeline (" + " id serial primary key, num integer)" + ) + async with aconn.pipeline(), aconn.cursor() as cur: + await cur.executemany( + "insert into execmanypipeline(num) values (%s) returning id", + [(10,), (20,)], + ) + assert (await cur.fetchone()) == (1,) + assert cur.nextset() + assert (await cur.fetchone()) == (2,) + assert cur.nextset() is None + + +async def test_prepared(aconn): + await aconn.set_autocommit(True) + async with aconn.pipeline(): + c1 = await aconn.execute("select %s::int", [10], prepare=True) + c2 = await aconn.execute("select count(*) from pg_prepared_statements") + + (r,) = await c1.fetchone() + assert r == 10 + + (r,) = await c2.fetchone() + assert r == 1 + + +async def test_auto_prepare(aconn): + aconn.prepared_threshold = 5 + async with aconn.pipeline(): + cursors = [ + await aconn.execute("select count(*) from pg_prepared_statements") + for i in range(10) + ] + + assert len(aconn._prepared._names) == 1 + + res = [(await c.fetchone())[0] for c in cursors] + assert res == [0] * 5 + [1] * 5 + + +async def test_transaction(aconn): + async with aconn.pipeline(): + async with aconn.transaction(): + cur = await aconn.execute("select 'tx'") + + (r,) = await cur.fetchone() + assert r == "tx" + + async with aconn.transaction(): + cur = await aconn.execute("select 'rb'") + raise psycopg.Rollback() + + (r,) = await cur.fetchone() + assert r == "rb" + + +async def test_transaction_nested(aconn): + async with aconn.pipeline(): + async with aconn.transaction(): + outer = await aconn.execute("select 'outer'") + with pytest.raises(ZeroDivisionError): + async with aconn.transaction(): + inner = await aconn.execute("select 'inner'") + 1 / 0 + + (r,) = await outer.fetchone() + assert r == "outer" + (r,) = await inner.fetchone() + assert r == "inner" + + +async def test_outer_transaction(aconn): + async with aconn.transaction(): + async with aconn.pipeline(): + await aconn.execute("drop table if exists outertx") + await aconn.execute("create table outertx as (select 1)") + cur = await aconn.execute("select * from outertx") + (r,) = await cur.fetchone() + assert r == 1 + cur = await aconn.execute( + "select count(*) from pg_tables where tablename = 'outertx'" + ) + assert (await cur.fetchone())[0] == 1 + + +async def test_outer_transaction_error(aconn): + async with aconn.transaction(): + with pytest.raises((UndefinedColumn, OperationalError)): + async with aconn.pipeline(): + await aconn.execute("select error") + await aconn.execute("create table voila ()") + + +async def test_concurrency(aconn): + async with aconn.transaction(): + await aconn.execute("drop table if exists pipeline_concurrency") + await aconn.execute( + "create unlogged table pipeline_concurrency (" + " id serial primary key," + " value integer" + ")" + ) + await aconn.execute("drop table if exists accessed") + await aconn.execute("create unlogged table accessed as (select now() as value)") + + async def update(value): + cur = await aconn.execute( + "insert into pipeline_concurrency(value) values (%s) returning id", + (value,), + ) + await aconn.execute("update accessed set value = now()") + return cur + + await aconn.set_autocommit(True) + + (before,) = await (await aconn.execute("select value from accessed")).fetchone() + + values = range(1, 10) + async with aconn.pipeline(): + cursors = await asyncio.wait_for( + asyncio.gather(*[update(value) for value in values]), + timeout=len(values), + ) + + assert sum([(await cur.fetchone())[0] for cur in cursors]) == sum(values) + + (s,) = await ( + await aconn.execute("select sum(value) from pipeline_concurrency") + ).fetchone() + assert s == sum(values) + (after,) = await (await aconn.execute("select value from accessed")).fetchone() + assert after > before