]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(tests): make test_copy and async counterpart more similar
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 10 Aug 2023 00:33:41 +0000 (01:33 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
tests/_test_copy.py [new file with mode: 0644]
tests/crdb/test_copy.py
tests/crdb/test_copy_async.py
tests/test_copy.py
tests/test_copy_async.py
tools/async_to_sync.py

diff --git a/tests/_test_copy.py b/tests/_test_copy.py
new file mode 100644 (file)
index 0000000..4ee9ee5
--- /dev/null
@@ -0,0 +1,63 @@
+import struct
+
+from psycopg.pq import Format
+from psycopg.copy import AsyncWriter
+from psycopg.copy import FileWriter as FileWriter  # noqa: F401
+
+sample_records = [(40010, 40020, "hello"), (40040, None, "world")]
+sample_values = "values (40010::int, 40020::int, 'hello'::text), (40040, NULL, 'world')"
+sample_tabledef = "col1 serial primary key, col2 int, data text"
+
+sample_text = b"""\
+40010\t40020\thello
+40040\t\\N\tworld
+"""
+
+sample_binary_str = """
+5047 434f 5059 0aff 0d0a 00
+00 0000 0000 0000 00
+00 0300 0000 0400 009c 4a00 0000 0400 009c 5400 0000 0568 656c 6c6f
+
+0003 0000 0004 0000 9c68 ffff ffff 0000 0005 776f 726c 64
+
+ff ff
+"""
+
+sample_binary_rows = [
+    bytes.fromhex("".join(row.split())) for row in sample_binary_str.split("\n\n")
+]
+sample_binary = b"".join(sample_binary_rows)
+
+special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"}
+
+
+def ensure_table(cur, tabledef, name="copy_in"):
+    cur.execute(f"drop table if exists {name}")
+    cur.execute(f"create table {name} ({tabledef})")
+
+
+async def ensure_table_async(cur, tabledef, name="copy_in"):
+    await cur.execute(f"drop table if exists {name}")
+    await cur.execute(f"create table {name} ({tabledef})")
+
+
+def py_to_raw(item, fmt):
+    """Convert from Python type to the expected result from the db"""
+    if fmt == Format.TEXT:
+        if isinstance(item, int):
+            return str(item)
+    else:
+        if isinstance(item, int):
+            # Assume int4
+            return struct.pack("!i", item)
+        elif isinstance(item, str):
+            return item.encode()
+    return item
+
+
+class AsyncFileWriter(AsyncWriter):
+    def __init__(self, file):
+        self.file = file
+
+    async def write(self, data):
+        self.file.write(data)
index b7d26aa516c07aa30ac10111732dacf92cf0215d..2bf714f1c3389ebb4525b52a56300ab8b87e8ac1 100644 (file)
@@ -8,9 +8,9 @@ from psycopg.adapt import PyFormat
 from psycopg.types.numeric import Int4
 
 from ..utils import eur, gc_collect, gc_count
-from ..test_copy import sample_text, sample_binary  # noqa
-from ..test_copy import ensure_table, sample_records
-from ..test_copy import sample_tabledef as sample_tabledef_pg
+from .._test_copy import sample_text, sample_binary  # noqa
+from .._test_copy import ensure_table, sample_records
+from .._test_copy import sample_tabledef as sample_tabledef_pg
 
 # CRDB int/serial are int8
 sample_tabledef = sample_tabledef_pg.replace("int", "int4").replace("serial", "int4")
index 45ee5eca08c44a05f773aedfaddfc77ac3fbf815..a994d9071fad16cdf7bc9325d97ad27783416256 100644 (file)
@@ -8,9 +8,8 @@ from psycopg.adapt import PyFormat
 from psycopg.types.numeric import Int4
 
 from ..utils import eur, gc_collect, gc_count
-from ..test_copy import sample_text, sample_binary  # noqa
-from ..test_copy import sample_records
-from ..test_copy_async import ensure_table
+from .._test_copy import sample_text, sample_binary  # noqa
+from .._test_copy import ensure_table_async, sample_records
 from .test_copy import sample_tabledef, copyopt
 
 pytestmark = [pytest.mark.crdb, pytest.mark.anyio]
@@ -22,7 +21,7 @@ pytestmark = [pytest.mark.crdb, pytest.mark.anyio]
 )
 async def test_copy_in_buffers(aconn, format, buffer):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
         await copy.write(globals()[buffer])
 
