]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(tests): auto-generate test_copy from async counterpart
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 10 Aug 2023 01:02:52 +0000 (02:02 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
tests/test_copy.py
tests/test_copy_async.py
tools/async_to_sync.py
tools/convert_async_to_sync.sh

index 7eae2197cff0105d935e7fdd664aa8ac9cc9db5d..f4deaf98beb53e7cc1415cc4ce6acee4c6b6a4ce 100644 (file)
@@ -1,3 +1,6 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'test_copy_async.py'
+# DO NOT CHANGE! Change the original file instead.
 import string
 import hashlib
 from io import BytesIO, StringIO
@@ -89,9 +92,10 @@ def test_copy_out_param(conn, ph, params):
 def test_read_rows(conn, format, typetype):
     cur = conn.cursor()
     with cur.copy(
-        f"""copy (
-            select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[]
-        ) to stdout (format {format.name})"""
+        """copy (
+            select 10::int4, 'hello'::text, '{0.0,1.0}'::float8[]
+        ) to stdout (format %s)"""
+        % format.name
     ) as copy:
         copy.set_types(["int4", "text", "float8[]"])
         row = copy.read_row()
@@ -161,7 +165,7 @@ def test_read_row_notypes(conn, format):
                 break
             rows.append(row)
 
-    ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
+    ref = [tuple((py_to_raw(i, format) for i in record)) for record in sample_records]
     assert rows == ref
 
 
@@ -170,7 +174,7 @@ def test_rows_notypes(conn, format):
     cur = conn.cursor()
     with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
         rows = list(copy.rows())
-    ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
+    ref = [tuple((py_to_raw(i, format) for i in record)) for record in sample_records]
     assert rows == ref
 
 
@@ -185,8 +189,7 @@ def test_copy_out_badntypes(conn, format, err):
 
 
 @pytest.mark.parametrize(
-    "format, buffer",
-    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+    "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
 )
 def test_copy_in_buffers(conn, format, buffer):
     cur = conn.cursor()
@@ -194,7 +197,8 @@ def test_copy_in_buffers(conn, format, buffer):
     with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
         copy.write(globals()[buffer])
 
-    data = cur.execute("select * from copy_in order by 1").fetchall()
+    cur.execute("select * from copy_in order by 1")
+    data = cur.fetchall()
     assert data == sample_records
 
 
@@ -240,7 +244,8 @@ def test_copy_in_str(conn):
     with cur.copy("copy copy_in from stdin (format text)") as copy:
         copy.write(sample_text.decode())
 
-    data = cur.execute("select * from copy_in order by 1").fetchall()
+    cur.execute("select * from copy_in order by 1")
+    data = cur.fetchall()
     assert data == sample_records
 
 
@@ -269,12 +274,12 @@ def test_copy_in_empty(conn, format):
 def test_copy_big_size_record(conn):
     cur = conn.cursor()
     ensure_table(cur, sample_tabledef)
-    data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
+    data = "".join((chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)))
     with cur.copy("copy copy_in (data) from stdin") as copy:
         copy.write_row([data])
 
     cur.execute("select data from copy_in limit 1")
-    assert cur.fetchone()[0] == data
+    assert cur.fetchone() == (data,)
 
 
 @pytest.mark.slow
@@ -282,13 +287,13 @@ def test_copy_big_size_record(conn):
 def test_copy_big_size_block(conn, pytype):
     cur = conn.cursor()
     ensure_table(cur, sample_tabledef)
-    data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+    data = "".join((choice(string.ascii_letters) for i in range(10 * 1024 * 1024)))
     copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n")
     with cur.copy("copy copy_in (data) from stdin") as copy:
         copy.write(copy_data)
 
     cur.execute("select data from copy_in limit 1")
-    assert cur.fetchone()[0] == data
+    assert cur.fetchone() == (data,)
 
 
 @pytest.mark.parametrize("format", Format)
@@ -312,7 +317,8 @@ def test_subclass_adapter(conn, format):
     with cur.copy(f"copy copy_in (data) from stdin (format {format.name})") as copy:
         copy.write_row(("hello",))
 
-    rec = cur.execute("select data from copy_in").fetchone()
+    cur.execute("select data from copy_in")
+    rec = cur.fetchone()
     assert rec[0] == "hellohello"
 
 
@@ -389,11 +395,12 @@ def test_copy_in_records(conn, format):
     with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
         for row in sample_records:
             if format == Format.BINARY:
-                row2 = tuple(Int4(i) if isinstance(i, int) else i for i in row)
+                row2 = tuple((Int4(i) if isinstance(i, int) else i for i in row))
                 row = row2  # type: ignore[assignment]
             copy.write_row(row)
 
-    data = cur.execute("select * from copy_in order by 1").fetchall()
+    cur.execute("select * from copy_in order by 1")
+    data = cur.fetchall()
     assert data == sample_records
 
 
