]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(copy): propagate errors raised in the worker thread
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 24 Mar 2022 15:52:05 +0000 (16:52 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 25 Mar 2022 17:40:33 +0000 (18:40 +0100)
Previously, an error in the worker thread was printed to stderr, but
processing continued, for no result but no exception.

Problem found in #255, but unrelated to it.

docs/news.rst
psycopg/psycopg/copy.py
tests/test_copy.py
tests/test_copy_async.py

index 549733cb9dc73c84b71359537c1b378542640870..a72290adc9fd407ab7ca14edb75f192ff85f1ffa 100644 (file)
@@ -26,6 +26,7 @@ Psycopg 3.0.11 (unreleased)
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
 - Fix `DataError` loading arrays with dimensions information (:ticket:`#253`).
+- Fix error propagation from COPY worker thread (mentioned in :ticket:`#255`).
 
 
 Current release
index dd47346079be8bcddf8242ccbef3ac54785f50f6..6a3da8e0d81ffb74b49c5f7c03d4e65710e59699 100644 (file)
@@ -182,6 +182,7 @@ class Copy(BaseCopy["Connection[Any]"]):
         super().__init__(cursor)
         self._queue: queue.Queue[bytes] = queue.Queue(maxsize=self.QUEUE_SIZE)
         self._worker: Optional[threading.Thread] = None
+        self._worker_error: Optional[BaseException] = None
 
     def __enter__(self) -> "Copy":
         self._enter()
@@ -270,15 +271,20 @@ class Copy(BaseCopy["Connection[Any]"]):
     def worker(self) -> None:
         """Push data to the server when available from the copy queue.
 
-        Terminate reading when the queue receives a None.
+        Terminate reading when the queue receives a false-y value, or in case
+        of error.
 
         The function is designed to be run in a separate thread.
         """
-        while True:
-            data = self._queue.get(block=True, timeout=24 * 60 * 60)
-            if not data:
-                break
-            self.connection.wait(copy_to(self._pgconn, data))
+        try:
+            while True:
+                data = self._queue.get(block=True, timeout=24 * 60 * 60)
+                if not data:
+                    break
+                self.connection.wait(copy_to(self._pgconn, data))
+        except BaseException as ex:
+            # Propagate the error to the main thread.
+            self._worker_error = ex
 
     def _write(self, data: bytes) -> None:
         if not data:
@@ -290,6 +296,10 @@ class Copy(BaseCopy["Connection[Any]"]):
             self._worker.daemon = True
             self._worker.start()
 
+        # If the worker thread raies an exception, re-raise it to the caller.
+        if self._worker_error:
+            raise self._worker_error
+
         self._queue.put(data)
 
     def _write_end(self) -> None:
@@ -301,6 +311,10 @@ class Copy(BaseCopy["Connection[Any]"]):
             self._worker.join()
             self._worker = None  # break the loop
 
+        # Check if the worker thread raised any exception before terminating.
+        if self._worker_error:
+            raise self._worker_error
+
 
 class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
     """Manage an asynchronous :sql:`COPY` operation."""
@@ -364,7 +378,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
     async def worker(self) -> None:
         """Push data to the server when available from the copy queue.
 
-        Terminate reading when the queue receives a None.
+        Terminate reading when the queue receives a false-y value.
 
         The function is designed to be run in a separate thread.
         """
index 4cea6646a890da8e965e2ac091354530030eee46..64037fd628d30772ac1d747e9db3a3f6298d08b7 100644 (file)
@@ -563,6 +563,19 @@ def test_worker_life(conn, format, buffer):
     assert data == sample_records
 
 
+def test_worker_error_propagated(conn, monkeypatch):
+    def copy_to_broken(pgconn, buffer):
+        raise ZeroDivisionError
+        yield
+
+    monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken)
+    cur = conn.cursor()
+    cur.execute("create temp table wat (a text, b text)")
+    with pytest.raises(ZeroDivisionError):
+        with cur.copy("copy wat from stdin") as copy:
+            copy.write("a,b")
+
+
 @pytest.mark.slow
 @pytest.mark.parametrize(
     "fmt, set_types",
index ba025e1a05ba56c4b72829a160e55d02724c719c..ad7ab7f44a7f43b910b40c4b0fd497459804892d 100644 (file)
@@ -563,6 +563,19 @@ async def test_worker_life(aconn, format, buffer):
     assert data == sample_records
 
 
+async def test_worker_error_propagated(aconn, monkeypatch):
+    def copy_to_broken(pgconn, buffer):
+        raise ZeroDivisionError
+        yield
+
+    monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken)
+    cur = aconn.cursor()
+    await cur.execute("create temp table wat (a text, b text)")
+    with pytest.raises(ZeroDivisionError):
+        async with cur.copy("copy wat from stdin") as copy:
+            await copy.write("a,b")
+
+
 @pytest.mark.slow
 @pytest.mark.parametrize(
     "fmt, set_types",