]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Realign async copy tests to the sync ones
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Aug 2021 15:54:47 +0000 (17:54 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Aug 2021 15:54:47 +0000 (17:54 +0200)
tests/test_copy_async.py

index 2efb7a1a36440c343568f3d2b3f460834ff5ca40..428258193dcd90a59103609b11d20677ff3fe681 100644 (file)
@@ -57,32 +57,28 @@ async def test_copy_out_iter(aconn, format):
         want = sample_binary_rows
 
     cur = aconn.cursor()
-    got = []
     async with cur.copy(
         f"copy ({sample_values}) to stdout (format {format.name})"
     ) as copy:
-        async for row in copy:
-            got.append(row)
+        [row async for row in copy] == want
 
-    assert got == want
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
 
 
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-async def test_read_rows(aconn, format):
+@pytest.mark.parametrize("typetype", ["names", "oids"])
+async def test_read_rows(aconn, format, typetype):
     cur = aconn.cursor()
     async with cur.copy(
-        f"copy ({sample_values}) to stdout (format {format.name})"
+        f"""copy (
+            select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[]
+        ) to stdout (format {format.name})"""
     ) as copy:
-        copy.set_types("int4 int4 text".split())
-        rows = []
-        while 1:
-            row = await copy.read_row()
-            if not row:
-                break
-            rows.append(row)
+        copy.set_types(["int4", "text", "float8[]"])
+        row = await copy.read_row()
+        assert (await copy.read_row()) is None
 
-    assert rows == sample_records
+    assert row == (10, "hello", [0.0, 1.0])
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
 
 
@@ -93,9 +89,7 @@ async def test_rows(aconn, format):
         f"copy ({sample_values}) to stdout (format {format.name})"
     ) as copy:
         copy.set_types("int4 int4 text".split())
-        rows = []
-        async for row in copy.rows():
-            rows.append(row)
+        rows = [row async for row in copy.rows()]
 
     assert rows == sample_records
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
@@ -165,9 +159,7 @@ async def test_rows_notypes(aconn, format):
     async with cur.copy(
         f"copy ({sample_values}) to stdout (format {format.name})"
     ) as copy:
-        rows = []
-        async for row in copy.rows():
-            rows.append(row)
+        rows = [row async for row in copy.rows()]
     ref = [
         tuple(py_to_raw(i, format) for i in record)
         for record in sample_records
@@ -260,8 +252,34 @@ async def test_copy_in_empty(aconn, format):
     async with cur.copy(f"copy copy_in from stdin (format {format.name})"):
         pass
 
-    assert cur.rowcount == 0
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+    assert cur.rowcount == 0
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_subclass_adapter(aconn, format):
+    if format == Format.TEXT:
+        from psycopg.types.string import StrDumper as BaseDumper
+    else:
+        from psycopg.types.string import StrBinaryDumper as BaseDumper
+
+    class MyStrDumper(BaseDumper):
+        def dump(self, obj):
+            return super().dump(obj) * 2
+
+    aconn.adapters.register_dumper(str, MyStrDumper)
+
+    cur = aconn.cursor()
+    await ensure_table(cur, sample_tabledef)
+
+    async with cur.copy(
+        f"copy copy_in (data) from stdin (format {format.name})"
+    ) as copy:
+        await copy.write_row(("hello",))
+
+    await cur.execute("select data from copy_in")
+    rec = await cur.fetchone()
+    assert rec[0] == "hellohello"
 
 
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])