@@ -33,7 +32,7 @@ async def test_copy_in_buffers(aconn, format, buffer):
 
 async def test_copy_in_buffers_pg_error(aconn):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     with pytest.raises(e.UniqueViolation):
         async with cur.copy("copy copy_in from stdin") as copy:
             await copy.write(sample_text)
@@ -43,7 +42,7 @@ async def test_copy_in_buffers_pg_error(aconn):
 
 async def test_copy_in_str(aconn):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     async with cur.copy("copy copy_in from stdin") as copy:
         await copy.write(sample_text.decode())
 
@@ -55,7 +54,7 @@ async def test_copy_in_str(aconn):
 @pytest.mark.xfail(reason="bad sqlstate - CRDB #81559")
 async def test_copy_in_error(aconn):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     with pytest.raises(e.QueryCanceled):
         async with cur.copy("copy copy_in from stdin with binary") as copy:
             await copy.write(sample_text.decode())
@@ -66,7 +65,7 @@ async def test_copy_in_error(aconn):
 @pytest.mark.parametrize("format", Format)
 async def test_copy_in_empty(aconn, format):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     async with cur.copy(f"copy copy_in from stdin {copyopt(format)}"):
         pass
 
@@ -77,7 +76,7 @@ async def test_copy_in_empty(aconn, format):
 @pytest.mark.slow
 async def test_copy_big_size_record(aconn):
     cur = aconn.cursor()
-    await ensure_table(cur, "id serial primary key, data text")
+    await ensure_table_async(cur, "id serial primary key, data text")
     data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
     async with cur.copy("copy copy_in (data) from stdin") as copy:
         await copy.write_row([data])
@@ -89,7 +88,7 @@ async def test_copy_big_size_record(aconn):
 @pytest.mark.slow
 async def test_copy_big_size_block(aconn):
     cur = aconn.cursor()
-    await ensure_table(cur, "id serial primary key, data text")
+    await ensure_table_async(cur, "id serial primary key, data text")
     data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
     copy_data = data + "\n"
     async with cur.copy("copy copy_in (data) from stdin") as copy:
@@ -101,7 +100,7 @@ async def test_copy_big_size_block(aconn):
 
 async def test_copy_in_buffers_with_pg_error(aconn):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     with pytest.raises(e.UniqueViolation):
         async with cur.copy("copy copy_in from stdin") as copy:
             await copy.write(sample_text)
@@ -113,7 +112,7 @@ async def test_copy_in_buffers_with_pg_error(aconn):
 @pytest.mark.parametrize("format", Format)
 async def test_copy_in_records(aconn, format):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
 
     async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
         for row in sample_records:
@@ -131,7 +130,7 @@ async def test_copy_in_records(aconn, format):
 @pytest.mark.parametrize("format", Format)
 async def test_copy_in_records_set_types(aconn, format):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
 
     async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
         copy.set_types(["int4", "int4", "text"])
@@ -146,7 +145,7 @@ async def test_copy_in_records_set_types(aconn, format):
 @pytest.mark.parametrize("format", Format)
 async def test_copy_in_records_binary(aconn, format):
     cur = aconn.cursor()