@@ -407,7 +414,8 @@ def test_copy_in_records_set_types(conn, format):
         for row in sample_records:
             copy.write_row(row)
 
-    data = cur.execute("select * from copy_in order by 1").fetchall()
+    cur.execute("select * from copy_in order by 1")
+    data = cur.fetchall()
     assert data == sample_records
 
 
@@ -422,7 +430,8 @@ def test_copy_in_records_binary(conn, format):
         for row in sample_records:
             copy.write_row((None, row[2]))
 
-    data = cur.execute("select * from copy_in order by 1").fetchall()
+    cur.execute("select * from copy_in order by 1")
+    data = cur.fetchall()
     assert data == [(1, None, "hello"), (2, None, "world")]
 
 
@@ -436,12 +445,13 @@ def test_copy_in_allchars(conn):
             copy.write_row((i, None, chr(i)))
         copy.write_row((ord(eur), None, eur))
 
-    data = cur.execute(
+    cur.execute(
         """
 select col1 = ascii(data), col2 is null, length(data), count(*)
 from copy_in group by 1, 2, 3
 """
-    ).fetchall()
+    )
+    data = cur.fetchall()
     assert data == [(True, True, 1, 256)]
 
 
@@ -625,7 +635,8 @@ def test_worker_life(conn, format, buffer):
         assert copy.writer._worker
 
     assert not copy.writer._worker
-    data = cur.execute("select * from copy_in order by 1").fetchall()
+    cur.execute("select * from copy_in order by 1")
+    data = cur.fetchall()
     assert data == sample_records
 
 
@@ -656,14 +667,14 @@ def test_connection_writer(conn, format, buffer):
         assert copy.writer is writer
         copy.write(globals()[buffer])
 
-    data = cur.execute("select * from copy_in order by 1").fetchall()
+    cur.execute("select * from copy_in order by 1")
+    data = cur.fetchall()
     assert data == sample_records
 
 
 @pytest.mark.slow
 @pytest.mark.parametrize(
-    "fmt, set_types",
-    [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+    "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)]
 )
 @pytest.mark.parametrize("method", ["read", "iter", "row", "rows"])
 def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method):
@@ -718,8 +729,7 @@ def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method):
 
 @pytest.mark.slow
 @pytest.mark.parametrize(
-    "fmt, set_types",
-    [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+    "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)]
 )
 def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types):
     faker.format = PyFormat.from_pq(fmt)
@@ -765,7 +775,8 @@ def test_copy_table_across(conn_cls, dsn, faker, mode):
     faker.choose_schema(ncols=20)
     faker.make_records(20)
 
-    with conn_cls.connect(dsn) as conn1, conn_cls.connect(dsn) as conn2:
+    connect = conn_cls.connect
+    with connect(dsn) as conn1, connect(dsn) as conn2:
         faker.table_name = sql.Identifier("copy_src")
         conn1.execute(faker.drop_stmt)
         conn1.execute(faker.create_stmt)
@@ -785,7 +796,8 @@ def test_copy_table_across(conn_cls, dsn, faker, mode):
                     for data in copy1:
                         copy2.write(data)
 
-        recs = conn2.execute(faker.select_stmt).fetchall()
+        cur = conn2.execute(faker.select_stmt)
+        recs = cur.fetchall()
         for got, want in zip(recs, faker.records):
             faker.assert_record(got, want)
 
index 5589038cc69c56c7a15d33807a06908f85d6d2c2..3435224738ffa63bd610feec526712d94ac4e0c5 100644 (file)
@@ -93,9 +93,10 @@ async def test_copy_out_param(aconn, ph, params):
 async def test_read_rows(aconn, format, typetype):
     cur = aconn.cursor()
     async with cur.copy(
-        f"""copy (
-            select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[]
-        ) to stdout (format {format.name})"""
+        """copy (
+            select 10::int4, 'hello'::text, '{0.0,1.0}'::float8[]
+        ) to stdout (format %s)"""
+        % format.name
     ) as copy:
         copy.set_types(["int4", "text", "float8[]"])
         row = await copy.read_row()
index fce12d5c82fa8837559185a2601be6292e2637b6..647bb4f75621a4ffcfeb3fbcf30e9b8b958c3781 100755 (executable)
@@ -120,6 +120,7 @@ class RenameAsyncToSync(ast.NodeTransformer):
         "aconn_set": "conn_set",
         "alist": "list",
         "anext": "next",
+        "ensure_table_async": "ensure_table",
         "find_insert_problem_async": "find_insert_problem",
     }
 
index 637cd6083265c2bee141700492c9b9ceacab1e05..9906b2042097aa3848bf741243eafa75fd3b78b5 100755 (executable)
@@ -10,6 +10,7 @@ cd "${dir}/.."
 for async in \
     tests/test_client_cursor_async.py \
     tests/test_connection_async.py \
+    tests/test_copy_async.py \
     tests/test_cursor_async.py \
     tests/test_pipeline_async.py
 do