From: Denis Laxalde Date: Wed, 8 Dec 2021 10:39:44 +0000 (+0100) Subject: Add Connection.pipeline() context manager and Pipeline object X-Git-Tag: 3.1~146^2~20 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=54b91ab2067f77972a4c663defbed17abf29e839;p=thirdparty%2Fpsycopg.git Add Connection.pipeline() context manager and Pipeline object On the connection, we store the pipeline object as _pipeline attribute when in pipeline mode or None otherwise. --- diff --git a/docs/api/connections.rst b/docs/api/connections.rst index 5f3230068..666464473 100644 --- a/docs/api/connections.rst +++ b/docs/api/connections.rst @@ -266,6 +266,8 @@ The `!Connection` class .. automethod:: fileno + .. automethod:: pipeline + .. _tpc-methods: @@ -447,6 +449,15 @@ The `!AsyncConnection` class .. automethod:: set_read_only .. automethod:: set_deferrable + .. automethod:: pipeline + + .. note:: + + It must be called as:: + + async with conn.pipeline(): + ... + .. automethod:: tpc_prepare .. automethod:: tpc_commit .. automethod:: tpc_rollback diff --git a/psycopg/psycopg/__init__.py b/psycopg/psycopg/__init__.py index 6f4068a0b..4879bcdf6 100644 --- a/psycopg/psycopg/__init__.py +++ b/psycopg/psycopg/__init__.py @@ -18,11 +18,11 @@ from .errors import DataError, OperationalError, IntegrityError from .errors import InternalError, ProgrammingError, NotSupportedError from ._column import Column from .conninfo import ConnectionInfo -from .connection import BaseConnection, Connection, Notify +from .connection import BaseConnection, Connection, Notify, Pipeline from .transaction import Rollback, Transaction, AsyncTransaction from .cursor_async import AsyncCursor from .server_cursor import AsyncServerCursor, ServerCursor -from .connection_async import AsyncConnection +from .connection_async import AsyncConnection, AsyncPipeline from . import dbapi20 from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING @@ -59,6 +59,7 @@ __all__ = [ "AsyncConnection", "AsyncCopy", "AsyncCursor", + "AsyncPipeline", "AsyncServerCursor", "AsyncTransaction", "BaseConnection", @@ -69,6 +70,7 @@ __all__ = [ "Cursor", "IsolationLevel", "Notify", + "Pipeline", "Rollback", "ServerCursor", "Transaction", diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index a7abef9aa..2fa4ecf0a 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -80,6 +80,37 @@ NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None] NotifyHandler: TypeAlias = Callable[[Notify], None] +class BasePipeline: + def __init__(self, pgconn: "PGconn") -> None: + self.pgconn = pgconn + + @property + def status(self) -> pq.PipelineStatus: + return pq.PipelineStatus(self.pgconn.pipeline_status) + + def _enter(self) -> None: + self.pgconn.enter_pipeline_mode() + + def _exit(self) -> None: + self.pgconn.exit_pipeline_mode() + + +class Pipeline(BasePipeline): + """Handler for connection in pipeline mode.""" + + def __enter__(self) -> "Pipeline": + self._enter() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self._exit() + + class BaseConnection(Generic[Row]): """ Base class for different types of connections. @@ -126,6 +157,8 @@ class BaseConnection(Generic[Row]): # apart a connection in the pool too (when _pool = None) self._pool: Optional["BasePool[Any]"] + self._pipeline: Optional[BasePipeline] = None + # Time after which the connection should be closed self._expire_at: float @@ -603,6 +636,8 @@ class Connection(BaseConnection[Row]): server_cursor_factory: Type[ServerCursor[Row]] row_factory: RowFactory[Row] + _pipeline: Optional[Pipeline] + def __init__(self, pgconn: "PGconn", row_factory: Optional[RowFactory[Row]] = None): super().__init__(pgconn) self.row_factory = row_factory or cast(RowFactory[Row], tuple_row) @@ -851,6 +886,20 @@ class Connection(BaseConnection[Row]): n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) yield n + @contextmanager + def pipeline(self) -> Iterator[None]: + """Context manager to switch the connection into pipeline mode.""" + if self._pipeline is not None: + raise e.ProgrammingError("already in pipeline mode") + + pipeline = self._pipeline = Pipeline(self.pgconn) + try: + with pipeline: + yield + finally: + assert pipeline.status == pq.PipelineStatus.OFF, pipeline.status + self._pipeline = None + def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV: """ Consume a generator operating on the connection. diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index e34477865..daa4f157c 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -14,7 +14,7 @@ from contextlib import asynccontextmanager from . import errors as e from . import waiting -from .pq import Format, TransactionStatus +from .pq import Format, PipelineStatus, TransactionStatus from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV from ._tpc import Xid from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row @@ -22,7 +22,7 @@ from .adapt import AdaptersMap from ._enums import IsolationLevel from .conninfo import make_conninfo, conninfo_to_dict from ._encodings import pgconn_encoding -from .connection import BaseConnection, CursorRow, Notify +from .connection import BaseConnection, BasePipeline, CursorRow, Notify from .generators import notifies from .transaction import AsyncTransaction from .cursor_async import AsyncCursor @@ -35,6 +35,22 @@ if TYPE_CHECKING: logger = logging.getLogger("psycopg") +class AsyncPipeline(BasePipeline): + """Handler for async connection in pipeline mode.""" + + async def __aenter__(self) -> "AsyncPipeline": + self._enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self._exit() + + class AsyncConnection(BaseConnection[Row]): """ Asynchronous wrapper for a connection to the database. @@ -46,6 +62,8 @@ class AsyncConnection(BaseConnection[Row]): server_cursor_factory: Type[AsyncServerCursor[Row]] row_factory: AsyncRowFactory[Row] + _pipeline: Optional[AsyncPipeline] + def __init__( self, pgconn: "PGconn", @@ -292,6 +310,20 @@ class AsyncConnection(BaseConnection[Row]): n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) yield n + @asynccontextmanager + async def pipeline(self) -> AsyncIterator[None]: + """Context manager to switch the connection into pipeline mode.""" + if self._pipeline is not None: + raise e.ProgrammingError("already in pipeline mode") + + pipeline = self._pipeline = AsyncPipeline(self.pgconn) + try: + async with pipeline: + yield + finally: + assert pipeline.status == PipelineStatus.OFF, pipeline.status + self._pipeline = None + async def wait(self, gen: PQGen[RV]) -> RV: try: return await waiting.wait_async(gen, self.pgconn.socket) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 000000000..0e0aa0d56 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,18 @@ +import pytest + +from psycopg import pq +from psycopg.errors import ProgrammingError + +pytestmark = pytest.mark.libpq(">= 14") + + +def test_pipeline_status(conn): + with conn.pipeline(): + p = conn._pipeline + assert p is not None + assert p.status == pq.PipelineStatus.ON + with pytest.raises(ProgrammingError): + with conn.pipeline(): + pass + assert p.status == pq.PipelineStatus.OFF + assert not conn._pipeline diff --git a/tests/test_pipeline_async.py b/tests/test_pipeline_async.py new file mode 100644 index 000000000..1b0843fe8 --- /dev/null +++ b/tests/test_pipeline_async.py @@ -0,0 +1,21 @@ +import pytest + +from psycopg import pq +from psycopg.errors import ProgrammingError + +pytestmark = [ + pytest.mark.libpq(">= 14"), + pytest.mark.asyncio, +] + + +async def test_pipeline_status(aconn): + async with aconn.pipeline(): + p = aconn._pipeline + assert p is not None + assert p.status == pq.PipelineStatus.ON + with pytest.raises(ProgrammingError): + async with aconn.pipeline(): + pass + assert p.status == pq.PipelineStatus.OFF + assert not aconn._pipeline