super().__init__(cursor)
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=self.QUEUE_SIZE)
self._worker: Optional[threading.Thread] = None
+ self._worker_error: Optional[BaseException] = None
def __enter__(self) -> "Copy":
self._enter()
def worker(self) -> None:
"""Push data to the server when available from the copy queue.
- Terminate reading when the queue receives a None.
+ Terminate reading when the queue receives a false-y value, or in case
+ of error.
The function is designed to be run in a separate thread.
"""
- while True:
- data = self._queue.get(block=True, timeout=24 * 60 * 60)
- if not data:
- break
- self.connection.wait(copy_to(self._pgconn, data))
+ try:
+ while True:
+ data = self._queue.get(block=True, timeout=24 * 60 * 60)
+ if not data:
+ break
+ self.connection.wait(copy_to(self._pgconn, data))
+ except BaseException as ex:
+ # Propagate the error to the main thread.
+ self._worker_error = ex
def _write(self, data: bytes) -> None:
if not data:
self._worker.daemon = True
self._worker.start()
+ # If the worker thread raies an exception, re-raise it to the caller.
+ if self._worker_error:
+ raise self._worker_error
+
self._queue.put(data)
def _write_end(self) -> None:
self._worker.join()
self._worker = None # break the loop
+ # Check if the worker thread raised any exception before terminating.
+ if self._worker_error:
+ raise self._worker_error
+
class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
"""Manage an asynchronous :sql:`COPY` operation."""
async def worker(self) -> None:
"""Push data to the server when available from the copy queue.
- Terminate reading when the queue receives a None.
+ Terminate reading when the queue receives a false-y value.
The function is designed to be run in a separate thread.
"""
assert data == sample_records
+def test_worker_error_propagated(conn, monkeypatch):
+ def copy_to_broken(pgconn, buffer):
+ raise ZeroDivisionError
+ yield
+
+ monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken)
+ cur = conn.cursor()
+ cur.execute("create temp table wat (a text, b text)")
+ with pytest.raises(ZeroDivisionError):
+ with cur.copy("copy wat from stdin") as copy:
+ copy.write("a,b")
+
+
@pytest.mark.slow
@pytest.mark.parametrize(
"fmt, set_types",
assert data == sample_records
+async def test_worker_error_propagated(aconn, monkeypatch):
+ def copy_to_broken(pgconn, buffer):
+ raise ZeroDivisionError
+ yield
+
+ monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken)
+ cur = aconn.cursor()
+ await cur.execute("create temp table wat (a text, b text)")
+ with pytest.raises(ZeroDivisionError):
+ async with cur.copy("copy wat from stdin") as copy:
+ await copy.write("a,b")
+
+
@pytest.mark.slow
@pytest.mark.parametrize(
"fmt, set_types",