]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: rename pipeline.communicate() -> sync()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 29 Mar 2022 12:21:03 +0000 (14:21 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Apr 2022 23:23:22 +0000 (01:23 +0200)
This has the async/sync interface, is is more apt to expose it as public
interface to say "call this to restore the state" (as the changed tests
does).

Expose the try/finally logic behind sync() as the _sync_gen() method,
which can be useful to call if sync has to be performed inside a lock.

psycopg/psycopg/_pipeline.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tests/test_pipeline_async.py

index 4a73286e66d786e296ad0c38eef6124eb2731238..26d4e1264246a9a30d414c0e77c858dc56fa551c 100644 (file)
@@ -63,11 +63,6 @@ class BasePipeline:
     def status(self) -> pq.PipelineStatus:
         return pq.PipelineStatus(self.pgconn.pipeline_status)
 
-    def sync(self) -> None:
-        """Enqueue a PQpipelineSync() command."""
-        self.command_queue.append(self.pgconn.pipeline_sync)
-        self.result_queue.append(None)
-
     @staticmethod
     def is_supported() -> bool:
         """Return `True` if the psycopg libpq wrapper suports pipeline mode."""
@@ -85,6 +80,17 @@ class BasePipeline:
         if self.pgconn.status != ConnStatus.BAD:
             self.pgconn.exit_pipeline_mode()
 
+    def _sync_gen(self) -> PQGen[None]:
+        self._enqueue_sync()
+        try:
+            # Send any pending commands (e.g. COMMIT or Sync);
+            # while processing results, we might get errors...
+            yield from self._communicate_gen()
+        finally:
+            # then fetch all remaining results but without forcing
+            # flush since we emitted a sync just before.
+            yield from self._fetch_gen(flush=False)
+
     def _communicate_gen(self) -> PQGen[None]:
         """Communicate with pipeline to send commands and possibly fetch
         results, which are then processed.
@@ -107,7 +113,8 @@ class BasePipeline:
             return
 
         if flush:
-            yield from self._flush_gen()
+            self.pgconn.send_flush_request()
+            yield from send(self.pgconn)
 
         to_process = []
         while self.result_queue:
@@ -122,10 +129,6 @@ class BasePipeline:
         for queued, results in to_process:
             self._process_results(queued, results)
 
-    def _flush_gen(self) -> PQGen[None]:
-        self.pgconn.send_flush_request()
-        yield from send(self.pgconn)
-
     def _process_results(
         self, queued: PendingResult, results: List["PGresult"]
     ) -> None:
@@ -150,6 +153,11 @@ class BasePipeline:
                 # Update the prepare state of the query.
                 cursor._conn._prepared.validate(key, prep, name, results)
 
+    def _enqueue_sync(self) -> None:
+        """Enqueue a PQpipelineSync() command."""
+        self.command_queue.append(self.pgconn.pipeline_sync)
+        self.result_queue.append(None)
+
 
 class Pipeline(BasePipeline):
     """Handler for connection in pipeline mode."""
@@ -160,7 +168,7 @@ class Pipeline(BasePipeline):
     def __init__(self, conn: "Connection[Any]") -> None:
         super().__init__(conn)
 
-    def communicate(self) -> None:
+    def sync(self) -> None:
         """Sync the pipeline, send any pending command and fetch and process
         all available results.
 
@@ -168,15 +176,7 @@ class Pipeline(BasePipeline):
         purposes (e.g. in nested pipelines).
         """
         with self._conn.lock:
-            self.sync()
-            try:
-                # Send any pending commands (e.g. COMMIT or Sync);
-                # while processing results, we might get errors...
-                self._conn.wait(self._communicate_gen())
-            finally:
-                # then fetch all remaining results but without forcing
-                # flush since we emitted a sync just before.
-                self._conn.wait(self._fetch_gen(flush=False))
+            self._conn.wait(self._sync_gen())
 
     def __enter__(self) -> "Pipeline":
         self._enter()
@@ -189,7 +189,7 @@ class Pipeline(BasePipeline):
         exc_tb: Optional[TracebackType],
     ) -> None:
         try:
-            self.communicate()
+            self.sync()
         except Exception as exc2:
             # Don't clobber an exception raised in the block with this one
             if exc_val:
@@ -209,7 +209,7 @@ class AsyncPipeline(BasePipeline):
     def __init__(self, conn: "AsyncConnection[Any]") -> None:
         super().__init__(conn)
 
-    async def communicate(self) -> None:
+    async def sync(self) -> None:
         """Sync the pipeline, send any pending command and fetch and process
         all available results.
 
@@ -217,15 +217,7 @@ class AsyncPipeline(BasePipeline):
         purposes (e.g. in nested pipelines).
         """
         async with self._conn.lock:
-            self.sync()
-            try:
-                # Send any pending commands (e.g. COMMIT or Sync);
-                # while processing results, we might get errors...
-                await self._conn.wait(self._communicate_gen())
-            finally:
-                # then fetch all remaining results but without forcing
-                # flush since we emitted a sync just before.
-                await self._conn.wait(self._fetch_gen(flush=False))
+            await self._conn.wait(self._sync_gen())
 
     async def __aenter__(self) -> "AsyncPipeline":
         self._enter()
@@ -238,7 +230,7 @@ class AsyncPipeline(BasePipeline):
         exc_tb: Optional[TracebackType],
     ) -> None:
         try:
-            await self.communicate()
+            await self.sync()
         except Exception as exc2:
             # Don't clobber an exception raised in the block with this one
             if exc_val:
index 68e8306a6a6b35d0ed579d8e6b288737d12059e5..bf2c584a47d7e8d0d4e114db2db7d21108debde4 100644 (file)
@@ -885,7 +885,7 @@ class Connection(BaseConnection[Row]):
             try:
                 yield self._pipeline
             finally:
-                self._pipeline.communicate()
+                self._pipeline.sync()
             return
 
         try:
index d0cbfea1051f6080ac7bf7b649f8a34e5ee8c45d..7d4bebfb1325641ca3b240e5412caba26f605eb4 100644 (file)
@@ -313,7 +313,7 @@ class AsyncConnection(BaseConnection[Row]):
             try:
                 yield self._pipeline
             finally:
-                await self._pipeline.communicate()
+                await self._pipeline.sync()
             return
 
         try:
index 783814f3b0ff89f9bb0ca927db6be4dfd5c88f41..fbb2c9f980da1cb9796f2d877b84b61c93fb4e3a 100644 (file)
@@ -141,7 +141,7 @@ async def test_pipeline_aborted(aconn):
         with pytest.raises(e.OperationalError, match="pipeline aborted"):
             await (await aconn.execute("select 'aborted'")).fetchone()
         # Sync restore the connection in usable state.
-        p.sync()
+        await p.sync()
         c2 = await aconn.execute("select 2")
 
     (r,) = await c1.fetchone()