]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add Connection.pipeline() context manager and Pipeline object
authorDenis Laxalde <denis.laxalde@dalibo.com>
Wed, 8 Dec 2021 10:39:44 +0000 (11:39 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Apr 2022 23:17:57 +0000 (01:17 +0200)
On the connection, we store the pipeline object as _pipeline attribute
when in pipeline mode or None otherwise.

docs/api/connections.rst
psycopg/psycopg/__init__.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tests/test_pipeline.py [new file with mode: 0644]
tests/test_pipeline_async.py [new file with mode: 0644]

index 5f3230068dbd38eb419888c59772724df95a5597..666464473db90afa6061bc0f16f35bba84cfb2cd 100644 (file)
@@ -266,6 +266,8 @@ The `!Connection` class
 
     .. automethod:: fileno
 
+    .. automethod:: pipeline
+
 
     .. _tpc-methods:
 
@@ -447,6 +449,15 @@ The `!AsyncConnection` class
     .. automethod:: set_read_only
     .. automethod:: set_deferrable
 
+    .. automethod:: pipeline
+
+        .. note::
+
+            It must be called as::
+
+                async with conn.pipeline():
+                    ...
+
     .. automethod:: tpc_prepare
     .. automethod:: tpc_commit
     .. automethod:: tpc_rollback
index 6f4068a0b6dc5a301336e6aee63032d529c9f97b..4879bcdf6cfbbb760aef4750794d6f0fab8604ee 100644 (file)
@@ -18,11 +18,11 @@ from .errors import DataError, OperationalError, IntegrityError
 from .errors import InternalError, ProgrammingError, NotSupportedError
 from ._column import Column
 from .conninfo import ConnectionInfo
-from .connection import BaseConnection, Connection, Notify
+from .connection import BaseConnection, Connection, Notify, Pipeline
 from .transaction import Rollback, Transaction, AsyncTransaction
 from .cursor_async import AsyncCursor
 from .server_cursor import AsyncServerCursor, ServerCursor
-from .connection_async import AsyncConnection
+from .connection_async import AsyncConnection, AsyncPipeline
 
 from . import dbapi20
 from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING
@@ -59,6 +59,7 @@ __all__ = [
     "AsyncConnection",
     "AsyncCopy",
     "AsyncCursor",
+    "AsyncPipeline",
     "AsyncServerCursor",
     "AsyncTransaction",
     "BaseConnection",
@@ -69,6 +70,7 @@ __all__ = [
     "Cursor",
     "IsolationLevel",
     "Notify",
+    "Pipeline",
     "Rollback",
     "ServerCursor",
     "Transaction",
index a7abef9aa226af64ba6f85617f96aa4cea411218..2fa4ecf0ac52feaf51edb900f6719a2cfd631c19 100644 (file)
@@ -80,6 +80,37 @@ NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None]
 NotifyHandler: TypeAlias = Callable[[Notify], None]
 
 
+class BasePipeline:
+    def __init__(self, pgconn: "PGconn") -> None:
+        self.pgconn = pgconn
+
+    @property
+    def status(self) -> pq.PipelineStatus:
+        return pq.PipelineStatus(self.pgconn.pipeline_status)
+
+    def _enter(self) -> None:
+        self.pgconn.enter_pipeline_mode()
+
+    def _exit(self) -> None:
+        self.pgconn.exit_pipeline_mode()
+
+
+class Pipeline(BasePipeline):
+    """Handler for connection in pipeline mode."""
+
+    def __enter__(self) -> "Pipeline":
+        self._enter()
+        return self
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        self._exit()
+
+
 class BaseConnection(Generic[Row]):
     """
     Base class for different types of connections.
@@ -126,6 +157,8 @@ class BaseConnection(Generic[Row]):
         # apart a connection in the pool too (when _pool = None)
         self._pool: Optional["BasePool[Any]"]
 
+        self._pipeline: Optional[BasePipeline] = None
+
         # Time after which the connection should be closed
         self._expire_at: float
 
@@ -603,6 +636,8 @@ class Connection(BaseConnection[Row]):
     server_cursor_factory: Type[ServerCursor[Row]]
     row_factory: RowFactory[Row]
 
+    _pipeline: Optional[Pipeline]
+
     def __init__(self, pgconn: "PGconn", row_factory: Optional[RowFactory[Row]] = None):
         super().__init__(pgconn)
         self.row_factory = row_factory or cast(RowFactory[Row], tuple_row)
@@ -851,6 +886,20 @@ class Connection(BaseConnection[Row]):
                 n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
                 yield n
 
+    @contextmanager
+    def pipeline(self) -> Iterator[None]:
+        """Context manager to switch the connection into pipeline mode."""
+        if self._pipeline is not None:
+            raise e.ProgrammingError("already in pipeline mode")
+
+        pipeline = self._pipeline = Pipeline(self.pgconn)
+        try:
+            with pipeline:
+                yield
+        finally:
+            assert pipeline.status == pq.PipelineStatus.OFF, pipeline.status
+            self._pipeline = None
+
     def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
         """
         Consume a generator operating on the connection.
index e344778654c2a671f63a036dd90c3e11d0b600b3..daa4f157cdac29002421d2e00eefb5a8941c52d3 100644 (file)
@@ -14,7 +14,7 @@ from contextlib import asynccontextmanager
 
 from . import errors as e
 from . import waiting
-from .pq import Format, TransactionStatus
+from .pq import Format, PipelineStatus, TransactionStatus
 from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
 from ._tpc import Xid
 from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
@@ -22,7 +22,7 @@ from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from .conninfo import make_conninfo, conninfo_to_dict
 from ._encodings import pgconn_encoding
-from .connection import BaseConnection, CursorRow, Notify
+from .connection import BaseConnection, BasePipeline, CursorRow, Notify
 from .generators import notifies
 from .transaction import AsyncTransaction
 from .cursor_async import AsyncCursor
@@ -35,6 +35,22 @@ if TYPE_CHECKING:
 logger = logging.getLogger("psycopg")
 
 
+class AsyncPipeline(BasePipeline):
+    """Handler for async connection in pipeline mode."""
+
+    async def __aenter__(self) -> "AsyncPipeline":
+        self._enter()
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        self._exit()
+
+
 class AsyncConnection(BaseConnection[Row]):
     """
     Asynchronous wrapper for a connection to the database.
@@ -46,6 +62,8 @@ class AsyncConnection(BaseConnection[Row]):
     server_cursor_factory: Type[AsyncServerCursor[Row]]
     row_factory: AsyncRowFactory[Row]
 
+    _pipeline: Optional[AsyncPipeline]
+
     def __init__(
         self,
         pgconn: "PGconn",
@@ -292,6 +310,20 @@ class AsyncConnection(BaseConnection[Row]):
                 n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
                 yield n
 
+    @asynccontextmanager
+    async def pipeline(self) -> AsyncIterator[None]:
+        """Context manager to switch the connection into pipeline mode."""
+        if self._pipeline is not None:
+            raise e.ProgrammingError("already in pipeline mode")
+
+        pipeline = self._pipeline = AsyncPipeline(self.pgconn)
+        try:
+            async with pipeline:
+                yield
+        finally:
+            assert pipeline.status == PipelineStatus.OFF, pipeline.status
+            self._pipeline = None
+
     async def wait(self, gen: PQGen[RV]) -> RV:
         try:
             return await waiting.wait_async(gen, self.pgconn.socket)
diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py
new file mode 100644 (file)
index 0000000..0e0aa0d
--- /dev/null
@@ -0,0 +1,18 @@
+import pytest
+
+from psycopg import pq
+from psycopg.errors import ProgrammingError
+
+pytestmark = pytest.mark.libpq(">= 14")
+
+
+def test_pipeline_status(conn):
+    with conn.pipeline():
+        p = conn._pipeline
+        assert p is not None
+        assert p.status == pq.PipelineStatus.ON
+        with pytest.raises(ProgrammingError):
+            with conn.pipeline():
+                pass
+    assert p.status == pq.PipelineStatus.OFF
+    assert not conn._pipeline
diff --git a/tests/test_pipeline_async.py b/tests/test_pipeline_async.py
new file mode 100644 (file)
index 0000000..1b0843f
--- /dev/null
@@ -0,0 +1,21 @@
+import pytest
+
+from psycopg import pq
+from psycopg.errors import ProgrammingError
+
+pytestmark = [
+    pytest.mark.libpq(">= 14"),
+    pytest.mark.asyncio,
+]
+
+
+async def test_pipeline_status(aconn):
+    async with aconn.pipeline():
+        p = aconn._pipeline
+        assert p is not None
+        assert p.status == pq.PipelineStatus.ON
+        with pytest.raises(ProgrammingError):
+            async with aconn.pipeline():
+                pass
+    assert p.status == pq.PipelineStatus.OFF
+    assert not aconn._pipeline