From: Daniele Varrazzo Date: Thu, 10 Aug 2023 01:02:52 +0000 (+0100) Subject: refactor(tests): auto-generate test_copy from async counterpart X-Git-Tag: pool-3.2.0~12^2~54 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0062cf854abfb7dddae902cb664a18b237445ce3;p=thirdparty%2Fpsycopg.git refactor(tests): auto-generate test_copy from async counterpart --- diff --git a/tests/test_copy.py b/tests/test_copy.py index 7eae2197c..f4deaf98b 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -1,3 +1,6 @@ +# WARNING: this file is auto-generated by 'async_to_sync.py' +# from the original file 'test_copy_async.py' +# DO NOT CHANGE! Change the original file instead. import string import hashlib from io import BytesIO, StringIO @@ -89,9 +92,10 @@ def test_copy_out_param(conn, ph, params): def test_read_rows(conn, format, typetype): cur = conn.cursor() with cur.copy( - f"""copy ( - select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[] - ) to stdout (format {format.name})""" + """copy ( + select 10::int4, 'hello'::text, '{0.0,1.0}'::float8[] + ) to stdout (format %s)""" + % format.name ) as copy: copy.set_types(["int4", "text", "float8[]"]) row = copy.read_row() @@ -161,7 +165,7 @@ def test_read_row_notypes(conn, format): break rows.append(row) - ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records] + ref = [tuple((py_to_raw(i, format) for i in record)) for record in sample_records] assert rows == ref @@ -170,7 +174,7 @@ def test_rows_notypes(conn, format): cur = conn.cursor() with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy: rows = list(copy.rows()) - ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records] + ref = [tuple((py_to_raw(i, format) for i in record)) for record in sample_records] assert rows == ref @@ -185,8 +189,7 @@ def test_copy_out_badntypes(conn, format, err): @pytest.mark.parametrize( - "format, buffer", - [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] ) def test_copy_in_buffers(conn, format, buffer): cur = conn.cursor() @@ -194,7 +197,8 @@ def test_copy_in_buffers(conn, format, buffer): with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: copy.write(globals()[buffer]) - data = cur.execute("select * from copy_in order by 1").fetchall() + cur.execute("select * from copy_in order by 1") + data = cur.fetchall() assert data == sample_records @@ -240,7 +244,8 @@ def test_copy_in_str(conn): with cur.copy("copy copy_in from stdin (format text)") as copy: copy.write(sample_text.decode()) - data = cur.execute("select * from copy_in order by 1").fetchall() + cur.execute("select * from copy_in order by 1") + data = cur.fetchall() assert data == sample_records @@ -269,12 +274,12 @@ def test_copy_in_empty(conn, format): def test_copy_big_size_record(conn): cur = conn.cursor() ensure_table(cur, sample_tabledef) - data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)) + data = "".join((chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))) with cur.copy("copy copy_in (data) from stdin") as copy: copy.write_row([data]) cur.execute("select data from copy_in limit 1") - assert cur.fetchone()[0] == data + assert cur.fetchone() == (data,) @pytest.mark.slow @@ -282,13 +287,13 @@ def test_copy_big_size_record(conn): def test_copy_big_size_block(conn, pytype): cur = conn.cursor() ensure_table(cur, sample_tabledef) - data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) + data = "".join((choice(string.ascii_letters) for i in range(10 * 1024 * 1024))) copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n") with cur.copy("copy copy_in (data) from stdin") as copy: copy.write(copy_data) cur.execute("select data from copy_in limit 1") - assert cur.fetchone()[0] == data + assert cur.fetchone() == (data,) @pytest.mark.parametrize("format", Format) @@ -312,7 +317,8 @@ def test_subclass_adapter(conn, format): with cur.copy(f"copy copy_in (data) from stdin (format {format.name})") as copy: copy.write_row(("hello",)) - rec = cur.execute("select data from copy_in").fetchone() + cur.execute("select data from copy_in") + rec = cur.fetchone() assert rec[0] == "hellohello" @@ -389,11 +395,12 @@ def test_copy_in_records(conn, format): with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: for row in sample_records: if format == Format.BINARY: - row2 = tuple(Int4(i) if isinstance(i, int) else i for i in row) + row2 = tuple((Int4(i) if isinstance(i, int) else i for i in row)) row = row2 # type: ignore[assignment] copy.write_row(row) - data = cur.execute("select * from copy_in order by 1").fetchall() + cur.execute("select * from copy_in order by 1") + data = cur.fetchall() assert data == sample_records @@ -407,7 +414,8 @@ def test_copy_in_records_set_types(conn, format): for row in sample_records: copy.write_row(row) - data = cur.execute("select * from copy_in order by 1").fetchall() + cur.execute("select * from copy_in order by 1") + data = cur.fetchall() assert data == sample_records @@ -422,7 +430,8 @@ def test_copy_in_records_binary(conn, format): for row in sample_records: copy.write_row((None, row[2])) - data = cur.execute("select * from copy_in order by 1").fetchall() + cur.execute("select * from copy_in order by 1") + data = cur.fetchall() assert data == [(1, None, "hello"), (2, None, "world")] @@ -436,12 +445,13 @@ def test_copy_in_allchars(conn): copy.write_row((i, None, chr(i))) copy.write_row((ord(eur), None, eur)) - data = cur.execute( + cur.execute( """ select col1 = ascii(data), col2 is null, length(data), count(*) from copy_in group by 1, 2, 3 """ - ).fetchall() + ) + data = cur.fetchall() assert data == [(True, True, 1, 256)] @@ -625,7 +635,8 @@ def test_worker_life(conn, format, buffer): assert copy.writer._worker assert not copy.writer._worker - data = cur.execute("select * from copy_in order by 1").fetchall() + cur.execute("select * from copy_in order by 1") + data = cur.fetchall() assert data == sample_records @@ -656,14 +667,14 @@ def test_connection_writer(conn, format, buffer): assert copy.writer is writer copy.write(globals()[buffer]) - data = cur.execute("select * from copy_in order by 1").fetchall() + cur.execute("select * from copy_in order by 1") + data = cur.fetchall() assert data == sample_records @pytest.mark.slow @pytest.mark.parametrize( - "fmt, set_types", - [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], + "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)] ) @pytest.mark.parametrize("method", ["read", "iter", "row", "rows"]) def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method): @@ -718,8 +729,7 @@ def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method): @pytest.mark.slow @pytest.mark.parametrize( - "fmt, set_types", - [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], + "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)] ) def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types): faker.format = PyFormat.from_pq(fmt) @@ -765,7 +775,8 @@ def test_copy_table_across(conn_cls, dsn, faker, mode): faker.choose_schema(ncols=20) faker.make_records(20) - with conn_cls.connect(dsn) as conn1, conn_cls.connect(dsn) as conn2: + connect = conn_cls.connect + with connect(dsn) as conn1, connect(dsn) as conn2: faker.table_name = sql.Identifier("copy_src") conn1.execute(faker.drop_stmt) conn1.execute(faker.create_stmt) @@ -785,7 +796,8 @@ def test_copy_table_across(conn_cls, dsn, faker, mode): for data in copy1: copy2.write(data) - recs = conn2.execute(faker.select_stmt).fetchall() + cur = conn2.execute(faker.select_stmt) + recs = cur.fetchall() for got, want in zip(recs, faker.records): faker.assert_record(got, want) diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 5589038cc..343522473 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -93,9 +93,10 @@ async def test_copy_out_param(aconn, ph, params): async def test_read_rows(aconn, format, typetype): cur = aconn.cursor() async with cur.copy( - f"""copy ( - select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[] - ) to stdout (format {format.name})""" + """copy ( + select 10::int4, 'hello'::text, '{0.0,1.0}'::float8[] + ) to stdout (format %s)""" + % format.name ) as copy: copy.set_types(["int4", "text", "float8[]"]) row = await copy.read_row() diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index fce12d5c8..647bb4f75 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -120,6 +120,7 @@ class RenameAsyncToSync(ast.NodeTransformer): "aconn_set": "conn_set", "alist": "list", "anext": "next", + "ensure_table_async": "ensure_table", "find_insert_problem_async": "find_insert_problem", } diff --git a/tools/convert_async_to_sync.sh b/tools/convert_async_to_sync.sh index 637cd6083..9906b2042 100755 --- a/tools/convert_async_to_sync.sh +++ b/tools/convert_async_to_sync.sh @@ -10,6 +10,7 @@ cd "${dir}/.." for async in \ tests/test_client_cursor_async.py \ tests/test_connection_async.py \ + tests/test_copy_async.py \ tests/test_cursor_async.py \ tests/test_pipeline_async.py do