-    await ensure_table(cur, "col1 serial primary key, col2 int4, data text")
+    await ensure_table_async(cur, "col1 serial primary key, col2 int4, data text")
 
     async with cur.copy(
         f"copy copy_in (col2, data) from stdin {copyopt(format)}"
@@ -162,7 +161,7 @@ async def test_copy_in_records_binary(aconn, format):
 @pytest.mark.crdb_skip("copy canceled")
 async def test_copy_in_buffers_with_py_error(aconn):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     with pytest.raises(e.QueryCanceled) as exc:
         async with cur.copy("copy copy_in from stdin") as copy:
             await copy.write(sample_text)
@@ -174,7 +173,7 @@ async def test_copy_in_buffers_with_py_error(aconn):
 
 async def test_copy_in_allchars(aconn):
     cur = aconn.cursor()
-    await ensure_table(cur, "col1 int primary key, col2 int, data text")
+    await ensure_table_async(cur, "col1 int primary key, col2 int, data text")
 
     async with cur.copy("copy copy_in from stdin") as copy:
         for i in range(1, 256):
index 2c21368ae5c7d9915ddcbff0a3149cbb0edf3fbe..7eae2197cff0105d935e7fdd664aa8ac9cc9db5d 100644 (file)
@@ -1,5 +1,4 @@
 import string
-import struct
 import hashlib
 from io import BytesIO, StringIO
 from random import choice, randrange
@@ -12,42 +11,19 @@ from psycopg import pq
 from psycopg import sql
 from psycopg import errors as e
 from psycopg.pq import Format
-from psycopg.copy import Copy, LibpqWriter, QueuedLibpqWriter, FileWriter
+from psycopg.copy import Copy, LibpqWriter, QueuedLibpqWriter
 from psycopg.adapt import PyFormat
 from psycopg.types import TypeInfo
 from psycopg.types.hstore import register_hstore
 from psycopg.types.numeric import Int4
 
 from .utils import eur, gc_collect, gc_count
+from ._test_copy import sample_text, sample_binary, sample_binary_rows  # noqa
+from ._test_copy import sample_values, sample_records, sample_tabledef
+from ._test_copy import ensure_table, py_to_raw, special_chars, FileWriter
 
 pytestmark = pytest.mark.crdb_skip("copy")
 
-sample_records = [(40010, 40020, "hello"), (40040, None, "world")]
-sample_values = "values (40010::int, 40020::int, 'hello'::text), (40040, NULL, 'world')"
-sample_tabledef = "col1 serial primary key, col2 int, data text"
-
-sample_text = b"""\
-40010\t40020\thello
-40040\t\\N\tworld
-"""
-
-sample_binary_str = """
-5047 434f 5059 0aff 0d0a 00
-00 0000 0000 0000 00
-00 0300 0000 0400 009c 4a00 0000 0400 009c 5400 0000 0568 656c 6c6f
-
-0003 0000 0004 0000 9c68 ffff ffff 0000 0005 776f 726c 64
-
-ff ff
-"""
-
-sample_binary_rows = [
-    bytes.fromhex("".join(row.split())) for row in sample_binary_str.split("\n\n")
-]
-sample_binary = b"".join(sample_binary_rows)
-
-special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"}
-
 
 @pytest.mark.parametrize("format", Format)
 def test_copy_out_read(conn, format):
@@ -320,9 +296,9 @@ def test_subclass_adapter(conn, format):
     if format == Format.TEXT:
         from psycopg.types.string import StrDumper as BaseDumper
     else:
-        from psycopg.types.string import (  # type: ignore[assignment]
-            StrBinaryDumper as BaseDumper,
-        )
+        from psycopg.types.string import StrBinaryDumper
+
+        BaseDumper = StrBinaryDumper  # type: ignore
 
     class MyStrDumper(BaseDumper):
         def dump(self, obj):
@@ -413,9 +389,8 @@ def test_copy_in_records(conn, format):
     with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
         for row in sample_records:
             if format == Format.BINARY:
-                row = tuple(
-                    Int4(i) if isinstance(i, int) else i for i in row
-                )  # type: ignore[assignment]
+                row2 = tuple(Int4(i) if isinstance(i, int) else i for i in row)
+                row = row2  # type: ignore[assignment]
             copy.write_row(row)
 
     data = cur.execute("select * from copy_in order by 1").fetchall()
@@ -815,25 +790,6 @@ def test_copy_table_across(conn_cls, dsn, faker, mode):
             faker.assert_record(got, want)
 
 
-def py_to_raw(item, fmt):
-    """Convert from Python type to the expected result from the db"""
-    if fmt == Format.TEXT:
-        if isinstance(item, int):
-            return str(item)
-    else:
-        if isinstance(item, int):
-            # Assume int4
-            return struct.pack("!i", item)
-        elif isinstance(item, str):
-            return item.encode()
-    return item
-
-
-def ensure_table(cur, tabledef, name="copy_in"):
-    cur.execute(f"drop table if exists {name}")
-    cur.execute(f"create table {name} ({tabledef})")
-
-
 class DataGenerator:
     def __init__(self, conn, nrecs, srec, offset=0, block_size=8192):
         self.conn = conn
index dd11d4bd278aa6429b428968ea21889fde897e55..5589038cc69c56c7a15d33807a06908f85d6d2c2 100644 (file)
@@ -11,21 +11,18 @@ from psycopg import pq
 from psycopg import sql
 from psycopg import errors as e
 from psycopg.pq import Format
-from psycopg.copy import AsyncCopy
-from psycopg.copy import AsyncWriter, AsyncLibpqWriter, AsyncQueuedLibpqWriter
-from psycopg.types import TypeInfo
+from psycopg.copy import AsyncCopy, AsyncLibpqWriter, AsyncQueuedLibpqWriter
 from psycopg.adapt import PyFormat
+from psycopg.types import TypeInfo
 from psycopg.types.hstore import register_hstore
 from psycopg.types.numeric import Int4
 
 from .utils import alist, eur, gc_collect, gc_count
-from .test_copy import sample_text, sample_binary, sample_binary_rows  # noqa
-from .test_copy import sample_values, sample_records, sample_tabledef
-from .test_copy import py_to_raw, special_chars
+from ._test_copy import sample_text, sample_binary, sample_binary_rows  # noqa
+from ._test_copy import sample_values, sample_records, sample_tabledef
+from ._test_copy import ensure_table_async, py_to_raw, special_chars, AsyncFileWriter
 
-pytestmark = [
-    pytest.mark.crdb_skip("copy"),
-]
+pytestmark = pytest.mark.crdb_skip("copy")
 
 
 @pytest.mark.parametrize("format", Format)
@@ -114,7 +111,7 @@ async def test_rows(aconn, format):
     async with cur.copy(
         f"copy ({sample_values}) to stdout (format {format.name})"
     ) as copy:
-        copy.set_types("int4 int4 text".split())
+        copy.set_types(["int4", "int4", "text"])
         rows = await alist(copy.rows())
 
     assert rows == sample_records
@@ -204,7 +201,7 @@ async def test_copy_out_badntypes(aconn, format, err):
 )
 async def test_copy_in_buffers(aconn, format, buffer):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
         await copy.write(globals()[buffer])
 
