]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: consistent sync/exit and error management in pipeline contexts
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 29 Mar 2022 23:57:06 +0000 (01:57 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Apr 2022 23:23:22 +0000 (01:23 +0200)
Don't clobber an exception on exit of the nested block too. In order to
simplify the code, make the pipeline count the number of time it is
entered, and call _exit() only the last time it exits.

Drop assert that we have left pipeline mode leaving the block. If we get
in unrecoverable state, we will have not. By now we should probably just
close the connection; however, leaving it this way is a better
indication that the connection is broken because of something about
pipeline mode; closing it would hide it, and even if we raised a
warning, it would be much easier to miss it than to miss the exceptions
raised in broken state.

psycopg/psycopg/_pipeline.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tests/test_pipeline.py
tests/test_pipeline_async.py

index 26d4e1264246a9a30d414c0e77c858dc56fa551c..2523aaccdeaef02210c27a4f3dac447aa9c253f3 100644 (file)
@@ -53,6 +53,7 @@ class BasePipeline:
         self.pgconn = conn.pgconn
         self.command_queue = Deque[PipelineCommand]()
         self.result_queue = Deque[PendingResult]()
+        self.level = 0
 
     def __repr__(self) -> str:
         cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
@@ -74,10 +75,13 @@ class BasePipeline:
         return BasePipeline._is_supported
 
     def _enter(self) -> None:
-        self.pgconn.enter_pipeline_mode()
+        if self.level == 0:
+            self.pgconn.enter_pipeline_mode()
+        self.level += 1
 
     def _exit(self) -> None:
-        if self.pgconn.status != ConnStatus.BAD:
+        self.level -= 1
+        if self.level == 0 and self.pgconn.status != ConnStatus.BAD:
             self.pgconn.exit_pipeline_mode()
 
     def _sync_gen(self) -> PQGen[None]:
@@ -193,11 +197,20 @@ class Pipeline(BasePipeline):
         except Exception as exc2:
             # Don't clobber an exception raised in the block with this one
             if exc_val:
-                logger.warning("error ignored exiting %r: %s", self, exc2)
+                logger.warning("error ignored syncing %r: %s", self, exc2)
             else:
                 raise
         finally:
-            self._exit()
+            try:
+                self._exit()
+            except Exception as exc2:
+                # Notice that this error might be pretty irrecoverable. It
+                # happens on COPY, for insance: even if sync succeeds, exiting
+                # fails with "cannot exit pipeline mode with uncollected results"
+                if exc_val:
+                    logger.warning("error ignored exiting %r: %s", self, exc2)
+                else:
+                    raise
 
 
 class AsyncPipeline(BasePipeline):
@@ -234,8 +247,14 @@ class AsyncPipeline(BasePipeline):
         except Exception as exc2:
             # Don't clobber an exception raised in the block with this one
             if exc_val:
-                logger.warning("error ignored exiting %r: %s", self, exc2)
+                logger.warning("error ignored syncing %r: %s", self, exc2)
             else:
                 raise
         finally:
-            self._exit()
+            try:
+                self._exit()
+            except Exception as exc2:
+                if exc_val:
+                    logger.warning("error ignored exiting %r: %s", self, exc2)
+                else:
+                    raise
index bf2c584a47d7e8d0d4e114db2db7d21108debde4..2efd7c4b694fd63296d8a8dd2369cb32fafe037d 100644 (file)
@@ -871,29 +871,19 @@ class Connection(BaseConnection[Row]):
     def pipeline(self) -> Iterator[Pipeline]:
         """Context manager to switch the connection into pipeline mode."""
         with self.lock:
-            if self._pipeline is None:
-                # We must enter pipeline mode: create a new one
+            pipeline = self._pipeline
+            if pipeline is None:
                 # WARNING: reference loop, broken ahead.
                 pipeline = self._pipeline = Pipeline(self)
-            else:
-                # we are already in pipeline mode: bail out as soon as we
-                # leave the lock block.
-                pipeline = None
-
-        if not pipeline:
-            # No-op re-entered inner pipeline block.
-            try:
-                yield self._pipeline
-            finally:
-                self._pipeline.sync()
-            return
 
         try:
             with pipeline:
                 yield pipeline
         finally:
-            assert pipeline.status == pq.PipelineStatus.OFF, pipeline.status
-            self._pipeline = None
+            if pipeline.level == 0:
+                with self.lock:
+                    assert pipeline is self._pipeline
+                    self._pipeline = None
 
     def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
         """
index 7d4bebfb1325641ca3b240e5412caba26f605eb4..d0f0b181b6f6bf98741f5c236b3bc4704cbfe22f 100644 (file)
@@ -14,7 +14,7 @@ from contextlib import asynccontextmanager
 
 from . import errors as e
 from . import waiting
-from .pq import Format, PipelineStatus, TransactionStatus
+from .pq import Format, TransactionStatus
 from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
 from ._tpc import Xid
 from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
@@ -46,8 +46,7 @@ class AsyncConnection(BaseConnection[Row]):
     cursor_factory: Type[AsyncCursor[Row]]
     server_cursor_factory: Type[AsyncServerCursor[Row]]
     row_factory: AsyncRowFactory[Row]
-
-    _pipeline: "Optional[AsyncPipeline]"
+    _pipeline: Optional[AsyncPipeline]
 
     def __init__(
         self,
@@ -299,29 +298,19 @@ class AsyncConnection(BaseConnection[Row]):
     async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
         """Context manager to switch the connection into pipeline mode."""
         async with self.lock:
-            if self._pipeline is None:
-                # We must enter pipeline mode: create a new one
+            pipeline = self._pipeline
+            if pipeline is None:
                 # WARNING: reference loop, broken ahead.
                 pipeline = self._pipeline = AsyncPipeline(self)
-            else:
-                # we are already in pipeline mode: bail out as soon as we
-                # leave the lock block.
-                pipeline = None
-
-        if not pipeline:
-            # No-op re-entered inner pipeline block.
-            try:
-                yield self._pipeline
-            finally:
-                await self._pipeline.sync()
-            return
 
         try:
             async with pipeline:
                 yield pipeline
         finally:
-            assert pipeline.status == PipelineStatus.OFF, pipeline.status
-            self._pipeline = None
+            if pipeline.level == 0:
+                async with self.lock:
+                    assert pipeline is self._pipeline
+                    self._pipeline = None
 
     async def wait(self, gen: PQGen[RV]) -> RV:
         try:
index 43c862d389a37004c974cab0af24d3e710faa52b..2b4abe53f462d0288efcc3bb40c9a6f7540e97c3 100644 (file)
@@ -62,6 +62,17 @@ def test_pipeline_exit_error_noclobber(conn, caplog):
     assert len(caplog.records) == 1
 
 
+def test_pipeline_exit_error_noclobber_nested(conn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg")
+    with pytest.raises(ZeroDivisionError):
+        with conn.pipeline():
+            with conn.pipeline():
+                conn.close()
+                1 / 0
+
+    assert len(caplog.records) == 2
+
+
 def test_pipeline_exit_sync_trace(conn, trace):
     t = trace.trace(conn)
     with conn.pipeline():
index 668a8b3e0126e86024f609148451bfef0e62f048..f24cc144f4f2cbae4bf08000ac3b462d30af5882 100644 (file)
@@ -65,6 +65,17 @@ async def test_pipeline_exit_error_noclobber(aconn, caplog):
     assert len(caplog.records) == 1
 
 
+async def test_pipeline_exit_error_noclobber_nested(aconn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg")
+    with pytest.raises(ZeroDivisionError):
+        async with aconn.pipeline():
+            async with aconn.pipeline():
+                await aconn.close()
+                1 / 0
+
+    assert len(caplog.records) == 2
+
+
 async def test_pipeline_exit_sync_trace(aconn, trace):
     t = trace.trace(aconn)
     async with aconn.pipeline():