]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: handle exit exception in BasePipeline._exit()
authorDenis Laxalde <denis@laxalde.org>
Tue, 4 Oct 2022 19:41:22 +0000 (21:41 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 5 Oct 2022 11:24:48 +0000 (12:24 +0100)
We remove repeated code in __(a)exit__() method of (Async)Pipeline by
passing the exception value to _exit() method of BasePipeline.

psycopg/psycopg/_pipeline.py

index 75bb6b2eb31aeabe9c1bb99e75ab27bd3cfcae2f..fe4bf54fcf1a4099df45a230d7b844f56d21292c 100644 (file)
@@ -104,10 +104,19 @@ class BasePipeline:
             yield from self._sync_gen()
         self.level += 1
 
-    def _exit(self) -> None:
+    def _exit(self, exc: Optional[BaseException]) -> None:
         self.level -= 1
         if self.level == 0 and self.pgconn.status != BAD:
-            self.pgconn.exit_pipeline_mode()
+            try:
+                self.pgconn.exit_pipeline_mode()
+            except e.OperationalError as exc2:
+                # Notice that this error might be pretty irrecoverable. It
+                # happens on COPY, for instance: even if sync succeeds, exiting
+                # fails with "cannot exit pipeline mode with uncollected results"
+                if exc:
+                    logger.warning("error ignored exiting %r: %s", self, exc2)
+                else:
+                    raise exc2.with_traceback(None)
 
     def _sync_gen(self) -> PQGen[None]:
         self._enqueue_sync()
@@ -234,16 +243,7 @@ class Pipeline(BasePipeline):
             else:
                 raise exc2.with_traceback(None)
         finally:
-            try:
-                self._exit()
-            except Exception as exc2:
-                # Notice that this error might be pretty irrecoverable. It
-                # happens on COPY, for instance: even if sync succeeds, exiting
-                # fails with "cannot exit pipeline mode with uncollected results"
-                if exc_val:
-                    logger.warning("error ignored exiting %r: %s", self, exc2)
-                else:
-                    raise exc2.with_traceback(None)
+            self._exit(exc_val)
 
 
 class AsyncPipeline(BasePipeline):
@@ -284,10 +284,4 @@ class AsyncPipeline(BasePipeline):
             else:
                 raise exc2.with_traceback(None)
         finally:
-            try:
-                self._exit()
-            except Exception as exc2:
-                if exc_val:
-                    logger.warning("error ignored exiting %r: %s", self, exc2)
-                else:
-                    raise exc2.with_traceback(None)
+            self._exit(exc_val)