From: Denis Laxalde Date: Sun, 27 Mar 2022 10:56:40 +0000 (+0200) Subject: feat: let Connection.pipeline() return the Pipeline object X-Git-Tag: 3.1~146^2~7 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=862ae37db5a6b980192b9e8bcb899af7944a2e41;p=thirdparty%2Fpsycopg.git feat: let Connection.pipeline() return the Pipeline object In tests, add a type annotation on 'conn'/'aconn' fixture so that mypy understands that pipeline() yields a Pipeline object. --- diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 1a0b9393a..0bdfc6b02 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -868,7 +868,7 @@ class Connection(BaseConnection[Row]): yield n @contextmanager - def pipeline(self) -> Iterator[None]: + def pipeline(self) -> Iterator[Pipeline]: """Context manager to switch the connection into pipeline mode.""" with self.lock: if self._pipeline is None: @@ -881,13 +881,13 @@ class Connection(BaseConnection[Row]): if not pipeline: # No-op re-entered inner pipeline block. - yield + yield self._pipeline return try: with pipeline: try: - yield + yield pipeline finally: with self.lock: pipeline.sync() diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 2c0321bb6..592aa474f 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -296,7 +296,7 @@ class AsyncConnection(BaseConnection[Row]): yield n @asynccontextmanager - async def pipeline(self) -> AsyncIterator[None]: + async def pipeline(self) -> AsyncIterator[AsyncPipeline]: """Context manager to switch the connection into pipeline mode.""" async with self.lock: if self._pipeline is None: @@ -309,13 +309,13 @@ class AsyncConnection(BaseConnection[Row]): if not pipeline: # No-op re-entered inner pipeline block. - yield + yield self._pipeline return try: async with pipeline: try: - yield + yield pipeline finally: async with self.lock: pipeline.sync() diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 2e0f8f6ce..331db905e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,3 +1,4 @@ +from typing import Any import concurrent.futures import pytest @@ -9,26 +10,24 @@ from psycopg import errors as e pytestmark = pytest.mark.libpq(">= 14") -def test_pipeline_status(conn): +def test_pipeline_status(conn: psycopg.Connection[Any]) -> None: assert conn._pipeline is None - with conn.pipeline(): - p = conn._pipeline - assert p is not None + with conn.pipeline() as p: + assert conn._pipeline is p assert p.status == pq.PipelineStatus.ON assert p.status == pq.PipelineStatus.OFF assert not conn._pipeline -def test_pipeline_reenter(conn): - with conn.pipeline(): - p = conn._pipeline - with conn.pipeline(): - assert conn._pipeline is p - assert p.status == pq.PipelineStatus.ON - assert conn._pipeline is p - assert p.status == pq.PipelineStatus.ON +def test_pipeline_reenter(conn: psycopg.Connection[Any]) -> None: + with conn.pipeline() as p1: + with conn.pipeline() as p2: + assert p2 is p1 + assert p1.status == pq.PipelineStatus.ON + assert p2 is p1 + assert p2.status == pq.PipelineStatus.ON assert conn._pipeline is None - assert p.status == pq.PipelineStatus.OFF + assert p1.status == pq.PipelineStatus.OFF def test_cursor_stream(conn): @@ -52,11 +51,11 @@ def test_cannot_insert_multiple_commands(conn): def test_pipeline_processed_at_exit(conn): with conn.cursor() as cur: - with conn.pipeline(): + with conn.pipeline() as p: cur.execute("select 1") # PQsendQuery[BEGIN], PQsendQuery - assert len(conn._pipeline.result_queue) == 2 + assert len(p.result_queue) == 2 assert cur.fetchone() == (1,) @@ -75,14 +74,14 @@ def test_pipeline_errors_processed_at_exit(conn): def test_pipeline(conn): - with conn.pipeline(): + with conn.pipeline() as p: c1 = conn.cursor() c2 = conn.cursor() c1.execute("select 1") c2.execute("select 2") # PQsendQuery[BEGIN], PQsendQuery(2) - assert len(conn._pipeline.result_queue) == 3 + assert len(p.result_queue) == 3 (r1,) = c1.fetchone() assert r1 == 1 @@ -102,14 +101,14 @@ def test_autocommit(conn): def test_pipeline_aborted(conn): conn.autocommit = True - with conn.pipeline(): + with conn.pipeline() as p: c1 = conn.execute("select 1") with pytest.raises(e.UndefinedTable): conn.execute("select * from doesnotexist").fetchone() with pytest.raises(e.OperationalError, match="pipeline aborted"): conn.execute("select 'aborted'").fetchone() # Sync restore the connection in usable state. - conn._pipeline.sync() + p.sync() c2 = conn.execute("select 2") (r,) = c1.fetchone() diff --git a/tests/test_pipeline_async.py b/tests/test_pipeline_async.py index 1480e0592..bb78d0822 100644 --- a/tests/test_pipeline_async.py +++ b/tests/test_pipeline_async.py @@ -1,4 +1,5 @@ import asyncio +from typing import Any import pytest @@ -12,26 +13,24 @@ pytestmark = [ ] -async def test_pipeline_status(aconn): +async def test_pipeline_status(aconn: psycopg.AsyncConnection[Any]) -> None: assert aconn._pipeline is None - async with aconn.pipeline(): - p = aconn._pipeline - assert p is not None + async with aconn.pipeline() as p: + assert aconn._pipeline is p assert p.status == pq.PipelineStatus.ON assert p.status == pq.PipelineStatus.OFF assert not aconn._pipeline -async def test_pipeline_reenter(aconn): - async with aconn.pipeline(): - p = aconn._pipeline - async with aconn.pipeline(): - assert aconn._pipeline is p - assert p.status == pq.PipelineStatus.ON - assert aconn._pipeline is p - assert p.status == pq.PipelineStatus.ON +async def test_pipeline_reenter(aconn: psycopg.AsyncConnection[Any]) -> None: + async with aconn.pipeline() as p1: + async with aconn.pipeline() as p2: + assert p2 is p1 + assert p1.status == pq.PipelineStatus.ON + assert p2 is p1 + assert p2.status == pq.PipelineStatus.ON assert aconn._pipeline is None - assert p.status == pq.PipelineStatus.OFF + assert p1.status == pq.PipelineStatus.OFF async def test_cursor_stream(aconn): @@ -55,11 +54,11 @@ async def test_cannot_insert_multiple_commands(aconn): async def test_pipeline_processed_at_exit(aconn): async with aconn.cursor() as cur: - async with aconn.pipeline(): + async with aconn.pipeline() as p: await cur.execute("select 1") # PQsendQuery[BEGIN], PQsendQuery - assert len(aconn._pipeline.result_queue) == 2 + assert len(p.result_queue) == 2 assert await cur.fetchone() == (1,) @@ -78,14 +77,14 @@ async def test_pipeline_errors_processed_at_exit(aconn): async def test_pipeline(aconn): - async with aconn.pipeline(): + async with aconn.pipeline() as p: c1 = aconn.cursor() c2 = aconn.cursor() await c1.execute("select 1") await c2.execute("select 2") # PQsendQuery[BEGIN], PQsendQuery(2) - assert len(aconn._pipeline.result_queue) == 3 + assert len(p.result_queue) == 3 (r1,) = await c1.fetchone() assert r1 == 1 @@ -105,14 +104,14 @@ async def test_autocommit(aconn): async def test_pipeline_aborted(aconn): await aconn.set_autocommit(True) - async with aconn.pipeline(): + async with aconn.pipeline() as p: c1 = await aconn.execute("select 1") with pytest.raises(e.UndefinedTable): await (await aconn.execute("select * from doesnotexist")).fetchone() with pytest.raises(e.OperationalError, match="pipeline aborted"): await (await aconn.execute("select 'aborted'")).fetchone() # Sync restore the connection in usable state. - aconn._pipeline.sync() + p.sync() c2 = await aconn.execute("select 2") (r,) = await c1.fetchone()