want = sample_binary_rows
cur = aconn.cursor()
- got = []
async with cur.copy(
f"copy ({sample_values}) to stdout (format {format.name})"
) as copy:
- async for row in copy:
- got.append(row)
+ [row async for row in copy] == want
- assert got == want
assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-async def test_read_rows(aconn, format):
+@pytest.mark.parametrize("typetype", ["names", "oids"])
+async def test_read_rows(aconn, format, typetype):
cur = aconn.cursor()
async with cur.copy(
- f"copy ({sample_values}) to stdout (format {format.name})"
+ f"""copy (
+ select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[]
+ ) to stdout (format {format.name})"""
) as copy:
- copy.set_types("int4 int4 text".split())
- rows = []
- while 1:
- row = await copy.read_row()
- if not row:
- break
- rows.append(row)
+ copy.set_types(["int4", "text", "float8[]"])
+ row = await copy.read_row()
+ assert (await copy.read_row()) is None
- assert rows == sample_records
+ assert row == (10, "hello", [0.0, 1.0])
assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
f"copy ({sample_values}) to stdout (format {format.name})"
) as copy:
copy.set_types("int4 int4 text".split())
- rows = []
- async for row in copy.rows():
- rows.append(row)
+ rows = [row async for row in copy.rows()]
assert rows == sample_records
assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
async with cur.copy(
f"copy ({sample_values}) to stdout (format {format.name})"
) as copy:
- rows = []
- async for row in copy.rows():
- rows.append(row)
+ rows = [row async for row in copy.rows()]
ref = [
tuple(py_to_raw(i, format) for i in record)
for record in sample_records
async with cur.copy(f"copy copy_in from stdin (format {format.name})"):
pass
- assert cur.rowcount == 0
assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+ assert cur.rowcount == 0
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_subclass_adapter(aconn, format):
+ if format == Format.TEXT:
+ from psycopg.types.string import StrDumper as BaseDumper
+ else:
+ from psycopg.types.string import StrBinaryDumper as BaseDumper
+
+ class MyStrDumper(BaseDumper):
+ def dump(self, obj):
+ return super().dump(obj) * 2
+
+ aconn.adapters.register_dumper(str, MyStrDumper)
+
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ async with cur.copy(
+ f"copy copy_in (data) from stdin (format {format.name})"
+ ) as copy:
+ await copy.write_row(("hello",))
+
+ await cur.execute("select data from copy_in")
+ rec = await cur.fetchone()
+ assert rec[0] == "hellohello"
@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])