from . import waiting
from . import postgres
from .pq import ConnStatus, ExecStatus, TransactionStatus, Format
-from .abc import AdaptContext, ConnectionType, Params, Query, RV
+from .abc import AdaptContext, Command, ConnectionType, Params, Query, RV
from .abc import PQGen, PQGenConn
from .sql import Composable, SQL
from ._tpc import Xid
from .adapt import AdaptersMap
from ._enums import IsolationLevel
from .cursor import Cursor
-from ._compat import TypeAlias
+from ._compat import Deque, TypeAlias
from ._cmodule import _psycopg
from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
from .generators import notifies
from ._encodings import pgconn_encoding
-from ._preparing import PrepareManager
+from ._preparing import Key, Prepare, PrepareManager
from .transaction import Transaction
from .server_cursor import ServerCursor
if TYPE_CHECKING:
from .pq.abc import PGconn, PGresult
+ from .cursor import BaseCursor
from psycopg_pool.base import BasePool
logger = logging.getLogger("psycopg")
if _psycopg:
connect = _psycopg.connect
execute = _psycopg.execute
+ fetch_many = _psycopg.fetch_many
+ pipeline_communicate = _psycopg.pipeline_communicate
+ send = _psycopg.send
else:
from . import generators
connect = generators.connect
execute = generators.execute
+ fetch_many = generators.fetch_many
+ pipeline_communicate = generators.pipeline_communicate
+ send = generators.send
class Notify(NamedTuple):
NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None]
NotifyHandler: TypeAlias = Callable[[Notify], None]
+PipelinePendingResult = Union[
+ None,
+ Tuple[
+ "BaseCursor[Any, Any]",
+ Optional[Tuple[Key, Prepare, bytes]],
+ ],
+]
+
class BasePipeline:
def __init__(self, pgconn: "PGconn") -> None:
self.pgconn = pgconn
+ self.command_queue = Deque[Command]()
+ self.result_queue = Deque[PipelinePendingResult]()
@property
def status(self) -> pq.PipelineStatus:
return pq.PipelineStatus(self.pgconn.pipeline_status)
+ def sync(self) -> None:
+ """Enqueue a PQpipelineSync() command."""
+ self.command_queue.append(self.pgconn.pipeline_sync)
+ self.result_queue.append(None)
+
def _enter(self) -> None:
self.pgconn.enter_pipeline_mode()
def _exit(self) -> None:
self.pgconn.exit_pipeline_mode()
+ def _communicate_gen(self) -> PQGen[None]:
+ """Communicate with pipeline to send commands and possibly fetch
+ results, which are then processed.
+ """
+ fetched = yield from pipeline_communicate(self.pgconn, self.command_queue)
+ to_process = [(self.result_queue.popleft(), results) for results in fetched]
+ for queued, results in to_process:
+ self._process_results(queued, results)
+
+ def _fetch_gen(self, *, flush: bool) -> PQGen[None]:
+ """Fetch available results from the connection and process them with
+ pipeline queued items.
+
+ If 'flush' is True, a PQsendFlushRequest() is issued in order to make
+ sure results can be fetched. Otherwise, the caller may emit a
+ PQpipelineSync() call to ensure the output buffer gets flushed before
+ fetching.
+ """
+ if not self.result_queue:
+ return
+
+ if flush:
+ self.pgconn.send_flush_request()
+ yield from send(self.pgconn)
+
+ to_process = []
+ while self.result_queue:
+ results = yield from fetch_many(self.pgconn)
+ if not results:
+ # No more results to fetch, but there may still be pending
+ # commands.
+ break
+ queued = self.result_queue.popleft()
+ to_process.append((queued, results))
+
+ for queued, results in to_process:
+ self._process_results(queued, results)
+
+ def _process_results(
+ self, queued: PipelinePendingResult, results: List["PGresult"]
+ ) -> None:
+ """Process a results set fetched from the current pipeline.
+
+ This matchs 'results' with its respective element in the pipeline
+ queue. For commands (None value in the pipeline queue), results are
+ checked directly. For prepare statement creation requests, update the
+ cache. Otherwise, results are attached to their respective cursor.
+ """
+ if queued is None:
+ (result,) = results
+ if result.status == ExecStatus.FATAL_ERROR:
+ raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
+ elif result.status == ExecStatus.PIPELINE_ABORTED:
+ raise e.OperationalError("pipeline aborted")
+ else:
+ cursor, prepinfo = queued
+ cursor._check_results(results)
+ if not cursor._results:
+ cursor._results = results
+ cursor._set_current_result(0)
+ else:
+ cursor._results.extend(results)
+ if prepinfo:
+ key, prep, name = prepinfo
+ # Update the prepare state of the query.
+ cursor._conn._prepared.validate(key, prep, name, results)
+
class Pipeline(BasePipeline):
"""Handler for connection in pipeline mode."""
def _exec_command(
self, command: Query, result_format: Format = Format.TEXT
- ) -> PQGen["PGresult"]:
+ ) -> PQGen[Optional["PGresult"]]:
"""
Generator to send a command and receive the result to the backend.
elif isinstance(command, Composable):
command = command.as_bytes(self)
+ if self._pipeline:
+ if result_format == Format.TEXT:
+ cmd = partial(self.pgconn.send_query, command)
+ else:
+ cmd = partial(
+ self.pgconn.send_query_params,
+ command,
+ None,
+ result_format=result_format,
+ )
+ self._pipeline.command_queue.append(cmd)
+ self._pipeline.result_queue.append(None)
+ return None
+
if result_format == Format.TEXT:
self.pgconn.send_query(command)
else:
pipeline = self._pipeline = Pipeline(self.pgconn)
try:
with pipeline:
- yield
+ try:
+ yield
+ finally:
+ with self.lock:
+ pipeline.sync()
+ try:
+ # Send an pending commands (e.g. COMMIT or Sync);
+ # while processing results, we might get errors...
+ self.wait(pipeline._communicate_gen())
+ finally:
+ # then fetch all remaining results but without forcing
+ # flush since we emitted a sync just before.
+ self.wait(pipeline._fetch_gen(flush=False))
finally:
assert pipeline.status == pq.PipelineStatus.OFF, pipeline.status
self._pipeline = None
pipeline = self._pipeline = AsyncPipeline(self.pgconn)
try:
async with pipeline:
- yield
+ try:
+ yield
+ finally:
+ async with self.lock:
+ pipeline.sync()
+ try:
+ # Send an pending commands (e.g. COMMIT or Sync);
+ # while processing results, we might get errors...
+ await self.wait(pipeline._communicate_gen())
+ finally:
+ # then fetch all remaining results but without forcing
+ # flush since we emitted a sync just before.
+ await self.wait(pipeline._fetch_gen(flush=False))
finally:
assert pipeline.status == PipelineStatus.OFF, pipeline.status
self._pipeline = None
# Copyright (C) 2020 The Psycopg Team
+from functools import partial
from types import TracebackType
from typing import Any, Generic, Iterable, Iterator, List
from typing import Optional, NoReturn, Sequence, Type, TypeVar, TYPE_CHECKING
results = yield from self._maybe_prepare_gen(
pgq, prepare=prepare, binary=binary
)
- self._check_results(results)
- self._results = results
- self._set_current_result(0)
+ if self._conn._pipeline:
+ yield from self._conn._pipeline._communicate_gen()
+ else:
+ assert results is not None
+ self._check_results(results)
+ self._results = results
+ self._set_current_result(0)
+
self._last_query = query
for cmd in self._conn._prepared.get_maintenance_commands():
pgq.dump(params)
results = yield from self._maybe_prepare_gen(pgq, prepare=True)
- self._check_results(results)
- if returning and results[0].status == ExecStatus.TUPLES_OK:
- self._results.extend(results)
- for res in results:
- nrows += res.command_tuples or 0
+ if self._conn._pipeline:
+ yield from self._conn._pipeline._communicate_gen()
+ else:
+ assert results is not None
+ self._check_results(results)
+ if returning and results[0].status == ExecStatus.TUPLES_OK:
+ self._results.extend(results)
- if self._results:
- self._set_current_result(0)
+ for res in results:
+ nrows += res.command_tuples or 0
+
+ if not self._conn._pipeline:
+ if self._results:
+ self._set_current_result(0)
+
+ # Override rowcount for the first result. Calls to nextset() will
+ # change it to the value of that result only, but we hope nobody
+ # will notice.
+ # You haven't read this comment.
+ self._rowcount = nrows
- # Override rowcount for the first result. Calls to nextset() will change
- # it to the value of that result only, but we hope nobody will notice.
- # You haven't read this comment.
- self._rowcount = nrows
self._last_query = query
for cmd in self._conn._prepared.get_maintenance_commands():
*,
prepare: Optional[bool] = None,
binary: Optional[bool] = None,
- ) -> PQGen[List["PGresult"]]:
+ ) -> PQGen[Optional[List["PGresult"]]]:
# Check if the query is prepared or needs preparing
prep, name = self._conn._prepared.get(pgq, prepare)
if prep is Prepare.NO:
# If the query is not already prepared, prepare it.
if prep is Prepare.SHOULD:
self._send_prepare(name, pgq)
- (result,) = yield from execute(self._pgconn)
- if result.status == ExecStatus.FATAL_ERROR:
- raise e.error_from_result(result, encoding=self._encoding)
+ if not self._conn._pipeline:
+ (result,) = yield from execute(self._pgconn)
+ if result.status == ExecStatus.FATAL_ERROR:
+ raise e.error_from_result(result, encoding=self._encoding)
# Then execute it.
self._send_query_prepared(name, pgq, binary=binary)
- # run the query
- results = yield from execute(self._pgconn)
-
# Update the prepare state of the query.
# If an operation requires to flush our prepared statements cache,
# it will be added to the maintenance commands to execute later.
key = self._conn._prepared.maybe_add_to_cache(pgq, prep, name)
+
+ if self._conn._pipeline:
+ queued = None
+ if key is not None:
+ queued = (key, prep, name)
+ self._conn._pipeline.result_queue.append((self, queued))
+ return None
+
+ # run the query
+ results = yield from execute(self._pgconn)
+
if key is not None:
self._conn._prepared.validate(key, prep, name, results)
self._query = query
if query.params or no_pqexec or fmt == Format.BINARY:
- self._pgconn.send_query_params(
- query.query,
- query.params,
- param_formats=query.formats,
- param_types=query.types,
- result_format=fmt,
- )
+ if self._conn._pipeline:
+ self._conn._pipeline.command_queue.append(
+ partial(
+ self._pgconn.send_query_params,
+ query.query,
+ query.params,
+ param_formats=query.formats,
+ param_types=query.types,
+ result_format=fmt,
+ )
+ )
+ else:
+ self._pgconn.send_query_params(
+ query.query,
+ query.params,
+ param_formats=query.formats,
+ param_types=query.types,
+ result_format=fmt,
+ )
else:
# if we don't have to, let's use exec_ as it can run more than
# one query in one go
- self._pgconn.send_query(query.query)
+ if self._conn._pipeline:
+ self._conn._pipeline.command_queue.append(
+ partial(self._pgconn.send_query, query.query)
+ )
+ else:
+ self._pgconn.send_query(query.query)
def _convert_query(
self, query: Query, params: Optional[Params] = None
self._rowcount = nrows if nrows is not None else -1
def _send_prepare(self, name: bytes, query: PostgresQuery) -> None:
- self._pgconn.send_prepare(name, query.query, param_types=query.types)
+ if self._conn._pipeline:
+ self._conn._pipeline.command_queue.append(
+ partial(
+ self._pgconn.send_prepare,
+ name,
+ query.query,
+ param_types=query.types,
+ )
+ )
+ self._conn._pipeline.result_queue.append(None)
+ else:
+ self._pgconn.send_prepare(name, query.query, param_types=query.types)
def _send_query_prepared(
self, name: bytes, pgq: PostgresQuery, *, binary: Optional[bool] = None
else:
fmt = Format.BINARY if binary else Format.TEXT
- self._pgconn.send_query_prepared(
- name, pgq.params, param_formats=pgq.formats, result_format=fmt
- )
+ if self._conn._pipeline:
+ self._conn._pipeline.command_queue.append(
+ partial(
+ self._pgconn.send_query_prepared,
+ name,
+ pgq.params,
+ param_formats=pgq.formats,
+ result_format=fmt,
+ )
+ )
+ else:
+ self._pgconn.send_query_prepared(
+ name, pgq.params, param_formats=pgq.formats, result_format=fmt
+ )
def _check_result_for_fetch(self) -> None:
if self.closed:
:rtype: Optional[Row], with Row defined by `row_factory`
"""
+ self._fetch_pipeline()
self._check_result_for_fetch()
record = self._tx.load_row(self._pos, self._make_row)
if record is not None:
:rtype: Sequence[Row], with Row defined by `row_factory`
"""
+ self._fetch_pipeline()
self._check_result_for_fetch()
assert self.pgresult
:rtype: Sequence[Row], with Row defined by `row_factory`
"""
+ self._fetch_pipeline()
self._check_result_for_fetch()
assert self.pgresult
records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
return records
def __iter__(self) -> Iterator[Row]:
+ self._fetch_pipeline()
self._check_result_for_fetch()
def load(pos: int) -> Optional[Row]:
Raise `!IndexError` in case a scroll operation would leave the result
set. In this case the position will not change.
"""
+ self._fetch_pipeline()
self._scroll(value, mode)
@contextmanager
with Copy(self) as copy:
yield copy
+
+ def _fetch_pipeline(self) -> None:
+ if not self.pgresult and self._conn._pipeline:
+ with self._conn.lock:
+ self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))
+ assert self.pgresult
first = False
async def fetchone(self) -> Optional[Row]:
+ await self._fetch_pipeline()
self._check_result_for_fetch()
rv = self._tx.load_row(self._pos, self._make_row)
if rv is not None:
return rv
async def fetchmany(self, size: int = 0) -> List[Row]:
+ await self._fetch_pipeline()
self._check_result_for_fetch()
assert self.pgresult
return records
async def fetchall(self) -> List[Row]:
+ await self._fetch_pipeline()
self._check_result_for_fetch()
assert self.pgresult
records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
return records
async def __aiter__(self) -> AsyncIterator[Row]:
+ await self._fetch_pipeline()
self._check_result_for_fetch()
def load(pos: int) -> Optional[Row]:
async with AsyncCopy(self) as copy:
yield copy
+
+ async def _fetch_pipeline(self) -> None:
+ if not self.pgresult and self._conn._pipeline:
+ async with self._conn.lock:
+ await self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))
+ assert self.pgresult
"SELECT 1 FROM pg_catalog.pg_cursors WHERE name = {}"
).format(sql.Literal(self.name))
res = yield from cur._conn._exec_command(query)
+ # pipeline mode otherwise, unsupported here.
+ assert res is not None
if res.ntuples == 0:
return
sql.Identifier(self.name),
)
res = yield from cur._conn._exec_command(query, result_format=self._format)
+ # pipeline mode otherwise, unsupported here.
+ assert res is not None
cur.pgresult = res
cur._tx.set_pgresult(res, set_loaders=False)
raise e.error_from_result(r)
+def pipeline_demo(rows_to_send: int, logger: logging.Logger) -> None:
+ """Pipeline demo using sync API."""
+ conn = Connection.connect()
+ conn.autocommit = True
+ conn.pgconn = LoggingPGconn(conn.pgconn, logger) # type: ignore[assignment]
+ with conn.pipeline():
+ with conn.transaction():
+ conn.execute("DROP TABLE IF EXISTS pq_pipeline_demo")
+ conn.execute(
+ "CREATE UNLOGGED TABLE pq_pipeline_demo("
+ " id serial primary key,"
+ " itemno integer,"
+ " int8filler int8"
+ ")"
+ )
+ for r in range(rows_to_send, 0, -1):
+ conn.execute(
+ "INSERT INTO pq_pipeline_demo(itemno, int8filler)"
+ " VALUES (%s, %s)",
+ (r, 1 << 62),
+ )
+
+
+async def pipeline_demo_async(rows_to_send: int, logger: logging.Logger) -> None:
+ """Pipeline demo using async API."""
+ aconn = await AsyncConnection.connect()
+ await aconn.set_autocommit(True)
+ aconn.pgconn = LoggingPGconn(aconn.pgconn, logger) # type: ignore[assignment]
+ async with aconn.pipeline():
+ async with aconn.transaction():
+ await aconn.execute("DROP TABLE IF EXISTS pq_pipeline_demo")
+ await aconn.execute(
+ "CREATE UNLOGGED TABLE pq_pipeline_demo("
+ " id serial primary key,"
+ " itemno integer,"
+ " int8filler int8"
+ ")"
+ )
+ for r in range(rows_to_send, 0, -1):
+ await aconn.execute(
+ "INSERT INTO pq_pipeline_demo(itemno, int8filler)"
+ " VALUES (%s, %s)",
+ (r, 1 << 62),
+ )
+
+
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
type=int,
help="number of rows to insert",
)
+ parser.add_argument(
+ "--pq", action="store_true", help="use low-level psycopg.pq API"
+ )
parser.add_argument(
"--async", dest="async_", action="store_true", help="use async API"
)
else:
logger.addHandler(logging.StreamHandler())
pipeline_logger.addHandler(logging.StreamHandler())
- if args.async_:
- asyncio.run(pipeline_demo_pq_async(args.nrows, pipeline_logger))
+ if args.pq:
+ if args.async_:
+ asyncio.run(pipeline_demo_pq_async(args.nrows, pipeline_logger))
+ else:
+ pipeline_demo_pq(args.nrows, pipeline_logger)
else:
- pipeline_demo_pq(args.nrows, pipeline_logger)
+ if pq.__impl__ != "python":
+ parser.error(
+ "only supported for Python implementation (set PSYCOPG_IMPL=python)"
+ )
+ if args.async_:
+ asyncio.run(pipeline_demo_async(args.nrows, pipeline_logger))
+ else:
+ pipeline_demo(args.nrows, pipeline_logger)
if __name__ == "__main__":
+import concurrent.futures
+
import pytest
import psycopg
from psycopg import pq
-from psycopg.errors import ProgrammingError
+from psycopg.errors import (
+ OperationalError,
+ ProgrammingError,
+ UndefinedColumn,
+ UndefinedTable,
+)
pytestmark = pytest.mark.libpq(">= 14")
with conn.cursor(name="pipeline") as cur, conn.pipeline():
with pytest.raises(psycopg.NotSupportedError):
cur.execute("select 1")
+
+
+def test_cannot_insert_multiple_commands(conn):
+ with pytest.raises(psycopg.errors.SyntaxError) as cm:
+ with conn.pipeline():
+ conn.execute("select 1; select 2")
+ assert cm.value.sqlstate == "42601"
+
+
+def test_pipeline_processed_at_exit(conn):
+ with conn.cursor() as cur:
+ with conn.pipeline():
+ cur.execute("select 1")
+
+ # PQsendQuery[BEGIN], PQsendQuery
+ assert len(conn._pipeline.result_queue) == 2
+
+ assert cur.fetchone() == (1,)
+
+
+def test_pipeline_errors_processed_at_exit(conn):
+ conn.autocommit = True
+ with pytest.raises((OperationalError, UndefinedTable)):
+ with conn.pipeline():
+ conn.execute("select * from nosuchtable")
+ conn.execute("create table voila ()")
+ cur = conn.execute(
+ "select count(*) from pg_tables where tablename = %s", ("voila",)
+ )
+ (count,) = cur.fetchone()
+ assert count == 0
+
+
+def test_pipeline(conn):
+ with conn.pipeline():
+ c1 = conn.cursor()
+ c2 = conn.cursor()
+ c1.execute("select 1")
+ c2.execute("select 2")
+
+ # PQsendQuery[BEGIN], PQsendQuery(2)
+ assert len(conn._pipeline.result_queue) == 3
+
+ (r1,) = c1.fetchone()
+ assert r1 == 1
+
+ (r2,) = c2.fetchone()
+ assert r2 == 2
+
+
+def test_autocommit(conn):
+ conn.autocommit = True
+ with conn.pipeline(), conn.cursor() as c:
+ c.execute("select 1")
+
+ (r,) = c.fetchone()
+ assert r == 1
+
+
+def test_pipeline_aborted(conn):
+ conn.autocommit = True
+ with conn.pipeline():
+ c1 = conn.execute("select 1")
+ with pytest.raises(UndefinedTable):
+ conn.execute("select * from doesnotexist").fetchone()
+ with pytest.raises(OperationalError, match="pipeline aborted"):
+ conn.execute("select 'aborted'").fetchone()
+ # Sync restore the connection in usable state.
+ conn._pipeline.sync()
+ c2 = conn.execute("select 2")
+
+ (r,) = c1.fetchone()
+ assert r == 1
+
+ (r,) = c2.fetchone()
+ assert r == 2
+
+
+def test_pipeline_commit_aborted(conn):
+ with pytest.raises((UndefinedColumn, OperationalError)):
+ with conn.pipeline():
+ conn.execute("select error")
+ conn.execute("create table voila ()")
+ conn.commit()
+
+
+def test_executemany(conn):
+ conn.autocommit = True
+ conn.execute("drop table if exists execmanypipeline")
+ conn.execute(
+ "create unlogged table execmanypipeline ("
+ " id serial primary key, num integer)"
+ )
+ with conn.pipeline(), conn.cursor() as cur:
+ cur.executemany(
+ "insert into execmanypipeline(num) values (%s) returning id",
+ [(10,), (20,)],
+ )
+ assert cur.fetchone() == (1,)
+ assert cur.nextset()
+ assert cur.fetchone() == (2,)
+ assert cur.nextset() is None
+
+
+def test_prepared(conn):
+ conn.autocommit = True
+ with conn.pipeline():
+ c1 = conn.execute("select %s::int", [10], prepare=True)
+ c2 = conn.execute("select count(*) from pg_prepared_statements")
+
+ (r,) = c1.fetchone()
+ assert r == 10
+
+ (r,) = c2.fetchone()
+ assert r == 1
+
+
+def test_auto_prepare(conn):
+ conn.autocommit = True
+ conn.prepared_threshold = 5
+ with conn.pipeline():
+ cursors = [
+ conn.execute("select count(*) from pg_prepared_statements")
+ for i in range(10)
+ ]
+
+ assert len(conn._prepared._names) == 1
+
+ res = [c.fetchone()[0] for c in cursors]
+ assert res == [0] * 5 + [1] * 5
+
+
+def test_transaction(conn):
+ with conn.pipeline():
+ with conn.transaction():
+ cur = conn.execute("select 'tx'")
+
+ (r,) = cur.fetchone()
+ assert r == "tx"
+
+ with conn.transaction():
+ cur = conn.execute("select 'rb'")
+ raise psycopg.Rollback()
+
+ (r,) = cur.fetchone()
+ assert r == "rb"
+
+
+def test_transaction_nested(conn):
+ with conn.pipeline():
+ with conn.transaction():
+ outer = conn.execute("select 'outer'")
+ with pytest.raises(ZeroDivisionError):
+ with conn.transaction():
+ inner = conn.execute("select 'inner'")
+ 1 / 0
+
+ (r,) = outer.fetchone()
+ assert r == "outer"
+ (r,) = inner.fetchone()
+ assert r == "inner"
+
+
+def test_outer_transaction(conn):
+ with conn.transaction():
+ with conn.pipeline():
+ conn.execute("drop table if exists outertx")
+ conn.execute("create table outertx as (select 1)")
+ cur = conn.execute("select * from outertx")
+ (r,) = cur.fetchone()
+ assert r == 1
+ cur = conn.execute("select count(*) from pg_tables where tablename = 'outertx'")
+ assert cur.fetchone()[0] == 1
+
+
+def test_outer_transaction_error(conn):
+ with conn.transaction():
+ with pytest.raises((UndefinedColumn, OperationalError)):
+ with conn.pipeline():
+ conn.execute("select error")
+ conn.execute("create table voila ()")
+
+
+def test_concurrency(conn):
+ with conn.transaction():
+ conn.execute("drop table if exists pipeline_concurrency")
+ conn.execute(
+ "create unlogged table pipeline_concurrency ("
+ " id serial primary key,"
+ " value integer"
+ ")"
+ )
+ conn.execute("drop table if exists accessed")
+ conn.execute("create unlogged table accessed as (select now() as value)")
+
+ def update(value):
+ cur = conn.execute(
+ "insert into pipeline_concurrency(value) values (%s) returning id",
+ (value,),
+ )
+ conn.execute("update accessed set value = now()")
+ return cur
+
+ conn.autocommit = True
+
+ (before,) = conn.execute("select value from accessed").fetchone()
+
+ values = range(1, 10)
+ with conn.pipeline():
+ with concurrent.futures.ThreadPoolExecutor() as e:
+ cursors = e.map(update, values, timeout=len(values))
+ assert sum(cur.fetchone()[0] for cur in cursors) == sum(values)
+
+ (s,) = conn.execute("select sum(value) from pipeline_concurrency").fetchone()
+ assert s == sum(values)
+ (after,) = conn.execute("select value from accessed").fetchone()
+ assert after > before
+import asyncio
+
import pytest
import psycopg
from psycopg import pq
-from psycopg.errors import ProgrammingError
+from psycopg.errors import (
+ OperationalError,
+ ProgrammingError,
+ UndefinedColumn,
+ UndefinedTable,
+)
pytestmark = [
pytest.mark.libpq(">= 14"),
async with aconn.cursor(name="pipeline") as cur, aconn.pipeline():
with pytest.raises(psycopg.NotSupportedError):
await cur.execute("select 1")
+
+
+async def test_cannot_insert_multiple_commands(aconn):
+ with pytest.raises(psycopg.errors.SyntaxError) as cm:
+ async with aconn.pipeline():
+ await aconn.execute("select 1; select 2")
+ assert cm.value.sqlstate == "42601"
+
+
+async def test_pipeline_processed_at_exit(aconn):
+ async with aconn.cursor() as cur:
+ async with aconn.pipeline():
+ await cur.execute("select 1")
+
+ # PQsendQuery[BEGIN], PQsendQuery
+ assert len(aconn._pipeline.result_queue) == 2
+
+ assert await cur.fetchone() == (1,)
+
+
+async def test_pipeline_errors_processed_at_exit(aconn):
+ await aconn.set_autocommit(True)
+ with pytest.raises((OperationalError, UndefinedTable)):
+ async with aconn.pipeline():
+ await aconn.execute("select * from nosuchtable")
+ await aconn.execute("create table voila ()")
+ cur = await aconn.execute(
+ "select count(*) from pg_tables where tablename = %s", ("voila",)
+ )
+ (count,) = await cur.fetchone()
+ assert count == 0
+
+
+async def test_pipeline(aconn):
+ async with aconn.pipeline():
+ c1 = aconn.cursor()
+ c2 = aconn.cursor()
+ await c1.execute("select 1")
+ await c2.execute("select 2")
+
+ # PQsendQuery[BEGIN], PQsendQuery(2)
+ assert len(aconn._pipeline.result_queue) == 3
+
+ (r1,) = await c1.fetchone()
+ assert r1 == 1
+
+ (r2,) = await c2.fetchone()
+ assert r2 == 2
+
+
+async def test_autocommit(aconn):
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline(), aconn.cursor() as c:
+ await c.execute("select 1")
+
+ (r,) = await c.fetchone()
+ assert r == 1
+
+
+async def test_pipeline_aborted(aconn):
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline():
+ c1 = await aconn.execute("select 1")
+ with pytest.raises(UndefinedTable):
+ await (await aconn.execute("select * from doesnotexist")).fetchone()
+ with pytest.raises(OperationalError, match="pipeline aborted"):
+ await (await aconn.execute("select 'aborted'")).fetchone()
+ # Sync restore the connection in usable state.
+ aconn._pipeline.sync()
+ c2 = await aconn.execute("select 2")
+
+ (r,) = await c1.fetchone()
+ assert r == 1
+
+ (r,) = await c2.fetchone()
+ assert r == 2
+
+
+async def test_pipeline_commit_aborted(aconn):
+ with pytest.raises((UndefinedColumn, OperationalError)):
+ async with aconn.pipeline():
+ await aconn.execute("select error")
+ await aconn.execute("create table voila ()")
+ await aconn.commit()
+
+
+async def test_executemany(aconn):
+ await aconn.set_autocommit(True)
+ await aconn.execute("drop table if exists execmanypipeline")
+ await aconn.execute(
+ "create unlogged table execmanypipeline ("
+ " id serial primary key, num integer)"
+ )
+ async with aconn.pipeline(), aconn.cursor() as cur:
+ await cur.executemany(
+ "insert into execmanypipeline(num) values (%s) returning id",
+ [(10,), (20,)],
+ )
+ assert (await cur.fetchone()) == (1,)
+ assert cur.nextset()
+ assert (await cur.fetchone()) == (2,)
+ assert cur.nextset() is None
+
+
+async def test_prepared(aconn):
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline():
+ c1 = await aconn.execute("select %s::int", [10], prepare=True)
+ c2 = await aconn.execute("select count(*) from pg_prepared_statements")
+
+ (r,) = await c1.fetchone()
+ assert r == 10
+
+ (r,) = await c2.fetchone()
+ assert r == 1
+
+
+async def test_auto_prepare(aconn):
+ aconn.prepared_threshold = 5
+ async with aconn.pipeline():
+ cursors = [
+ await aconn.execute("select count(*) from pg_prepared_statements")
+ for i in range(10)
+ ]
+
+ assert len(aconn._prepared._names) == 1
+
+ res = [(await c.fetchone())[0] for c in cursors]
+ assert res == [0] * 5 + [1] * 5
+
+
+async def test_transaction(aconn):
+ async with aconn.pipeline():
+ async with aconn.transaction():
+ cur = await aconn.execute("select 'tx'")
+
+ (r,) = await cur.fetchone()
+ assert r == "tx"
+
+ async with aconn.transaction():
+ cur = await aconn.execute("select 'rb'")
+ raise psycopg.Rollback()
+
+ (r,) = await cur.fetchone()
+ assert r == "rb"
+
+
+async def test_transaction_nested(aconn):
+ async with aconn.pipeline():
+ async with aconn.transaction():
+ outer = await aconn.execute("select 'outer'")
+ with pytest.raises(ZeroDivisionError):
+ async with aconn.transaction():
+ inner = await aconn.execute("select 'inner'")
+ 1 / 0
+
+ (r,) = await outer.fetchone()
+ assert r == "outer"
+ (r,) = await inner.fetchone()
+ assert r == "inner"
+
+
+async def test_outer_transaction(aconn):
+ async with aconn.transaction():
+ async with aconn.pipeline():
+ await aconn.execute("drop table if exists outertx")
+ await aconn.execute("create table outertx as (select 1)")
+ cur = await aconn.execute("select * from outertx")
+ (r,) = await cur.fetchone()
+ assert r == 1
+ cur = await aconn.execute(
+ "select count(*) from pg_tables where tablename = 'outertx'"
+ )
+ assert (await cur.fetchone())[0] == 1
+
+
+async def test_outer_transaction_error(aconn):
+ async with aconn.transaction():
+ with pytest.raises((UndefinedColumn, OperationalError)):
+ async with aconn.pipeline():
+ await aconn.execute("select error")
+ await aconn.execute("create table voila ()")
+
+
+async def test_concurrency(aconn):
+ async with aconn.transaction():
+ await aconn.execute("drop table if exists pipeline_concurrency")
+ await aconn.execute(
+ "create unlogged table pipeline_concurrency ("
+ " id serial primary key,"
+ " value integer"
+ ")"
+ )
+ await aconn.execute("drop table if exists accessed")
+ await aconn.execute("create unlogged table accessed as (select now() as value)")
+
+ async def update(value):
+ cur = await aconn.execute(
+ "insert into pipeline_concurrency(value) values (%s) returning id",
+ (value,),
+ )
+ await aconn.execute("update accessed set value = now()")
+ return cur
+
+ await aconn.set_autocommit(True)
+
+ (before,) = await (await aconn.execute("select value from accessed")).fetchone()
+
+ values = range(1, 10)
+ async with aconn.pipeline():
+ cursors = await asyncio.wait_for(
+ asyncio.gather(*[update(value) for value in values]),
+ timeout=len(values),
+ )
+
+ assert sum([(await cur.fetchone())[0] for cur in cursors]) == sum(values)
+
+ (s,) = await (
+ await aconn.execute("select sum(value) from pipeline_concurrency")
+ ).fetchone()
+ assert s == sum(values)
+ (after,) = await (await aconn.execute("select value from accessed")).fetchone()
+ assert after > before