From: Daniele Varrazzo Date: Sat, 26 Mar 2022 17:56:31 +0000 (+0100) Subject: fix: allow re-entering pipeline mode X-Git-Tag: 3.1~146^2~9 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=384d831874f79f2544a878eb35184adecb7be47e;p=thirdparty%2Fpsycopg.git fix: allow re-entering pipeline mode --- diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index dc440967b..6c09c1722 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -871,7 +871,9 @@ class Connection(BaseConnection[Row]): def pipeline(self) -> Iterator[None]: """Context manager to switch the connection into pipeline mode.""" if self._pipeline is not None: - raise e.ProgrammingError("already in pipeline mode") + # calling pipeline recursively is no-op. + yield + return pipeline = self._pipeline = Pipeline(self.pgconn) try: diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index a51291483..52709d670 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -299,7 +299,9 @@ class AsyncConnection(BaseConnection[Row]): async def pipeline(self) -> AsyncIterator[None]: """Context manager to switch the connection into pipeline mode.""" if self._pipeline is not None: - raise e.ProgrammingError("already in pipeline mode") + # calling pipeline recursively is no-op. + yield + return pipeline = self._pipeline = AsyncPipeline(self.pgconn) try: diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 99f404663..2e0f8f6ce 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -4,28 +4,33 @@ import pytest import psycopg from psycopg import pq -from psycopg.errors import ( - OperationalError, - ProgrammingError, - UndefinedColumn, - UndefinedTable, -) +from psycopg import errors as e pytestmark = pytest.mark.libpq(">= 14") def test_pipeline_status(conn): + assert conn._pipeline is None with conn.pipeline(): p = conn._pipeline assert p is not None assert p.status == pq.PipelineStatus.ON - with pytest.raises(ProgrammingError): - with conn.pipeline(): - pass assert p.status == pq.PipelineStatus.OFF assert not conn._pipeline +def test_pipeline_reenter(conn): + with conn.pipeline(): + p = conn._pipeline + with conn.pipeline(): + assert conn._pipeline is p + assert p.status == pq.PipelineStatus.ON + assert conn._pipeline is p + assert p.status == pq.PipelineStatus.ON + assert conn._pipeline is None + assert p.status == pq.PipelineStatus.OFF + + def test_cursor_stream(conn): with conn.pipeline(), conn.cursor() as cur: with pytest.raises(psycopg.ProgrammingError): @@ -58,7 +63,7 @@ def test_pipeline_processed_at_exit(conn): def test_pipeline_errors_processed_at_exit(conn): conn.autocommit = True - with pytest.raises((OperationalError, UndefinedTable)): + with pytest.raises((e.OperationalError, e.UndefinedTable)): with conn.pipeline(): conn.execute("select * from nosuchtable") conn.execute("create table voila ()") @@ -99,9 +104,9 @@ def test_pipeline_aborted(conn): conn.autocommit = True with conn.pipeline(): c1 = conn.execute("select 1") - with pytest.raises(UndefinedTable): + with pytest.raises(e.UndefinedTable): conn.execute("select * from doesnotexist").fetchone() - with pytest.raises(OperationalError, match="pipeline aborted"): + with pytest.raises(e.OperationalError, match="pipeline aborted"): conn.execute("select 'aborted'").fetchone() # Sync restore the connection in usable state. conn._pipeline.sync() @@ -115,7 +120,7 @@ def test_pipeline_aborted(conn): def test_pipeline_commit_aborted(conn): - with pytest.raises((UndefinedColumn, OperationalError)): + with pytest.raises((e.UndefinedColumn, e.OperationalError)): with conn.pipeline(): conn.execute("select error") conn.execute("create table voila ()") @@ -213,7 +218,7 @@ def test_outer_transaction(conn): def test_outer_transaction_error(conn): with conn.transaction(): - with pytest.raises((UndefinedColumn, OperationalError)): + with pytest.raises((e.UndefinedColumn, e.OperationalError)): with conn.pipeline(): conn.execute("select error") conn.execute("create table voila ()") diff --git a/tests/test_pipeline_async.py b/tests/test_pipeline_async.py index 5c3dedc94..1480e0592 100644 --- a/tests/test_pipeline_async.py +++ b/tests/test_pipeline_async.py @@ -4,12 +4,7 @@ import pytest import psycopg from psycopg import pq -from psycopg.errors import ( - OperationalError, - ProgrammingError, - UndefinedColumn, - UndefinedTable, -) +from psycopg import errors as e pytestmark = [ pytest.mark.libpq(">= 14"), @@ -18,17 +13,27 @@ pytestmark = [ async def test_pipeline_status(aconn): + assert aconn._pipeline is None async with aconn.pipeline(): p = aconn._pipeline assert p is not None assert p.status == pq.PipelineStatus.ON - with pytest.raises(ProgrammingError): - async with aconn.pipeline(): - pass assert p.status == pq.PipelineStatus.OFF assert not aconn._pipeline +async def test_pipeline_reenter(aconn): + async with aconn.pipeline(): + p = aconn._pipeline + async with aconn.pipeline(): + assert aconn._pipeline is p + assert p.status == pq.PipelineStatus.ON + assert aconn._pipeline is p + assert p.status == pq.PipelineStatus.ON + assert aconn._pipeline is None + assert p.status == pq.PipelineStatus.OFF + + async def test_cursor_stream(aconn): async with aconn.pipeline(), aconn.cursor() as cur: with pytest.raises(psycopg.ProgrammingError): @@ -61,7 +66,7 @@ async def test_pipeline_processed_at_exit(aconn): async def test_pipeline_errors_processed_at_exit(aconn): await aconn.set_autocommit(True) - with pytest.raises((OperationalError, UndefinedTable)): + with pytest.raises((e.OperationalError, e.UndefinedTable)): async with aconn.pipeline(): await aconn.execute("select * from nosuchtable") await aconn.execute("create table voila ()") @@ -102,9 +107,9 @@ async def test_pipeline_aborted(aconn): await aconn.set_autocommit(True) async with aconn.pipeline(): c1 = await aconn.execute("select 1") - with pytest.raises(UndefinedTable): + with pytest.raises(e.UndefinedTable): await (await aconn.execute("select * from doesnotexist")).fetchone() - with pytest.raises(OperationalError, match="pipeline aborted"): + with pytest.raises(e.OperationalError, match="pipeline aborted"): await (await aconn.execute("select 'aborted'")).fetchone() # Sync restore the connection in usable state. aconn._pipeline.sync() @@ -118,7 +123,7 @@ async def test_pipeline_aborted(aconn): async def test_pipeline_commit_aborted(aconn): - with pytest.raises((UndefinedColumn, OperationalError)): + with pytest.raises((e.UndefinedColumn, e.OperationalError)): async with aconn.pipeline(): await aconn.execute("select error") await aconn.execute("create table voila ()") @@ -217,7 +222,7 @@ async def test_outer_transaction(aconn): async def test_outer_transaction_error(aconn): async with aconn.transaction(): - with pytest.raises((UndefinedColumn, OperationalError)): + with pytest.raises((e.UndefinedColumn, e.OperationalError)): async with aconn.pipeline(): await aconn.execute("select error") await aconn.execute("create table voila ()")