]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: allow re-entering pipeline mode
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 26 Mar 2022 17:56:31 +0000 (18:56 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Apr 2022 23:17:57 +0000 (01:17 +0200)
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tests/test_pipeline.py
tests/test_pipeline_async.py

index dc440967b18dc2bd38766db3299c2820a17601df..6c09c17223a4b8199ad5fa5e28cbc7b79008fead 100644 (file)
@@ -871,7 +871,9 @@ class Connection(BaseConnection[Row]):
     def pipeline(self) -> Iterator[None]:
         """Context manager to switch the connection into pipeline mode."""
         if self._pipeline is not None:
-            raise e.ProgrammingError("already in pipeline mode")
+            # calling pipeline recursively is no-op.
+            yield
+            return
 
         pipeline = self._pipeline = Pipeline(self.pgconn)
         try:
index a5129148384916adfdff17da79a7e90e2c7911bc..52709d67024e7d65f40fcccb0cf4b8c4f68f7998 100644 (file)
@@ -299,7 +299,9 @@ class AsyncConnection(BaseConnection[Row]):
     async def pipeline(self) -> AsyncIterator[None]:
         """Context manager to switch the connection into pipeline mode."""
         if self._pipeline is not None:
-            raise e.ProgrammingError("already in pipeline mode")
+            # calling pipeline recursively is no-op.
+            yield
+            return
 
         pipeline = self._pipeline = AsyncPipeline(self.pgconn)
         try:
index 99f4046634f28a1980b2a1f0f6bead8549b2cee7..2e0f8f6ce2881833440bf62e7686f251fad3ed2f 100644 (file)
@@ -4,28 +4,33 @@ import pytest
 
 import psycopg
 from psycopg import pq
-from psycopg.errors import (
-    OperationalError,
-    ProgrammingError,
-    UndefinedColumn,
-    UndefinedTable,
-)
+from psycopg import errors as e
 
 pytestmark = pytest.mark.libpq(">= 14")
 
 
 def test_pipeline_status(conn):
+    assert conn._pipeline is None
     with conn.pipeline():
         p = conn._pipeline
         assert p is not None
         assert p.status == pq.PipelineStatus.ON
-        with pytest.raises(ProgrammingError):
-            with conn.pipeline():
-                pass
     assert p.status == pq.PipelineStatus.OFF
     assert not conn._pipeline
 
 
+def test_pipeline_reenter(conn):
+    with conn.pipeline():
+        p = conn._pipeline
+        with conn.pipeline():
+            assert conn._pipeline is p
+            assert p.status == pq.PipelineStatus.ON
+        assert conn._pipeline is p
+        assert p.status == pq.PipelineStatus.ON
+    assert conn._pipeline is None
+    assert p.status == pq.PipelineStatus.OFF
+
+
 def test_cursor_stream(conn):
     with conn.pipeline(), conn.cursor() as cur:
         with pytest.raises(psycopg.ProgrammingError):
@@ -58,7 +63,7 @@ def test_pipeline_processed_at_exit(conn):
 
 def test_pipeline_errors_processed_at_exit(conn):
     conn.autocommit = True
-    with pytest.raises((OperationalError, UndefinedTable)):
+    with pytest.raises((e.OperationalError, e.UndefinedTable)):
         with conn.pipeline():
             conn.execute("select * from nosuchtable")
             conn.execute("create table voila ()")
@@ -99,9 +104,9 @@ def test_pipeline_aborted(conn):
     conn.autocommit = True
     with conn.pipeline():
         c1 = conn.execute("select 1")
-        with pytest.raises(UndefinedTable):
+        with pytest.raises(e.UndefinedTable):
             conn.execute("select * from doesnotexist").fetchone()
-        with pytest.raises(OperationalError, match="pipeline aborted"):
+        with pytest.raises(e.OperationalError, match="pipeline aborted"):
             conn.execute("select 'aborted'").fetchone()
         # Sync restore the connection in usable state.
         conn._pipeline.sync()
@@ -115,7 +120,7 @@ def test_pipeline_aborted(conn):
 
 
 def test_pipeline_commit_aborted(conn):
-    with pytest.raises((UndefinedColumn, OperationalError)):
+    with pytest.raises((e.UndefinedColumn, e.OperationalError)):
         with conn.pipeline():
             conn.execute("select error")
             conn.execute("create table voila ()")
@@ -213,7 +218,7 @@ def test_outer_transaction(conn):
 
 def test_outer_transaction_error(conn):
     with conn.transaction():
-        with pytest.raises((UndefinedColumn, OperationalError)):
+        with pytest.raises((e.UndefinedColumn, e.OperationalError)):
             with conn.pipeline():
                 conn.execute("select error")
                 conn.execute("create table voila ()")
index 5c3dedc9490bf36ea8421273f428f629880c83b0..1480e0592d73c5d68b3ac7b0f04b693f43c62d8f 100644 (file)
@@ -4,12 +4,7 @@ import pytest
 
 import psycopg
 from psycopg import pq
-from psycopg.errors import (
-    OperationalError,
-    ProgrammingError,
-    UndefinedColumn,
-    UndefinedTable,
-)
+from psycopg import errors as e
 
 pytestmark = [
     pytest.mark.libpq(">= 14"),
@@ -18,17 +13,27 @@ pytestmark = [
 
 
 async def test_pipeline_status(aconn):
+    assert aconn._pipeline is None
     async with aconn.pipeline():
         p = aconn._pipeline
         assert p is not None
         assert p.status == pq.PipelineStatus.ON
-        with pytest.raises(ProgrammingError):
-            async with aconn.pipeline():
-                pass
     assert p.status == pq.PipelineStatus.OFF
     assert not aconn._pipeline
 
 
+async def test_pipeline_reenter(aconn):
+    async with aconn.pipeline():
+        p = aconn._pipeline
+        async with aconn.pipeline():
+            assert aconn._pipeline is p
+            assert p.status == pq.PipelineStatus.ON
+        assert aconn._pipeline is p
+        assert p.status == pq.PipelineStatus.ON
+    assert aconn._pipeline is None
+    assert p.status == pq.PipelineStatus.OFF
+
+
 async def test_cursor_stream(aconn):
     async with aconn.pipeline(), aconn.cursor() as cur:
         with pytest.raises(psycopg.ProgrammingError):
@@ -61,7 +66,7 @@ async def test_pipeline_processed_at_exit(aconn):
 
 async def test_pipeline_errors_processed_at_exit(aconn):
     await aconn.set_autocommit(True)
-    with pytest.raises((OperationalError, UndefinedTable)):
+    with pytest.raises((e.OperationalError, e.UndefinedTable)):
         async with aconn.pipeline():
             await aconn.execute("select * from nosuchtable")
             await aconn.execute("create table voila ()")
@@ -102,9 +107,9 @@ async def test_pipeline_aborted(aconn):
     await aconn.set_autocommit(True)
     async with aconn.pipeline():
         c1 = await aconn.execute("select 1")
-        with pytest.raises(UndefinedTable):
+        with pytest.raises(e.UndefinedTable):
             await (await aconn.execute("select * from doesnotexist")).fetchone()
-        with pytest.raises(OperationalError, match="pipeline aborted"):
+        with pytest.raises(e.OperationalError, match="pipeline aborted"):
             await (await aconn.execute("select 'aborted'")).fetchone()
         # Sync restore the connection in usable state.
         aconn._pipeline.sync()
@@ -118,7 +123,7 @@ async def test_pipeline_aborted(aconn):
 
 
 async def test_pipeline_commit_aborted(aconn):
-    with pytest.raises((UndefinedColumn, OperationalError)):
+    with pytest.raises((e.UndefinedColumn, e.OperationalError)):
         async with aconn.pipeline():
             await aconn.execute("select error")
             await aconn.execute("create table voila ()")
@@ -217,7 +222,7 @@ async def test_outer_transaction(aconn):
 
 async def test_outer_transaction_error(aconn):
     async with aconn.transaction():
-        with pytest.raises((UndefinedColumn, OperationalError)):
+        with pytest.raises((e.UndefinedColumn, e.OperationalError)):
             async with aconn.pipeline():
                 await aconn.execute("select error")
                 await aconn.execute("create table voila ()")