@@ -215,7 +212,7 @@ async def test_copy_in_buffers(aconn, format, buffer):
 
 async def test_copy_in_buffers_pg_error(aconn):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     with pytest.raises(e.UniqueViolation):
         async with cur.copy("copy copy_in from stdin (format text)") as copy:
             await copy.write(sample_text)
@@ -251,7 +248,7 @@ async def test_copy_bad_result(aconn):
 
 async def test_copy_in_str(aconn):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     async with cur.copy("copy copy_in from stdin (format text)") as copy:
         await copy.write(sample_text.decode())
 
@@ -262,7 +259,7 @@ async def test_copy_in_str(aconn):
 
 async def test_copy_in_error(aconn):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     with pytest.raises(TypeError):
         async with cur.copy("copy copy_in from stdin (format binary)") as copy:
             await copy.write(sample_text.decode())
@@ -273,7 +270,7 @@ async def test_copy_in_error(aconn):
 @pytest.mark.parametrize("format", Format)
 async def test_copy_in_empty(aconn, format):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     async with cur.copy(f"copy copy_in from stdin (format {format.name})"):
         pass
 
@@ -284,7 +281,7 @@ async def test_copy_in_empty(aconn, format):
 @pytest.mark.slow
 async def test_copy_big_size_record(aconn):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
     async with cur.copy("copy copy_in (data) from stdin") as copy:
         await copy.write_row([data])
@@ -297,7 +294,7 @@ async def test_copy_big_size_record(aconn):
 @pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview])
 async def test_copy_big_size_block(aconn, pytype):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
     copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n")
     async with cur.copy("copy copy_in (data) from stdin") as copy:
@@ -312,9 +309,9 @@ async def test_subclass_adapter(aconn, format):
     if format == Format.TEXT:
         from psycopg.types.string import StrDumper as BaseDumper
     else:
-        from psycopg.types.string import (  # type: ignore[assignment]
-            StrBinaryDumper as BaseDumper,
-        )
+        from psycopg.types.string import StrBinaryDumper
+
+        BaseDumper = StrBinaryDumper  # type: ignore
 
     class MyStrDumper(BaseDumper):
         def dump(self, obj):
