From a9bfb5f349496293fccb70e72e05bfc2312fc506 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Fri, 20 May 2022 02:32:19 +0200 Subject: [PATCH] test(crdb): add copy tests Skip tests from psycopg, because there is no support for copy out and copy in syntax is different. Data types are different too (int/serial mean int8) so the binary data samples don't match the format. --- tests/crdb/test_copy.py | 174 +++++++++++++++++++++++++++++++++++++++ tests/fix_crdb.py | 1 + tests/test_copy.py | 6 +- tests/test_copy_async.py | 7 +- 4 files changed, 183 insertions(+), 5 deletions(-) create mode 100644 tests/crdb/test_copy.py diff --git a/tests/crdb/test_copy.py b/tests/crdb/test_copy.py new file mode 100644 index 000000000..76ca461cc --- /dev/null +++ b/tests/crdb/test_copy.py @@ -0,0 +1,174 @@ +import pytest +import string +from random import randrange, choice + +from psycopg.pq import Format +from psycopg import errors as e +from psycopg.types.numeric import Int4 + +from ..utils import eur +from ..test_copy import sample_text, sample_binary # noqa +from ..test_copy import ensure_table, sample_tabledef, sample_records + +# CRDB int/serial are int8 +sample_tabledef = sample_tabledef.replace("int", "int4").replace("serial", "int4") + +pytestmark = pytest.mark.crdb + + +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +def test_copy_in_buffers(conn, format, buffer): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + copy.write(globals()[buffer]) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +def test_copy_in_buffers_pg_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + with cur.copy("copy copy_in from stdin") as copy: + copy.write(sample_text) + copy.write(sample_text) + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_in_str(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy("copy copy_in from stdin") as copy: + copy.write(sample_text.decode()) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.xfail(reason="bad sqlstate - CRDB #81559") +def test_copy_in_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled): + with cur.copy("copy copy_in from stdin with binary") as copy: + copy.write(sample_text.decode()) + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_empty(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy(f"copy copy_in from stdin {copyopt(format)}"): + pass + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + assert cur.rowcount == 0 + + +@pytest.mark.slow +def test_copy_big_size_record(conn): + cur = conn.cursor() + ensure_table(cur, "id serial primary key, data text") + data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)) + with cur.copy("copy copy_in (data) from stdin") as copy: + copy.write_row([data]) + + cur.execute("select data from copy_in limit 1") + assert cur.fetchone()[0] == data + + +@pytest.mark.slow +def test_copy_big_size_block(conn): + cur = conn.cursor() + ensure_table(cur, "id serial primary key, data text") + data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) + copy_data = data + "\n" + with cur.copy("copy copy_in (data) from stdin") as copy: + copy.write(copy_data) + + cur.execute("select data from copy_in limit 1") + assert cur.fetchone()[0] == data + + +def test_copy_in_buffers_with_pg_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + with cur.copy("copy copy_in from stdin") as copy: + copy.write(sample_text) + copy.write(sample_text) + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + for row in sample_records: + if format == Format.BINARY: + row = tuple( + Int4(i) if isinstance(i, int) else i for i in row + ) # type: ignore[assignment] + copy.write_row(row) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records_set_types(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + copy.set_types(["int4", "int4", "text"]) + for row in sample_records: + copy.write_row(row) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records_binary(conn, format): + cur = conn.cursor() + ensure_table(cur, "col1 serial primary key, col2 int4, data text") + + with cur.copy(f"copy copy_in (col2, data) from stdin {copyopt(format)}") as copy: + for row in sample_records: + copy.write_row((None, row[2])) + + data = cur.execute("select col2, data from copy_in order by 2").fetchall() + assert data == [(None, "hello"), (None, "world")] + + +def test_copy_in_allchars(conn): + cur = conn.cursor() + ensure_table(cur, "col1 int primary key, col2 int, data text") + + with cur.copy("copy copy_in from stdin") as copy: + for i in range(1, 256): + copy.write_row((i, None, chr(i))) + copy.write_row((ord(eur), None, eur)) + + data = cur.execute( + """ +select col1 = ascii(data), col2 is null, length(data), count(*) +from copy_in group by 1, 2, 3 +""" + ).fetchall() + assert data == [(True, True, 1, 256)] + + +def copyopt(format): + return "with binary" if format == Format.BINARY else "" diff --git a/tests/fix_crdb.py b/tests/fix_crdb.py index 5c69f7e34..eb9aeb747 100644 --- a/tests/fix_crdb.py +++ b/tests/fix_crdb.py @@ -99,6 +99,7 @@ crdb_reasons = { "cidr": 18846, "composite": 27792, "copy": 41608, + "copy canceled": 81559, "cursor with hold": 77101, "deferrable": 48307, "do": 17511, diff --git a/tests/test_copy.py b/tests/test_copy.py index e343cde0a..73e88d7d2 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -19,10 +19,10 @@ from psycopg.types.numeric import Int4 from .utils import eur, gc_collect -sample_records = [(10, 20, "hello"), (40, None, "world")] +pytestmark = pytest.mark.crdb("skip", reason="copy") +sample_records = [(10, 20, "hello"), (40, None, "world")] sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')" - sample_tabledef = "col1 serial primary key, col2 int, data text" sample_text = b"""\ @@ -254,7 +254,7 @@ def test_copy_in_str(conn): assert data == sample_records -def test_copy_in_str_binary(conn): +def test_copy_in_error(conn): cur = conn.cursor() ensure_table(cur, sample_tabledef) with pytest.raises(e.QueryCanceled): diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 36a10c01e..12ae52d83 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -22,7 +22,10 @@ from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa from .test_copy import sample_values, sample_records, sample_tabledef from .test_copy import py_to_raw -pytestmark = pytest.mark.asyncio +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.crdb("skip", reason="copy"), +] @pytest.mark.parametrize("format", Format) @@ -246,7 +249,7 @@ async def test_copy_in_str(aconn): assert data == sample_records -async def test_copy_in_str_binary(aconn): +async def test_copy_in_error(aconn): cur = aconn.cursor() await ensure_table(cur, sample_tabledef) with pytest.raises(e.QueryCanceled): -- 2.47.2