]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
test: add test to reveal the buffer corruption on dump error
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 14 Oct 2025 01:41:34 +0000 (03:41 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 14 Oct 2025 01:42:29 +0000 (03:42 +0200)
tests/test_copy.py
tests/test_copy_async.py

index a0cad5aaef4ca54fe03f89681d37058b61740a0a..e786368199ecf017517f95311ae0ac3c92461328 100644 (file)
@@ -686,6 +686,21 @@ def test_binary_partial_row(conn):
             copy.write_row([16, [[None], None]])
 
 
+@pytest.mark.parametrize("format", pq.Format)
+def test_clean_buffer_on_error(conn, format):
+    cur = conn.cursor()
+    ensure_table(cur, "id serial primary key, num int4, obj jsonb")
+    with cur.copy(f"copy copy_in (num, obj) from stdin (format {format.name})") as copy:
+        copy.set_types(["int4", "jsonb"])
+        copy.write_row([15, {}])
+        with pytest.raises(TypeError):
+            copy.write_row([16, 1j])
+        copy.write_row([17, []])
+
+    cur.execute("select num, obj from copy_in order by id")
+    assert cur.fetchall() == [(15, {}), (17, [])]
+
+
 @pytest.mark.parametrize(
     "format, buffer",
     [(pq.Format.TEXT, "sample_text"), (pq.Format.BINARY, "sample_binary")],
index 70d30dbda9d4951962b9a7a7ba85a0b589835d4e..1c9998b312f851c7a6d115429e2896a1c0bdd45f 100644 (file)
@@ -702,6 +702,23 @@ async def test_binary_partial_row(aconn):
             await copy.write_row([16, [[None], None]])
 
 
+@pytest.mark.parametrize("format", pq.Format)
+async def test_clean_buffer_on_error(aconn, format):
+    cur = aconn.cursor()
+    await ensure_table_async(cur, "id serial primary key, num int4, obj jsonb")
+    async with cur.copy(
+        f"copy copy_in (num, obj) from stdin (format {format.name})"
+    ) as copy:
+        copy.set_types(["int4", "jsonb"])
+        await copy.write_row([15, {}])
+        with pytest.raises(TypeError):
+            await copy.write_row([16, 1j])
+        await copy.write_row([17, []])
+
+    await cur.execute("select num, obj from copy_in order by id")
+    assert (await cur.fetchall()) == [(15, {}), (17, [])]
+
+
 @pytest.mark.parametrize(
     "format, buffer",
     [(pq.Format.TEXT, "sample_text"), (pq.Format.BINARY, "sample_binary")],