From: Daniele Varrazzo Date: Tue, 24 Aug 2021 15:54:47 +0000 (+0200) Subject: Realign async copy tests to the sync ones X-Git-Tag: 3.0.beta1~48 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=c4ba8255db3c41f7b6b2c4a9ea0aa1fd893a73bc;p=thirdparty%2Fpsycopg.git Realign async copy tests to the sync ones --- diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 2efb7a1a3..428258193 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -57,32 +57,28 @@ async def test_copy_out_iter(aconn, format): want = sample_binary_rows cur = aconn.cursor() - got = [] async with cur.copy( f"copy ({sample_values}) to stdout (format {format.name})" ) as copy: - async for row in copy: - got.append(row) + [row async for row in copy] == want - assert got == want assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -async def test_read_rows(aconn, format): +@pytest.mark.parametrize("typetype", ["names", "oids"]) +async def test_read_rows(aconn, format, typetype): cur = aconn.cursor() async with cur.copy( - f"copy ({sample_values}) to stdout (format {format.name})" + f"""copy ( + select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[] + ) to stdout (format {format.name})""" ) as copy: - copy.set_types("int4 int4 text".split()) - rows = [] - while 1: - row = await copy.read_row() - if not row: - break - rows.append(row) + copy.set_types(["int4", "text", "float8[]"]) + row = await copy.read_row() + assert (await copy.read_row()) is None - assert rows == sample_records + assert row == (10, "hello", [0.0, 1.0]) assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS @@ -93,9 +89,7 @@ async def test_rows(aconn, format): f"copy ({sample_values}) to stdout (format {format.name})" ) as copy: copy.set_types("int4 int4 text".split()) - rows = [] - async for row in copy.rows(): - rows.append(row) + rows = [row async for row in copy.rows()] assert rows == sample_records assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS @@ -165,9 +159,7 @@ async def test_rows_notypes(aconn, format): async with cur.copy( f"copy ({sample_values}) to stdout (format {format.name})" ) as copy: - rows = [] - async for row in copy.rows(): - rows.append(row) + rows = [row async for row in copy.rows()] ref = [ tuple(py_to_raw(i, format) for i in record) for record in sample_records @@ -260,8 +252,34 @@ async def test_copy_in_empty(aconn, format): async with cur.copy(f"copy copy_in from stdin (format {format.name})"): pass - assert cur.rowcount == 0 assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + assert cur.rowcount == 0 + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +async def test_subclass_adapter(aconn, format): + if format == Format.TEXT: + from psycopg.types.string import StrDumper as BaseDumper + else: + from psycopg.types.string import StrBinaryDumper as BaseDumper + + class MyStrDumper(BaseDumper): + def dump(self, obj): + return super().dump(obj) * 2 + + aconn.adapters.register_dumper(str, MyStrDumper) + + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy( + f"copy copy_in (data) from stdin (format {format.name})" + ) as copy: + await copy.write_row(("hello",)) + + await cur.execute("select data from copy_in") + rec = await cur.fetchone() + assert rec[0] == "hellohello" @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])