]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: move pipeline finalisation code to Pipeline.__exit__
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 27 Mar 2022 15:01:11 +0000 (17:01 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Apr 2022 23:17:57 +0000 (01:17 +0200)
This code has more internal knowledge of the Pipeline object than the
Connection object.

For some reason I don't understand, had to declare 'command_queue' and
'result_queue' types explicitly to the class definition. Leaving just the
definitions in '__init__()' causes mypy (0.940) to complain in 'cursor.py'.

psycopg/psycopg/_pipeline.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py

index b01f269bab97e271b192ef2b9444602317c8b32e..76b6d7d20d24b4117de14b7cc6953db7cb8a3406 100644 (file)
@@ -17,8 +17,10 @@ from ._encodings import pgconn_encoding
 from ._preparing import Key, Prepare
 
 if TYPE_CHECKING:
-    from .pq.abc import PGconn, PGresult
+    from .pq.abc import PGresult
     from .cursor import BaseCursor
+    from .connection import BaseConnection, Connection
+    from .connection_async import AsyncConnection
 
 if _psycopg:
     pipeline_communicate = _psycopg.pipeline_communicate
@@ -38,8 +40,13 @@ PendingResult: TypeAlias = Union[
 
 
 class BasePipeline:
-    def __init__(self, pgconn: "PGconn") -> None:
-        self.pgconn = pgconn
+
+    command_queue: Deque[PipelineCommand]
+    result_queue: Deque[PendingResult]
+
+    def __init__(self, conn: "BaseConnection[Any]") -> None:
+        self._conn = conn
+        self.pgconn = conn.pgconn
         self.command_queue = Deque[PipelineCommand]()
         self.result_queue = Deque[PendingResult]()
 
@@ -124,6 +131,11 @@ class BasePipeline:
 class Pipeline(BasePipeline):
     """Handler for connection in pipeline mode."""
 
+    _conn: "Connection[Any]"
+
+    def __init__(self, conn: "Connection[Any]") -> None:
+        super().__init__(conn)
+
     def __enter__(self) -> "Pipeline":
         self._enter()
         return self
@@ -134,12 +146,29 @@ class Pipeline(BasePipeline):
         exc_val: Optional[BaseException],
         exc_tb: Optional[TracebackType],
     ) -> None:
-        self._exit()
+        try:
+            with self._conn.lock:
+                self.sync()
+                try:
+                    # Send an pending commands (e.g. COMMIT or Sync);
+                    # while processing results, we might get errors...
+                    self._conn.wait(self._communicate_gen())
+                finally:
+                    # then fetch all remaining results but without forcing
+                    # flush since we emitted a sync just before.
+                    self._conn.wait(self._fetch_gen(flush=False))
+        finally:
+            self._exit()
 
 
 class AsyncPipeline(BasePipeline):
     """Handler for async connection in pipeline mode."""
 
+    _conn: "AsyncConnection[Any]"
+
+    def __init__(self, conn: "AsyncConnection[Any]") -> None:
+        super().__init__(conn)
+
     async def __aenter__(self) -> "AsyncPipeline":
         self._enter()
         return self
@@ -150,4 +179,16 @@ class AsyncPipeline(BasePipeline):
         exc_val: Optional[BaseException],
         exc_tb: Optional[TracebackType],
     ) -> None:
-        self._exit()
+        try:
+            async with self._conn.lock:
+                self.sync()
+                try:
+                    # Send an pending commands (e.g. COMMIT or Sync);
+                    # while processing results, we might get errors...
+                    await self._conn.wait(self._communicate_gen())
+                finally:
+                    # then fetch all remaining results but without forcing
+                    # flush since we emitted a sync just before.
+                    await self._conn.wait(self._fetch_gen(flush=False))
+        finally:
+            self._exit()
index 0bdfc6b026c1ccd841a9db7f44240ee59266a735..5826b4beb82f5969a43ff7f79ad0d7478fe6ba0b 100644 (file)
@@ -873,7 +873,8 @@ class Connection(BaseConnection[Row]):
         with self.lock:
             if self._pipeline is None:
                 # We must enter pipeline mode: create a new one
-                pipeline = self._pipeline = Pipeline(self.pgconn)
+                # WARNING: reference loop, broken ahead.
+                pipeline = self._pipeline = Pipeline(self)
             else:
                 # we are already in pipeline mode: bail out as soon as we
                 # leave the lock block.
@@ -886,19 +887,7 @@ class Connection(BaseConnection[Row]):
 
         try:
             with pipeline:
-                try:
-                    yield pipeline
-                finally:
-                    with self.lock:
-                        pipeline.sync()
-                        try:
-                            # Send an pending commands (e.g. COMMIT or Sync);
-                            # while processing results, we might get errors...
-                            self.wait(pipeline._communicate_gen())
-                        finally:
-                            # then fetch all remaining results but without forcing
-                            # flush since we emitted a sync just before.
-                            self.wait(pipeline._fetch_gen(flush=False))
+                yield pipeline
         finally:
             assert pipeline.status == pq.PipelineStatus.OFF, pipeline.status
             self._pipeline = None
index 592aa474fe5345e7a4e60a0b4b7f3fd3fa0e2f20..246dd3fc837558f3a87ef445802733a9ebd1b57d 100644 (file)
@@ -301,7 +301,8 @@ class AsyncConnection(BaseConnection[Row]):
         async with self.lock:
             if self._pipeline is None:
                 # We must enter pipeline mode: create a new one
-                pipeline = self._pipeline = AsyncPipeline(self.pgconn)
+                # WARNING: reference loop, broken ahead.
+                pipeline = self._pipeline = AsyncPipeline(self)
             else:
                 # we are already in pipeline mode: bail out as soon as we
                 # leave the lock block.
@@ -314,19 +315,7 @@ class AsyncConnection(BaseConnection[Row]):
 
         try:
             async with pipeline:
-                try:
-                    yield pipeline
-                finally:
-                    async with self.lock:
-                        pipeline.sync()
-                        try:
-                            # Send an pending commands (e.g. COMMIT or Sync);
-                            # while processing results, we might get errors...
-                            await self.wait(pipeline._communicate_gen())
-                        finally:
-                            # then fetch all remaining results but without forcing
-                            # flush since we emitted a sync just before.
-                            await self.wait(pipeline._fetch_gen(flush=False))
+                yield pipeline
         finally:
             assert pipeline.status == PipelineStatus.OFF, pipeline.status
             self._pipeline = None