From: Daniele Varrazzo Date: Thu, 10 Aug 2023 00:33:41 +0000 (+0100) Subject: refactor(tests): make test_copy and async counterpart more similar X-Git-Tag: pool-3.2.0~12^2~55 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=fc09e3454747925954ac9f4b780d470acaa04829;p=thirdparty%2Fpsycopg.git refactor(tests): make test_copy and async counterpart more similar --- diff --git a/tests/_test_copy.py b/tests/_test_copy.py new file mode 100644 index 000000000..4ee9ee5f9 --- /dev/null +++ b/tests/_test_copy.py @@ -0,0 +1,63 @@ +import struct + +from psycopg.pq import Format +from psycopg.copy import AsyncWriter +from psycopg.copy import FileWriter as FileWriter # noqa: F401 + +sample_records = [(40010, 40020, "hello"), (40040, None, "world")] +sample_values = "values (40010::int, 40020::int, 'hello'::text), (40040, NULL, 'world')" +sample_tabledef = "col1 serial primary key, col2 int, data text" + +sample_text = b"""\ +40010\t40020\thello +40040\t\\N\tworld +""" + +sample_binary_str = """ +5047 434f 5059 0aff 0d0a 00 +00 0000 0000 0000 00 +00 0300 0000 0400 009c 4a00 0000 0400 009c 5400 0000 0568 656c 6c6f + +0003 0000 0004 0000 9c68 ffff ffff 0000 0005 776f 726c 64 + +ff ff +""" + +sample_binary_rows = [ + bytes.fromhex("".join(row.split())) for row in sample_binary_str.split("\n\n") +] +sample_binary = b"".join(sample_binary_rows) + +special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"} + + +def ensure_table(cur, tabledef, name="copy_in"): + cur.execute(f"drop table if exists {name}") + cur.execute(f"create table {name} ({tabledef})") + + +async def ensure_table_async(cur, tabledef, name="copy_in"): + await cur.execute(f"drop table if exists {name}") + await cur.execute(f"create table {name} ({tabledef})") + + +def py_to_raw(item, fmt): + """Convert from Python type to the expected result from the db""" + if fmt == Format.TEXT: + if isinstance(item, int): + return str(item) + else: + if isinstance(item, int): + # Assume int4 + return struct.pack("!i", item) + elif isinstance(item, str): + return item.encode() + return item + + +class AsyncFileWriter(AsyncWriter): + def __init__(self, file): + self.file = file + + async def write(self, data): + self.file.write(data) diff --git a/tests/crdb/test_copy.py b/tests/crdb/test_copy.py index b7d26aa51..2bf714f1c 100644 --- a/tests/crdb/test_copy.py +++ b/tests/crdb/test_copy.py @@ -8,9 +8,9 @@ from psycopg.adapt import PyFormat from psycopg.types.numeric import Int4 from ..utils import eur, gc_collect, gc_count -from ..test_copy import sample_text, sample_binary # noqa -from ..test_copy import ensure_table, sample_records -from ..test_copy import sample_tabledef as sample_tabledef_pg +from .._test_copy import sample_text, sample_binary # noqa +from .._test_copy import ensure_table, sample_records +from .._test_copy import sample_tabledef as sample_tabledef_pg # CRDB int/serial are int8 sample_tabledef = sample_tabledef_pg.replace("int", "int4").replace("serial", "int4") diff --git a/tests/crdb/test_copy_async.py b/tests/crdb/test_copy_async.py index 45ee5eca0..a994d9071 100644 --- a/tests/crdb/test_copy_async.py +++ b/tests/crdb/test_copy_async.py @@ -8,9 +8,8 @@ from psycopg.adapt import PyFormat from psycopg.types.numeric import Int4 from ..utils import eur, gc_collect, gc_count -from ..test_copy import sample_text, sample_binary # noqa -from ..test_copy import sample_records -from ..test_copy_async import ensure_table +from .._test_copy import sample_text, sample_binary # noqa +from .._test_copy import ensure_table_async, sample_records from .test_copy import sample_tabledef, copyopt pytestmark = [pytest.mark.crdb, pytest.mark.anyio] @@ -22,7 +21,7 @@ pytestmark = [pytest.mark.crdb, pytest.mark.anyio] ) async def test_copy_in_buffers(aconn, format, buffer): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: await copy.write(globals()[buffer]) @@ -33,7 +32,7 @@ async def test_copy_in_buffers(aconn, format, buffer): async def test_copy_in_buffers_pg_error(aconn): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) with pytest.raises(e.UniqueViolation): async with cur.copy("copy copy_in from stdin") as copy: await copy.write(sample_text) @@ -43,7 +42,7 @@ async def test_copy_in_buffers_pg_error(aconn): async def test_copy_in_str(aconn): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) async with cur.copy("copy copy_in from stdin") as copy: await copy.write(sample_text.decode()) @@ -55,7 +54,7 @@ async def test_copy_in_str(aconn): @pytest.mark.xfail(reason="bad sqlstate - CRDB #81559") async def test_copy_in_error(aconn): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) with pytest.raises(e.QueryCanceled): async with cur.copy("copy copy_in from stdin with binary") as copy: await copy.write(sample_text.decode()) @@ -66,7 +65,7 @@ async def test_copy_in_error(aconn): @pytest.mark.parametrize("format", Format) async def test_copy_in_empty(aconn, format): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) async with cur.copy(f"copy copy_in from stdin {copyopt(format)}"): pass @@ -77,7 +76,7 @@ async def test_copy_in_empty(aconn, format): @pytest.mark.slow async def test_copy_big_size_record(aconn): cur = aconn.cursor() - await ensure_table(cur, "id serial primary key, data text") + await ensure_table_async(cur, "id serial primary key, data text") data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)) async with cur.copy("copy copy_in (data) from stdin") as copy: await copy.write_row([data]) @@ -89,7 +88,7 @@ async def test_copy_big_size_record(aconn): @pytest.mark.slow async def test_copy_big_size_block(aconn): cur = aconn.cursor() - await ensure_table(cur, "id serial primary key, data text") + await ensure_table_async(cur, "id serial primary key, data text") data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) copy_data = data + "\n" async with cur.copy("copy copy_in (data) from stdin") as copy: @@ -101,7 +100,7 @@ async def test_copy_big_size_block(aconn): async def test_copy_in_buffers_with_pg_error(aconn): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) with pytest.raises(e.UniqueViolation): async with cur.copy("copy copy_in from stdin") as copy: await copy.write(sample_text) @@ -113,7 +112,7 @@ async def test_copy_in_buffers_with_pg_error(aconn): @pytest.mark.parametrize("format", Format) async def test_copy_in_records(aconn, format): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: for row in sample_records: @@ -131,7 +130,7 @@ async def test_copy_in_records(aconn, format): @pytest.mark.parametrize("format", Format) async def test_copy_in_records_set_types(aconn, format): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: copy.set_types(["int4", "int4", "text"]) @@ -146,7 +145,7 @@ async def test_copy_in_records_set_types(aconn, format): @pytest.mark.parametrize("format", Format) async def test_copy_in_records_binary(aconn, format): cur = aconn.cursor() - await ensure_table(cur, "col1 serial primary key, col2 int4, data text") + await ensure_table_async(cur, "col1 serial primary key, col2 int4, data text") async with cur.copy( f"copy copy_in (col2, data) from stdin {copyopt(format)}" @@ -162,7 +161,7 @@ async def test_copy_in_records_binary(aconn, format): @pytest.mark.crdb_skip("copy canceled") async def test_copy_in_buffers_with_py_error(aconn): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) with pytest.raises(e.QueryCanceled) as exc: async with cur.copy("copy copy_in from stdin") as copy: await copy.write(sample_text) @@ -174,7 +173,7 @@ async def test_copy_in_buffers_with_py_error(aconn): async def test_copy_in_allchars(aconn): cur = aconn.cursor() - await ensure_table(cur, "col1 int primary key, col2 int, data text") + await ensure_table_async(cur, "col1 int primary key, col2 int, data text") async with cur.copy("copy copy_in from stdin") as copy: for i in range(1, 256): diff --git a/tests/test_copy.py b/tests/test_copy.py index 2c21368ae..7eae2197c 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -1,5 +1,4 @@ import string -import struct import hashlib from io import BytesIO, StringIO from random import choice, randrange @@ -12,42 +11,19 @@ from psycopg import pq from psycopg import sql from psycopg import errors as e from psycopg.pq import Format -from psycopg.copy import Copy, LibpqWriter, QueuedLibpqWriter, FileWriter +from psycopg.copy import Copy, LibpqWriter, QueuedLibpqWriter from psycopg.adapt import PyFormat from psycopg.types import TypeInfo from psycopg.types.hstore import register_hstore from psycopg.types.numeric import Int4 from .utils import eur, gc_collect, gc_count +from ._test_copy import sample_text, sample_binary, sample_binary_rows # noqa +from ._test_copy import sample_values, sample_records, sample_tabledef +from ._test_copy import ensure_table, py_to_raw, special_chars, FileWriter pytestmark = pytest.mark.crdb_skip("copy") -sample_records = [(40010, 40020, "hello"), (40040, None, "world")] -sample_values = "values (40010::int, 40020::int, 'hello'::text), (40040, NULL, 'world')" -sample_tabledef = "col1 serial primary key, col2 int, data text" - -sample_text = b"""\ -40010\t40020\thello -40040\t\\N\tworld -""" - -sample_binary_str = """ -5047 434f 5059 0aff 0d0a 00 -00 0000 0000 0000 00 -00 0300 0000 0400 009c 4a00 0000 0400 009c 5400 0000 0568 656c 6c6f - -0003 0000 0004 0000 9c68 ffff ffff 0000 0005 776f 726c 64 - -ff ff -""" - -sample_binary_rows = [ - bytes.fromhex("".join(row.split())) for row in sample_binary_str.split("\n\n") -] -sample_binary = b"".join(sample_binary_rows) - -special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"} - @pytest.mark.parametrize("format", Format) def test_copy_out_read(conn, format): @@ -320,9 +296,9 @@ def test_subclass_adapter(conn, format): if format == Format.TEXT: from psycopg.types.string import StrDumper as BaseDumper else: - from psycopg.types.string import ( # type: ignore[assignment] - StrBinaryDumper as BaseDumper, - ) + from psycopg.types.string import StrBinaryDumper + + BaseDumper = StrBinaryDumper # type: ignore class MyStrDumper(BaseDumper): def dump(self, obj): @@ -413,9 +389,8 @@ def test_copy_in_records(conn, format): with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: for row in sample_records: if format == Format.BINARY: - row = tuple( - Int4(i) if isinstance(i, int) else i for i in row - ) # type: ignore[assignment] + row2 = tuple(Int4(i) if isinstance(i, int) else i for i in row) + row = row2 # type: ignore[assignment] copy.write_row(row) data = cur.execute("select * from copy_in order by 1").fetchall() @@ -815,25 +790,6 @@ def test_copy_table_across(conn_cls, dsn, faker, mode): faker.assert_record(got, want) -def py_to_raw(item, fmt): - """Convert from Python type to the expected result from the db""" - if fmt == Format.TEXT: - if isinstance(item, int): - return str(item) - else: - if isinstance(item, int): - # Assume int4 - return struct.pack("!i", item) - elif isinstance(item, str): - return item.encode() - return item - - -def ensure_table(cur, tabledef, name="copy_in"): - cur.execute(f"drop table if exists {name}") - cur.execute(f"create table {name} ({tabledef})") - - class DataGenerator: def __init__(self, conn, nrecs, srec, offset=0, block_size=8192): self.conn = conn diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index dd11d4bd2..5589038cc 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -11,21 +11,18 @@ from psycopg import pq from psycopg import sql from psycopg import errors as e from psycopg.pq import Format -from psycopg.copy import AsyncCopy -from psycopg.copy import AsyncWriter, AsyncLibpqWriter, AsyncQueuedLibpqWriter -from psycopg.types import TypeInfo +from psycopg.copy import AsyncCopy, AsyncLibpqWriter, AsyncQueuedLibpqWriter from psycopg.adapt import PyFormat +from psycopg.types import TypeInfo from psycopg.types.hstore import register_hstore from psycopg.types.numeric import Int4 from .utils import alist, eur, gc_collect, gc_count -from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa -from .test_copy import sample_values, sample_records, sample_tabledef -from .test_copy import py_to_raw, special_chars +from ._test_copy import sample_text, sample_binary, sample_binary_rows # noqa +from ._test_copy import sample_values, sample_records, sample_tabledef +from ._test_copy import ensure_table_async, py_to_raw, special_chars, AsyncFileWriter -pytestmark = [ - pytest.mark.crdb_skip("copy"), -] +pytestmark = pytest.mark.crdb_skip("copy") @pytest.mark.parametrize("format", Format) @@ -114,7 +111,7 @@ async def test_rows(aconn, format): async with cur.copy( f"copy ({sample_values}) to stdout (format {format.name})" ) as copy: - copy.set_types("int4 int4 text".split()) + copy.set_types(["int4", "int4", "text"]) rows = await alist(copy.rows()) assert rows == sample_records @@ -204,7 +201,7 @@ async def test_copy_out_badntypes(aconn, format, err): ) async def test_copy_in_buffers(aconn, format, buffer): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: await copy.write(globals()[buffer]) @@ -215,7 +212,7 @@ async def test_copy_in_buffers(aconn, format, buffer): async def test_copy_in_buffers_pg_error(aconn): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) with pytest.raises(e.UniqueViolation): async with cur.copy("copy copy_in from stdin (format text)") as copy: await copy.write(sample_text) @@ -251,7 +248,7 @@ async def test_copy_bad_result(aconn): async def test_copy_in_str(aconn): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) async with cur.copy("copy copy_in from stdin (format text)") as copy: await copy.write(sample_text.decode()) @@ -262,7 +259,7 @@ async def test_copy_in_str(aconn): async def test_copy_in_error(aconn): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) with pytest.raises(TypeError): async with cur.copy("copy copy_in from stdin (format binary)") as copy: await copy.write(sample_text.decode()) @@ -273,7 +270,7 @@ async def test_copy_in_error(aconn): @pytest.mark.parametrize("format", Format) async def test_copy_in_empty(aconn, format): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) async with cur.copy(f"copy copy_in from stdin (format {format.name})"): pass @@ -284,7 +281,7 @@ async def test_copy_in_empty(aconn, format): @pytest.mark.slow async def test_copy_big_size_record(aconn): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)) async with cur.copy("copy copy_in (data) from stdin") as copy: await copy.write_row([data]) @@ -297,7 +294,7 @@ async def test_copy_big_size_record(aconn): @pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview]) async def test_copy_big_size_block(aconn, pytype): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n") async with cur.copy("copy copy_in (data) from stdin") as copy: @@ -312,9 +309,9 @@ async def test_subclass_adapter(aconn, format): if format == Format.TEXT: from psycopg.types.string import StrDumper as BaseDumper else: - from psycopg.types.string import ( # type: ignore[assignment] - StrBinaryDumper as BaseDumper, - ) + from psycopg.types.string import StrBinaryDumper + + BaseDumper = StrBinaryDumper # type: ignore class MyStrDumper(BaseDumper): def dump(self, obj): @@ -323,7 +320,7 @@ async def test_subclass_adapter(aconn, format): aconn.adapters.register_dumper(str, MyStrDumper) cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) async with cur.copy( f"copy copy_in (data) from stdin (format {format.name})" @@ -338,7 +335,7 @@ async def test_subclass_adapter(aconn, format): @pytest.mark.parametrize("format", Format) async def test_copy_in_error_empty(aconn, format): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) with pytest.raises(ZeroDivisionError, match="mannaggiamiseria"): async with cur.copy(f"copy copy_in from stdin (format {format.name})"): raise ZeroDivisionError("mannaggiamiseria") @@ -348,7 +345,7 @@ async def test_copy_in_error_empty(aconn, format): async def test_copy_in_buffers_with_pg_error(aconn): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) with pytest.raises(e.UniqueViolation): async with cur.copy("copy copy_in from stdin (format text)") as copy: await copy.write(sample_text) @@ -359,7 +356,7 @@ async def test_copy_in_buffers_with_pg_error(aconn): async def test_copy_in_buffers_with_py_error(aconn): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) with pytest.raises(ZeroDivisionError, match="nuttengoggenio"): async with cur.copy("copy copy_in from stdin (format text)") as copy: await copy.write(sample_text) @@ -405,14 +402,13 @@ async def test_copy_out_server_error(aconn): @pytest.mark.parametrize("format", Format) async def test_copy_in_records(aconn, format): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: for row in sample_records: if format == Format.BINARY: - row = tuple( - Int4(i) if isinstance(i, int) else i for i in row - ) # type: ignore[assignment] + row2 = tuple(Int4(i) if isinstance(i, int) else i for i in row) + row = row2 # type: ignore[assignment] await copy.write_row(row) await cur.execute("select * from copy_in order by 1") @@ -423,7 +419,7 @@ async def test_copy_in_records(aconn, format): @pytest.mark.parametrize("format", Format) async def test_copy_in_records_set_types(aconn, format): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: copy.set_types(["int4", "int4", "text"]) @@ -438,7 +434,7 @@ async def test_copy_in_records_set_types(aconn, format): @pytest.mark.parametrize("format", Format) async def test_copy_in_records_binary(aconn, format): cur = aconn.cursor() - await ensure_table(cur, "col1 serial primary key, col2 int, data text") + await ensure_table_async(cur, "col1 serial primary key, col2 int, data text") async with cur.copy( f"copy copy_in (col2, data) from stdin (format {format.name})" @@ -453,7 +449,7 @@ async def test_copy_in_records_binary(aconn, format): async def test_copy_in_allchars(aconn): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) await aconn.execute("set client_encoding to utf8") async with cur.copy("copy copy_in from stdin (format text)") as copy: @@ -642,7 +638,7 @@ async def test_description(aconn): ) async def test_worker_life(aconn, format, buffer): cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) async with cur.copy( f"copy copy_in from stdin (format {format.name})", writer=AsyncQueuedLibpqWriter(cur), @@ -679,7 +675,7 @@ async def test_connection_writer(aconn, format, buffer): cur = aconn.cursor() writer = AsyncLibpqWriter(cur) - await ensure_table(cur, sample_tabledef) + await ensure_table_async(cur, sample_tabledef) async with cur.copy( f"copy copy_in from stdin (format {format.name})", writer=writer ) as copy: @@ -823,11 +819,6 @@ async def test_copy_table_across(aconn_cls, dsn, faker, mode): faker.assert_record(got, want) -async def ensure_table(cur, tabledef, name="copy_in"): - await cur.execute(f"drop table if exists {name}") - await cur.execute(f"create table {name} ({tabledef})") - - class DataGenerator: def __init__(self, conn, nrecs, srec, offset=0, block_size=8192): self.conn = conn @@ -838,7 +829,7 @@ class DataGenerator: async def ensure_table(self): cur = self.conn.cursor() - await ensure_table(cur, "id integer primary key, data text") + await ensure_table_async(cur, "id integer primary key, data text") def records(self): for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)): @@ -879,11 +870,3 @@ class DataGenerator: block = block.encode() m.update(block) return m.hexdigest() - - -class AsyncFileWriter(AsyncWriter): - def __init__(self, file): - self.file = file - - async def write(self, data): - self.file.write(data) diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index 706dc66b8..fce12d5c8 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -106,7 +106,11 @@ class RenameAsyncToSync(ast.NodeTransformer): names_map = { "AsyncClientCursor": "ClientCursor", "AsyncConnection": "Connection", + "AsyncCopy": "Copy", "AsyncCursor": "Cursor", + "AsyncFileWriter": "FileWriter", + "AsyncLibpqWriter": "LibpqWriter", + "AsyncQueuedLibpqWriter": "QueuedLibpqWriter", "AsyncRawCursor": "RawCursor", "AsyncServerCursor": "ServerCursor", "aclose": "close",