From: Daniele Varrazzo Date: Sun, 24 Jul 2022 02:22:32 +0000 (+0100) Subject: feat(copy): add 'binary' param to copy object X-Git-Tag: 3.1~44^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=99f0bb27b3f27dce133b8947fd02bb49bb1fd509;p=thirdparty%2Fpsycopg.git feat(copy): add 'binary' param to copy object This allows to write binary format even when the query has executed no COPY operation. Notice that now we have introduced tests that convert sample_records into sample_binary: because small numbers would be dumped to int2, use values large enough to require int4. --- diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index ee9973de6..08f1e2b21 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -78,7 +78,12 @@ class BaseCopy(Generic[ConnectionType]): formatter: "Formatter" - def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"): + def __init__( + self, + cursor: "BaseCursor[ConnectionType, Any]", + *, + binary: Optional[bool] = None, + ): self.cursor = cursor self.connection = cursor.connection self._pgconn = self.connection.pgconn @@ -94,8 +99,11 @@ class BaseCopy(Generic[ConnectionType]): else: self._direction = COPY_IN + if binary is None: + binary = bool(result and result.binary_tuples) + tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor) - if result and result.binary_tuples: + if binary: self.formatter = BinaryFormatter(tx) else: self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn)) @@ -196,8 +204,14 @@ class Copy(BaseCopy["Connection[Any]"]): writer: "Writer" - def __init__(self, cursor: "Cursor[Any]", *, writer: Optional["Writer"] = None): - super().__init__(cursor) + def __init__( + self, + cursor: "Cursor[Any]", + *, + binary: Optional[bool] = None, + writer: Optional["Writer"] = None, + ): + super().__init__(cursor, binary=binary) if not writer: writer = LibpqWriter(cursor) @@ -438,9 +452,13 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): writer: "AsyncWriter" def __init__( - self, cursor: "AsyncCursor[Any]", *, writer: Optional["AsyncWriter"] = None + self, + cursor: "AsyncCursor[Any]", + *, + binary: Optional[bool] = None, + writer: Optional["AsyncWriter"] = None, ): - super().__init__(cursor) + super().__init__(cursor, binary=binary) if not writer: writer = AsyncLibpqWriter(cursor) diff --git a/tests/test_copy.py b/tests/test_copy.py index 540b903db..7844c3c4c 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -1,5 +1,6 @@ import gc import string +import struct import hashlib from io import BytesIO, StringIO from random import choice, randrange @@ -22,21 +23,21 @@ from .utils import eur, gc_collect pytestmark = pytest.mark.crdb_skip("copy") -sample_records = [(10, 20, "hello"), (40, None, "world")] -sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')" +sample_records = [(40010, 40020, "hello"), (40040, None, "world")] +sample_values = "values (40010::int, 40020::int, 'hello'::text), (40040, NULL, 'world')" sample_tabledef = "col1 serial primary key, col2 int, data text" sample_text = b"""\ -10\t20\thello -40\t\\N\tworld +40010\t40020\thello +40040\t\\N\tworld """ sample_binary_str = """ 5047 434f 5059 0aff 0d0a 00 00 0000 0000 0000 00 -00 0300 0000 0400 0000 0a00 0000 0400 0000 1400 0000 0568 656c 6c6f +00 0300 0000 0400 009c 4a00 0000 0400 009c 5400 0000 0568 656c 6c6f -0003 0000 0004 0000 0028 ffff ffff 0000 0005 776f 726c 64 +0003 0000 0004 0000 9c68 ffff ffff 0000 0005 776f 726c 64 ff ff """ @@ -44,7 +45,6 @@ ff ff sample_binary_rows = [ bytes.fromhex("".join(row.split())) for row in sample_binary_str.split("\n\n") ] - sample_binary = b"".join(sample_binary_rows) special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"} @@ -484,6 +484,23 @@ def test_copy_in_format(conn): assert fields[1].decode() == chr(i) +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +def test_file_writer(conn, format, buffer): + file = BytesIO() + conn.execute("set client_encoding to utf8") + cur = conn.cursor() + with Copy(cur, binary=format, writer=FileWriter(file)) as copy: + for record in sample_records: + copy.write_row(record) + + file.seek(0) + want = globals()[buffer] + got = file.read() + assert got == want + + @pytest.mark.slow def test_copy_from_to(conn): # Roundtrip from file to database to file blockwise @@ -611,8 +628,7 @@ def test_description(conn): @pytest.mark.parametrize( - "format, buffer", - [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] ) def test_worker_life(conn, format, buffer): cur = conn.cursor() @@ -643,8 +659,7 @@ def test_worker_error_propagated(conn, monkeypatch): @pytest.mark.parametrize( - "format, buffer", - [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] ) def test_connection_writer(conn, format, buffer): cur = conn.cursor() @@ -798,7 +813,8 @@ def py_to_raw(item, fmt): return str(item) else: if isinstance(item, int): - return bytes([0, 0, 0, item]) + # Assume int4 + return struct.pack("!i", item) elif isinstance(item, str): return item.encode() return item diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index e841d21c3..24aa9cabf 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -190,8 +190,7 @@ async def test_copy_out_badntypes(aconn, format, err): @pytest.mark.parametrize( - "format, buffer", - [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] ) async def test_copy_in_buffers(aconn, format, buffer): cur = aconn.cursor() @@ -465,15 +464,15 @@ from copy_in group by 1, 2, 3 async def test_copy_in_format(aconn): - writer = AsyncBytesWriter() + file = BytesIO() await aconn.execute("set client_encoding to utf8") cur = aconn.cursor() - async with AsyncCopy(cur, writer=writer) as copy: + async with AsyncCopy(cur, writer=AsyncFileWriter(file)) as copy: for i in range(1, 256): await copy.write_row((i, chr(i))) - writer.file.seek(0) - rows = writer.file.read().split(b"\n") + file.seek(0) + rows = file.read().split(b"\n") assert not rows[-1] del rows[-1] @@ -487,6 +486,23 @@ async def test_copy_in_format(aconn): assert fields[1].decode() == chr(i) +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +async def test_file_writer(aconn, format, buffer): + file = BytesIO() + await aconn.execute("set client_encoding to utf8") + cur = aconn.cursor() + async with AsyncCopy(cur, binary=format, writer=AsyncFileWriter(file)) as copy: + for record in sample_records: + await copy.write_row(record) + + file.seek(0) + want = globals()[buffer] + got = file.read() + assert got == want + + @pytest.mark.slow async def test_copy_from_to(aconn): # Roundtrip from file to database to file blockwise @@ -614,8 +630,7 @@ async def test_description(aconn): @pytest.mark.parametrize( - "format, buffer", - [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] ) async def test_worker_life(aconn, format, buffer): cur = aconn.cursor() @@ -650,8 +665,7 @@ async def test_worker_error_propagated(aconn, monkeypatch): @pytest.mark.parametrize( - "format, buffer", - [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] ) async def test_connection_writer(aconn, format, buffer): cur = aconn.cursor() @@ -859,9 +873,9 @@ class DataGenerator: return m.hexdigest() -class AsyncBytesWriter(AsyncWriter): - def __init__(self): - self.file = BytesIO() +class AsyncFileWriter(AsyncWriter): + def __init__(self, file): + self.file = file async def write(self, data): self.file.write(data)