From bdd8d49971cc1a48d922ebca56416356248ee8b7 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Fri, 21 Jul 2023 02:10:01 +0100 Subject: [PATCH] fix: don't clobber a Python exception on COPY FROM with QueryCanceled We trigger the server to raise the QueryCanceled; however, the original exception has more information (the traceback). We can consider the server exception just a notification that cancellation worked as expected. This is a mild change in behaviour, as the fixed tests state. However, raising QueryCanceled is not explicitly documented and not part of a strict interface, so we can probably change the exception raised without needing to wait for psycopg 4. Close #593 --- docs/news.rst | 7 +++++-- psycopg/psycopg/copy.py | 24 ++++++++++++++++++++---- tests/test_copy.py | 12 +++++------- tests/test_copy_async.py | 12 +++++------- 4 files changed, 35 insertions(+), 20 deletions(-) diff --git a/docs/news.rst b/docs/news.rst index 32da4ff5d..bd8e0c4ff 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -10,8 +10,8 @@ Future releases --------------- -Psycopg 3.1.10 -^^^^^^^^^^^^^^ +Psycopg 3.1.10 (unreleased) +^^^^^^^^^^^^^^^^^^^^^^^^^^^ - Fix prepared statement cache validation when exiting pipeline mode (or `~Cursor.executemany()`) in case an error occurred within the pipeline @@ -20,6 +20,9 @@ Psycopg 3.1.10 `OperationalError` in case of connection failure. `Error.pgconn` is now a shallow copy of the real libpq connection, and the latter is closed before the exception propagates (:ticket:`#565`). +- Don't clobber a Python exception raised during COPY FROM with the resulting + `!QueryCanceled` raised as a consequence (:ticket:`#593`). + Current release --------------- diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index 26a2d9e96..7bae6d22d 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -376,8 +376,16 @@ class LibpqWriter(Writer): else: bmsg = None - res = self.connection.wait(copy_end(self._pgconn, bmsg)) - self.cursor._results = [res] + try: + res = self.connection.wait(copy_end(self._pgconn, bmsg)) + # The QueryCanceled is expected if we sent an exception message to + # pgconn.put_copy_end(). The Python exception that generated that + # cancelling is more important, so don't clobber it. + except e.QueryCanceled: + if not bmsg: + raise + else: + self.cursor._results = [res] class QueuedLibpqDriver(LibpqWriter): @@ -583,8 +591,16 @@ class AsyncLibpqWriter(AsyncWriter): else: bmsg = None - res = await self.connection.wait(copy_end(self._pgconn, bmsg)) - self.cursor._results = [res] + try: + res = await self.connection.wait(copy_end(self._pgconn, bmsg)) + # The QueryCanceled is expected if we sent an exception message to + # pgconn.put_copy_end(). The Python exception that generated that + # cancelling is more important, so don't clobber it. + except e.QueryCanceled: + if not bmsg: + raise + else: + self.cursor._results = [res] class AsyncQueuedLibpqWriter(AsyncLibpqWriter): diff --git a/tests/test_copy.py b/tests/test_copy.py index 3a26f7863..29fad4581 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -271,7 +271,7 @@ def test_copy_in_str(conn): def test_copy_in_error(conn): cur = conn.cursor() ensure_table(cur, sample_tabledef) - with pytest.raises(e.QueryCanceled): + with pytest.raises(TypeError): with cur.copy("copy copy_in from stdin (format binary)") as copy: copy.write(sample_text.decode()) @@ -344,11 +344,10 @@ def test_subclass_adapter(conn, format): def test_copy_in_error_empty(conn, format): cur = conn.cursor() ensure_table(cur, sample_tabledef) - with pytest.raises(e.QueryCanceled) as exc: + with pytest.raises(ZeroDivisionError, match="mannaggiamiseria"): with cur.copy(f"copy copy_in from stdin (format {format.name})"): - raise Exception("mannaggiamiseria") + raise ZeroDivisionError("mannaggiamiseria") - assert "mannaggiamiseria" in str(exc.value) assert conn.info.transaction_status == conn.TransactionStatus.INERROR @@ -366,12 +365,11 @@ def test_copy_in_buffers_with_pg_error(conn): def test_copy_in_buffers_with_py_error(conn): cur = conn.cursor() ensure_table(cur, sample_tabledef) - with pytest.raises(e.QueryCanceled) as exc: + with pytest.raises(ZeroDivisionError, match="nuttengoggenio"): with cur.copy("copy copy_in from stdin (format text)") as copy: copy.write(sample_text) - raise Exception("nuttengoggenio") + raise ZeroDivisionError("nuttengoggenio") - assert "nuttengoggenio" in str(exc.value) assert conn.info.transaction_status == conn.TransactionStatus.INERROR diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index e49d8ff65..dd11d4bd2 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -263,7 +263,7 @@ async def test_copy_in_str(aconn): async def test_copy_in_error(aconn): cur = aconn.cursor() await ensure_table(cur, sample_tabledef) - with pytest.raises(e.QueryCanceled): + with pytest.raises(TypeError): async with cur.copy("copy copy_in from stdin (format binary)") as copy: await copy.write(sample_text.decode()) @@ -339,11 +339,10 @@ async def test_subclass_adapter(aconn, format): async def test_copy_in_error_empty(aconn, format): cur = aconn.cursor() await ensure_table(cur, sample_tabledef) - with pytest.raises(e.QueryCanceled) as exc: + with pytest.raises(ZeroDivisionError, match="mannaggiamiseria"): async with cur.copy(f"copy copy_in from stdin (format {format.name})"): - raise Exception("mannaggiamiseria") + raise ZeroDivisionError("mannaggiamiseria") - assert "mannaggiamiseria" in str(exc.value) assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR @@ -361,12 +360,11 @@ async def test_copy_in_buffers_with_pg_error(aconn): async def test_copy_in_buffers_with_py_error(aconn): cur = aconn.cursor() await ensure_table(cur, sample_tabledef) - with pytest.raises(e.QueryCanceled) as exc: + with pytest.raises(ZeroDivisionError, match="nuttengoggenio"): async with cur.copy("copy copy_in from stdin (format text)") as copy: await copy.write(sample_text) - raise Exception("nuttengoggenio") + raise ZeroDivisionError("nuttengoggenio") - assert "nuttengoggenio" in str(exc.value) assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR -- 2.47.2