From a22db4a45f908d2d9ffda6ae7112894bf8999f59 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Tue, 14 Oct 2025 03:41:34 +0200 Subject: [PATCH] test: add test to reveal the buffer corruption on dump error --- tests/test_copy.py | 15 +++++++++++++++ tests/test_copy_async.py | 17 +++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/tests/test_copy.py b/tests/test_copy.py index a0cad5aae..e78636819 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -686,6 +686,21 @@ def test_binary_partial_row(conn): copy.write_row([16, [[None], None]]) +@pytest.mark.parametrize("format", pq.Format) +def test_clean_buffer_on_error(conn, format): + cur = conn.cursor() + ensure_table(cur, "id serial primary key, num int4, obj jsonb") + with cur.copy(f"copy copy_in (num, obj) from stdin (format {format.name})") as copy: + copy.set_types(["int4", "jsonb"]) + copy.write_row([15, {}]) + with pytest.raises(TypeError): + copy.write_row([16, 1j]) + copy.write_row([17, []]) + + cur.execute("select num, obj from copy_in order by id") + assert cur.fetchall() == [(15, {}), (17, [])] + + @pytest.mark.parametrize( "format, buffer", [(pq.Format.TEXT, "sample_text"), (pq.Format.BINARY, "sample_binary")], diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 70d30dbda..1c9998b31 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -702,6 +702,23 @@ async def test_binary_partial_row(aconn): await copy.write_row([16, [[None], None]]) +@pytest.mark.parametrize("format", pq.Format) +async def test_clean_buffer_on_error(aconn, format): + cur = aconn.cursor() + await ensure_table_async(cur, "id serial primary key, num int4, obj jsonb") + async with cur.copy( + f"copy copy_in (num, obj) from stdin (format {format.name})" + ) as copy: + copy.set_types(["int4", "jsonb"]) + await copy.write_row([15, {}]) + with pytest.raises(TypeError): + await copy.write_row([16, 1j]) + await copy.write_row([17, []]) + + await cur.execute("select num, obj from copy_in order by id") + assert (await cur.fetchall()) == [(15, {}), (17, [])] + + @pytest.mark.parametrize( "format, buffer", [(pq.Format.TEXT, "sample_text"), (pq.Format.BINARY, "sample_binary")], -- 2.47.3