]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: let Connection.pipeline() return the Pipeline object
authorDenis Laxalde <denis@laxalde.org>
Sun, 27 Mar 2022 10:56:40 +0000 (12:56 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Apr 2022 23:17:57 +0000 (01:17 +0200)
In tests, add a type annotation on 'conn'/'aconn' fixture so that mypy
understands that pipeline() yields a Pipeline object.

psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tests/test_pipeline.py
tests/test_pipeline_async.py

index 1a0b9393aa57490413766d42e87513537b941b9b..0bdfc6b026c1ccd841a9db7f44240ee59266a735 100644 (file)
@@ -868,7 +868,7 @@ class Connection(BaseConnection[Row]):
                 yield n
 
     @contextmanager
-    def pipeline(self) -> Iterator[None]:
+    def pipeline(self) -> Iterator[Pipeline]:
         """Context manager to switch the connection into pipeline mode."""
         with self.lock:
             if self._pipeline is None:
@@ -881,13 +881,13 @@ class Connection(BaseConnection[Row]):
 
         if not pipeline:
             # No-op re-entered inner pipeline block.
-            yield
+            yield self._pipeline
             return
 
         try:
             with pipeline:
                 try:
-                    yield
+                    yield pipeline
                 finally:
                     with self.lock:
                         pipeline.sync()
index 2c0321bb67a49bc8b43f29e98cee83f49693dc73..592aa474fe5345e7a4e60a0b4b7f3fd3fa0e2f20 100644 (file)
@@ -296,7 +296,7 @@ class AsyncConnection(BaseConnection[Row]):
                 yield n
 
     @asynccontextmanager
-    async def pipeline(self) -> AsyncIterator[None]:
+    async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
         """Context manager to switch the connection into pipeline mode."""
         async with self.lock:
             if self._pipeline is None:
@@ -309,13 +309,13 @@ class AsyncConnection(BaseConnection[Row]):
 
         if not pipeline:
             # No-op re-entered inner pipeline block.
-            yield
+            yield self._pipeline
             return
 
         try:
             async with pipeline:
                 try:
-                    yield
+                    yield pipeline
                 finally:
                     async with self.lock:
                         pipeline.sync()
index 2e0f8f6ce2881833440bf62e7686f251fad3ed2f..331db905e75478aa5bb472e8d38174805217cd25 100644 (file)
@@ -1,3 +1,4 @@
+from typing import Any
 import concurrent.futures
 
 import pytest
@@ -9,26 +10,24 @@ from psycopg import errors as e
 pytestmark = pytest.mark.libpq(">= 14")
 
 
-def test_pipeline_status(conn):
+def test_pipeline_status(conn: psycopg.Connection[Any]) -> None:
     assert conn._pipeline is None
-    with conn.pipeline():
-        p = conn._pipeline
-        assert p is not None
+    with conn.pipeline() as p:
+        assert conn._pipeline is p
         assert p.status == pq.PipelineStatus.ON
     assert p.status == pq.PipelineStatus.OFF
     assert not conn._pipeline
 
 
-def test_pipeline_reenter(conn):
-    with conn.pipeline():
-        p = conn._pipeline
-        with conn.pipeline():
-            assert conn._pipeline is p
-            assert p.status == pq.PipelineStatus.ON
-        assert conn._pipeline is p
-        assert p.status == pq.PipelineStatus.ON
+def test_pipeline_reenter(conn: psycopg.Connection[Any]) -> None:
+    with conn.pipeline() as p1:
+        with conn.pipeline() as p2:
+            assert p2 is p1
+            assert p1.status == pq.PipelineStatus.ON
+        assert p2 is p1
+        assert p2.status == pq.PipelineStatus.ON
     assert conn._pipeline is None
-    assert p.status == pq.PipelineStatus.OFF
+    assert p1.status == pq.PipelineStatus.OFF
 
 
 def test_cursor_stream(conn):
@@ -52,11 +51,11 @@ def test_cannot_insert_multiple_commands(conn):
 
 def test_pipeline_processed_at_exit(conn):
     with conn.cursor() as cur:
-        with conn.pipeline():
+        with conn.pipeline() as p:
             cur.execute("select 1")
 
             # PQsendQuery[BEGIN], PQsendQuery
-            assert len(conn._pipeline.result_queue) == 2
+            assert len(p.result_queue) == 2
 
         assert cur.fetchone() == (1,)
 
@@ -75,14 +74,14 @@ def test_pipeline_errors_processed_at_exit(conn):
 
 
 def test_pipeline(conn):
-    with conn.pipeline():
+    with conn.pipeline() as p:
         c1 = conn.cursor()
         c2 = conn.cursor()
         c1.execute("select 1")
         c2.execute("select 2")
 
         # PQsendQuery[BEGIN], PQsendQuery(2)
-        assert len(conn._pipeline.result_queue) == 3
+        assert len(p.result_queue) == 3
 
         (r1,) = c1.fetchone()
         assert r1 == 1
@@ -102,14 +101,14 @@ def test_autocommit(conn):
 
 def test_pipeline_aborted(conn):
     conn.autocommit = True
-    with conn.pipeline():
+    with conn.pipeline() as p:
         c1 = conn.execute("select 1")
         with pytest.raises(e.UndefinedTable):
             conn.execute("select * from doesnotexist").fetchone()
         with pytest.raises(e.OperationalError, match="pipeline aborted"):
             conn.execute("select 'aborted'").fetchone()
         # Sync restore the connection in usable state.
-        conn._pipeline.sync()
+        p.sync()
         c2 = conn.execute("select 2")
 
     (r,) = c1.fetchone()
index 1480e0592d73c5d68b3ac7b0f04b693f43c62d8f..bb78d08225f01ffc965c8ea410297239fabf3f96 100644 (file)
@@ -1,4 +1,5 @@
 import asyncio
+from typing import Any
 
 import pytest
 
@@ -12,26 +13,24 @@ pytestmark = [
 ]
 
 
-async def test_pipeline_status(aconn):
+async def test_pipeline_status(aconn: psycopg.AsyncConnection[Any]) -> None:
     assert aconn._pipeline is None
-    async with aconn.pipeline():
-        p = aconn._pipeline
-        assert p is not None
+    async with aconn.pipeline() as p:
+        assert aconn._pipeline is p
         assert p.status == pq.PipelineStatus.ON
     assert p.status == pq.PipelineStatus.OFF
     assert not aconn._pipeline
 
 
-async def test_pipeline_reenter(aconn):
-    async with aconn.pipeline():
-        p = aconn._pipeline
-        async with aconn.pipeline():
-            assert aconn._pipeline is p
-            assert p.status == pq.PipelineStatus.ON
-        assert aconn._pipeline is p
-        assert p.status == pq.PipelineStatus.ON
+async def test_pipeline_reenter(aconn: psycopg.AsyncConnection[Any]) -> None:
+    async with aconn.pipeline() as p1:
+        async with aconn.pipeline() as p2:
+            assert p2 is p1
+            assert p1.status == pq.PipelineStatus.ON
+        assert p2 is p1
+        assert p2.status == pq.PipelineStatus.ON
     assert aconn._pipeline is None
-    assert p.status == pq.PipelineStatus.OFF
+    assert p1.status == pq.PipelineStatus.OFF
 
 
 async def test_cursor_stream(aconn):
@@ -55,11 +54,11 @@ async def test_cannot_insert_multiple_commands(aconn):
 
 async def test_pipeline_processed_at_exit(aconn):
     async with aconn.cursor() as cur:
-        async with aconn.pipeline():
+        async with aconn.pipeline() as p:
             await cur.execute("select 1")
 
             # PQsendQuery[BEGIN], PQsendQuery
-            assert len(aconn._pipeline.result_queue) == 2
+            assert len(p.result_queue) == 2
 
         assert await cur.fetchone() == (1,)
 
@@ -78,14 +77,14 @@ async def test_pipeline_errors_processed_at_exit(aconn):
 
 
 async def test_pipeline(aconn):
-    async with aconn.pipeline():
+    async with aconn.pipeline() as p:
         c1 = aconn.cursor()
         c2 = aconn.cursor()
         await c1.execute("select 1")
         await c2.execute("select 2")
 
         # PQsendQuery[BEGIN], PQsendQuery(2)
-        assert len(aconn._pipeline.result_queue) == 3
+        assert len(p.result_queue) == 3
 
         (r1,) = await c1.fetchone()
         assert r1 == 1
@@ -105,14 +104,14 @@ async def test_autocommit(aconn):
 
 async def test_pipeline_aborted(aconn):
     await aconn.set_autocommit(True)
-    async with aconn.pipeline():
+    async with aconn.pipeline() as p:
         c1 = await aconn.execute("select 1")
         with pytest.raises(e.UndefinedTable):
             await (await aconn.execute("select * from doesnotexist")).fetchone()
         with pytest.raises(e.OperationalError, match="pipeline aborted"):
             await (await aconn.execute("select 'aborted'")).fetchone()
         # Sync restore the connection in usable state.
-        aconn._pipeline.sync()
+        p.sync()
         c2 = await aconn.execute("select 2")
 
     (r,) = await c1.fetchone()