From: Daniele Varrazzo Date: Thu, 24 Mar 2022 15:52:05 +0000 (+0100) Subject: fix(copy): propagate errors raised in the worker thread X-Git-Tag: 3.1~163^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=47b51f952369f572f98a2a5ca525f94ae83d85ca;p=thirdparty%2Fpsycopg.git fix(copy): propagate errors raised in the worker thread Previously, an error in the worker thread was printed to stderr, but processing continued, for no result but no exception. Problem found in #255, but unrelated to it. --- diff --git a/docs/news.rst b/docs/news.rst index 549733cb9..a72290adc 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -26,6 +26,7 @@ Psycopg 3.0.11 (unreleased) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ - Fix `DataError` loading arrays with dimensions information (:ticket:`#253`). +- Fix error propagation from COPY worker thread (mentioned in :ticket:`#255`). Current release diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index dd4734607..6a3da8e0d 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -182,6 +182,7 @@ class Copy(BaseCopy["Connection[Any]"]): super().__init__(cursor) self._queue: queue.Queue[bytes] = queue.Queue(maxsize=self.QUEUE_SIZE) self._worker: Optional[threading.Thread] = None + self._worker_error: Optional[BaseException] = None def __enter__(self) -> "Copy": self._enter() @@ -270,15 +271,20 @@ class Copy(BaseCopy["Connection[Any]"]): def worker(self) -> None: """Push data to the server when available from the copy queue. - Terminate reading when the queue receives a None. + Terminate reading when the queue receives a false-y value, or in case + of error. The function is designed to be run in a separate thread. """ - while True: - data = self._queue.get(block=True, timeout=24 * 60 * 60) - if not data: - break - self.connection.wait(copy_to(self._pgconn, data)) + try: + while True: + data = self._queue.get(block=True, timeout=24 * 60 * 60) + if not data: + break + self.connection.wait(copy_to(self._pgconn, data)) + except BaseException as ex: + # Propagate the error to the main thread. + self._worker_error = ex def _write(self, data: bytes) -> None: if not data: @@ -290,6 +296,10 @@ class Copy(BaseCopy["Connection[Any]"]): self._worker.daemon = True self._worker.start() + # If the worker thread raies an exception, re-raise it to the caller. + if self._worker_error: + raise self._worker_error + self._queue.put(data) def _write_end(self) -> None: @@ -301,6 +311,10 @@ class Copy(BaseCopy["Connection[Any]"]): self._worker.join() self._worker = None # break the loop + # Check if the worker thread raised any exception before terminating. + if self._worker_error: + raise self._worker_error + class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): """Manage an asynchronous :sql:`COPY` operation.""" @@ -364,7 +378,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): async def worker(self) -> None: """Push data to the server when available from the copy queue. - Terminate reading when the queue receives a None. + Terminate reading when the queue receives a false-y value. The function is designed to be run in a separate thread. """ diff --git a/tests/test_copy.py b/tests/test_copy.py index 4cea6646a..64037fd62 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -563,6 +563,19 @@ def test_worker_life(conn, format, buffer): assert data == sample_records +def test_worker_error_propagated(conn, monkeypatch): + def copy_to_broken(pgconn, buffer): + raise ZeroDivisionError + yield + + monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken) + cur = conn.cursor() + cur.execute("create temp table wat (a text, b text)") + with pytest.raises(ZeroDivisionError): + with cur.copy("copy wat from stdin") as copy: + copy.write("a,b") + + @pytest.mark.slow @pytest.mark.parametrize( "fmt, set_types", diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index ba025e1a0..ad7ab7f44 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -563,6 +563,19 @@ async def test_worker_life(aconn, format, buffer): assert data == sample_records +async def test_worker_error_propagated(aconn, monkeypatch): + def copy_to_broken(pgconn, buffer): + raise ZeroDivisionError + yield + + monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken) + cur = aconn.cursor() + await cur.execute("create temp table wat (a text, b text)") + with pytest.raises(ZeroDivisionError): + async with cur.copy("copy wat from stdin") as copy: + await copy.write("a,b") + + @pytest.mark.slow @pytest.mark.parametrize( "fmt, set_types",