]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
test(crdb): add copy tests
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 20 May 2022 00:32:19 +0000 (02:32 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jul 2022 11:58:33 +0000 (12:58 +0100)
Skip tests from psycopg, because there is no support for copy out and
copy in syntax is different. Data types are different too (int/serial
mean int8) so the binary data samples don't match the format.

tests/crdb/test_copy.py [new file with mode: 0644]
tests/fix_crdb.py
tests/test_copy.py
tests/test_copy_async.py

diff --git a/tests/crdb/test_copy.py b/tests/crdb/test_copy.py
new file mode 100644 (file)
index 0000000..76ca461
--- /dev/null
@@ -0,0 +1,174 @@
+import pytest
+import string
+from random import randrange, choice
+
+from psycopg.pq import Format
+from psycopg import errors as e
+from psycopg.types.numeric import Int4
+
+from ..utils import eur
+from ..test_copy import sample_text, sample_binary  # noqa
+from ..test_copy import ensure_table, sample_tabledef, sample_records
+
+# CRDB int/serial are int8
+sample_tabledef = sample_tabledef.replace("int", "int4").replace("serial", "int4")
+
+pytestmark = pytest.mark.crdb
+
+
+@pytest.mark.parametrize(
+    "format, buffer",
+    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+def test_copy_in_buffers(conn, format, buffer):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+    with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+        copy.write(globals()[buffer])
+
+    data = cur.execute("select * from copy_in order by 1").fetchall()
+    assert data == sample_records
+
+
+def test_copy_in_buffers_pg_error(conn):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+    with pytest.raises(e.UniqueViolation):
+        with cur.copy("copy copy_in from stdin") as copy:
+            copy.write(sample_text)
+            copy.write(sample_text)
+    assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_in_str(conn):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+    with cur.copy("copy copy_in from stdin") as copy:
+        copy.write(sample_text.decode())
+
+    data = cur.execute("select * from copy_in order by 1").fetchall()
+    assert data == sample_records
+
+
+@pytest.mark.xfail(reason="bad sqlstate - CRDB #81559")
+def test_copy_in_error(conn):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+    with pytest.raises(e.QueryCanceled):
+        with cur.copy("copy copy_in from stdin with binary") as copy:
+            copy.write(sample_text.decode())
+
+    assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_empty(conn, format):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+    with cur.copy(f"copy copy_in from stdin {copyopt(format)}"):
+        pass
+
+    assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+    assert cur.rowcount == 0
+
+
+@pytest.mark.slow
+def test_copy_big_size_record(conn):
+    cur = conn.cursor()
+    ensure_table(cur, "id serial primary key, data text")
+    data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
+    with cur.copy("copy copy_in (data) from stdin") as copy:
+        copy.write_row([data])
+
+    cur.execute("select data from copy_in limit 1")
+    assert cur.fetchone()[0] == data
+
+
+@pytest.mark.slow
+def test_copy_big_size_block(conn):
+    cur = conn.cursor()
+    ensure_table(cur, "id serial primary key, data text")
+    data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+    copy_data = data + "\n"
+    with cur.copy("copy copy_in (data) from stdin") as copy:
+        copy.write(copy_data)
+
+    cur.execute("select data from copy_in limit 1")
+    assert cur.fetchone()[0] == data
+
+
+def test_copy_in_buffers_with_pg_error(conn):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+    with pytest.raises(e.UniqueViolation):
+        with cur.copy("copy copy_in from stdin") as copy:
+            copy.write(sample_text)
+            copy.write(sample_text)
+
+    assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records(conn, format):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+
+    with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+        for row in sample_records:
+            if format == Format.BINARY:
+                row = tuple(
+                    Int4(i) if isinstance(i, int) else i for i in row
+                )  # type: ignore[assignment]
+            copy.write_row(row)
+
+    data = cur.execute("select * from copy_in order by 1").fetchall()
+    assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records_set_types(conn, format):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+
+    with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+        copy.set_types(["int4", "int4", "text"])
+        for row in sample_records:
+            copy.write_row(row)
+
+    data = cur.execute("select * from copy_in order by 1").fetchall()
+    assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records_binary(conn, format):
+    cur = conn.cursor()
+    ensure_table(cur, "col1 serial primary key, col2 int4, data text")
+
+    with cur.copy(f"copy copy_in (col2, data) from stdin {copyopt(format)}") as copy:
+        for row in sample_records:
+            copy.write_row((None, row[2]))
+
+    data = cur.execute("select col2, data from copy_in order by 2").fetchall()
+    assert data == [(None, "hello"), (None, "world")]
+
+
+def test_copy_in_allchars(conn):
+    cur = conn.cursor()
+    ensure_table(cur, "col1 int primary key, col2 int, data text")
+
+    with cur.copy("copy copy_in from stdin") as copy:
+        for i in range(1, 256):
+            copy.write_row((i, None, chr(i)))
+        copy.write_row((ord(eur), None, eur))
+
+    data = cur.execute(
+        """
+select col1 = ascii(data), col2 is null, length(data), count(*)
+from copy_in group by 1, 2, 3
+"""
+    ).fetchall()
+    assert data == [(True, True, 1, 256)]
+
+
+def copyopt(format):
+    return "with binary" if format == Format.BINARY else ""
index 5c69f7e34766bd75f003ef2b5c968406a918a252..eb9aeb7473403f5ab85a5c0475dd945d05a2e588 100644 (file)
@@ -99,6 +99,7 @@ crdb_reasons = {
     "cidr": 18846,
     "composite": 27792,
     "copy": 41608,
+    "copy canceled": 81559,
     "cursor with hold": 77101,
     "deferrable": 48307,
     "do": 17511,
index e343cde0a224003fadd4dfc08b7e9b0d67a69f9f..73e88d7d2a1497b4bdc151e7d47cee0f9931e543 100644 (file)
@@ -19,10 +19,10 @@ from psycopg.types.numeric import Int4
 
 from .utils import eur, gc_collect
 
-sample_records = [(10, 20, "hello"), (40, None, "world")]
+pytestmark = pytest.mark.crdb("skip", reason="copy")
 
+sample_records = [(10, 20, "hello"), (40, None, "world")]
 sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')"
-
 sample_tabledef = "col1 serial primary key, col2 int, data text"
 
 sample_text = b"""\
@@ -254,7 +254,7 @@ def test_copy_in_str(conn):
     assert data == sample_records
 
 
-def test_copy_in_str_binary(conn):
+def test_copy_in_error(conn):
     cur = conn.cursor()
     ensure_table(cur, sample_tabledef)
     with pytest.raises(e.QueryCanceled):
index 36a10c01e8c134ed9465aa53d5bdafeebd8c4eea..12ae52d833cfbd4e749109ca8461057056d6efa1 100644 (file)
@@ -22,7 +22,10 @@ from .test_copy import sample_text, sample_binary, sample_binary_rows  # noqa
 from .test_copy import sample_values, sample_records, sample_tabledef
 from .test_copy import py_to_raw
 
-pytestmark = pytest.mark.asyncio
+pytestmark = [
+    pytest.mark.asyncio,
+    pytest.mark.crdb("skip", reason="copy"),
+]
 
 
 @pytest.mark.parametrize("format", Format)
@@ -246,7 +249,7 @@ async def test_copy_in_str(aconn):
     assert data == sample_records
 
 
-async def test_copy_in_str_binary(aconn):
+async def test_copy_in_error(aconn):
     cur = aconn.cursor()
     await ensure_table(cur, sample_tabledef)
     with pytest.raises(e.QueryCanceled):