@@ -323,7 +320,7 @@ async def test_subclass_adapter(aconn, format):
     aconn.adapters.register_dumper(str, MyStrDumper)
 
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
 
     async with cur.copy(
         f"copy copy_in (data) from stdin (format {format.name})"
@@ -338,7 +335,7 @@ async def test_subclass_adapter(aconn, format):
 @pytest.mark.parametrize("format", Format)
 async def test_copy_in_error_empty(aconn, format):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     with pytest.raises(ZeroDivisionError, match="mannaggiamiseria"):
         async with cur.copy(f"copy copy_in from stdin (format {format.name})"):
             raise ZeroDivisionError("mannaggiamiseria")
@@ -348,7 +345,7 @@ async def test_copy_in_error_empty(aconn, format):
 
 async def test_copy_in_buffers_with_pg_error(aconn):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     with pytest.raises(e.UniqueViolation):
         async with cur.copy("copy copy_in from stdin (format text)") as copy:
             await copy.write(sample_text)
@@ -359,7 +356,7 @@ async def test_copy_in_buffers_with_pg_error(aconn):
 
 async def test_copy_in_buffers_with_py_error(aconn):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     with pytest.raises(ZeroDivisionError, match="nuttengoggenio"):
         async with cur.copy("copy copy_in from stdin (format text)") as copy:
             await copy.write(sample_text)
@@ -405,14 +402,13 @@ async def test_copy_out_server_error(aconn):
 @pytest.mark.parametrize("format", Format)
 async def test_copy_in_records(aconn, format):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
 
     async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
         for row in sample_records:
             if format == Format.BINARY:
-                row = tuple(
-                    Int4(i) if isinstance(i, int) else i for i in row
-                )  # type: ignore[assignment]
+                row2 = tuple(Int4(i) if isinstance(i, int) else i for i in row)
+                row = row2  # type: ignore[assignment]
             await copy.write_row(row)
 
     await cur.execute("select * from copy_in order by 1")
@@ -423,7 +419,7 @@ async def test_copy_in_records(aconn, format):
 @pytest.mark.parametrize("format", Format)
 async def test_copy_in_records_set_types(aconn, format):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
 
     async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
         copy.set_types(["int4", "int4", "text"])
@@ -438,7 +434,7 @@ async def test_copy_in_records_set_types(aconn, format):
 @pytest.mark.parametrize("format", Format)
 async def test_copy_in_records_binary(aconn, format):
     cur = aconn.cursor()
-    await ensure_table(cur, "col1 serial primary key, col2 int, data text")
+    await ensure_table_async(cur, "col1 serial primary key, col2 int, data text")
 
     async with cur.copy(
         f"copy copy_in (col2, data) from stdin (format {format.name})"
@@ -453,7 +449,7 @@ async def test_copy_in_records_binary(aconn, format):
 
 async def test_copy_in_allchars(aconn):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
 
     await aconn.execute("set client_encoding to utf8")
     async with cur.copy("copy copy_in from stdin (format text)") as copy:
@@ -642,7 +638,7 @@ async def test_description(aconn):
 )
 async def test_worker_life(aconn, format, buffer):
     cur = aconn.cursor()
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     async with cur.copy(
         f"copy copy_in from stdin (format {format.name})",
         writer=AsyncQueuedLibpqWriter(cur),
@@ -679,7 +675,7 @@ async def test_connection_writer(aconn, format, buffer):
     cur = aconn.cursor()
     writer = AsyncLibpqWriter(cur)
 
-    await ensure_table(cur, sample_tabledef)
+    await ensure_table_async(cur, sample_tabledef)
     async with cur.copy(
         f"copy copy_in from stdin (format {format.name})", writer=writer
     ) as copy:
@@ -823,11 +819,6 @@ async def test_copy_table_across(aconn_cls, dsn, faker, mode):
             faker.assert_record(got, want)
 
 
-async def ensure_table(cur, tabledef, name="copy_in"):
-    await cur.execute(f"drop table if exists {name}")
-    await cur.execute(f"create table {name} ({tabledef})")
-
-
 class DataGenerator:
     def __init__(self, conn, nrecs, srec, offset=0, block_size=8192):
         self.conn = conn
@@ -838,7 +829,7 @@ class DataGenerator:
 
     async def ensure_table(self):
         cur = self.conn.cursor()
-        await ensure_table(cur, "id integer primary key, data text")
+        await ensure_table_async(cur, "id integer primary key, data text")
 
     def records(self):
         for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)):
@@ -879,11 +870,3 @@ class DataGenerator:
                 block = block.encode()
             m.update(block)
         return m.hexdigest()
-
-
-class AsyncFileWriter(AsyncWriter):
-    def __init__(self, file):
-        self.file = file
-
-    async def write(self, data):
-        self.file.write(data)
index 706dc66b88cf6955a759572340e60c9f89e5246c..fce12d5c82fa8837559185a2601be6292e2637b6 100755 (executable)
@@ -106,7 +106,11 @@ class RenameAsyncToSync(ast.NodeTransformer):
     names_map = {
         "AsyncClientCursor": "ClientCursor",
         "AsyncConnection": "Connection",
+        "AsyncCopy": "Copy",
         "AsyncCursor": "Cursor",
+        "AsyncFileWriter": "FileWriter",
+        "AsyncLibpqWriter": "LibpqWriter",
+        "AsyncQueuedLibpqWriter": "QueuedLibpqWriter",
         "AsyncRawCursor": "RawCursor",
         "AsyncServerCursor": "ServerCursor",
         "aclose": "close",