]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat(copy): add 'binary' param to copy object
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 24 Jul 2022 02:22:32 +0000 (03:22 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 26 Jul 2022 12:23:46 +0000 (13:23 +0100)
This allows to write binary format even when the query has executed no
COPY operation.

Notice that now we have introduced tests that convert sample_records
into sample_binary: because small numbers would be dumped to int2, use
values large enough to require int4.

psycopg/psycopg/copy.py
tests/test_copy.py
tests/test_copy_async.py

index ee9973de687bc46587039ae75cadb38929c2eddb..08f1e2b213e2633cadeef237850867811d4ae369 100644 (file)
@@ -78,7 +78,12 @@ class BaseCopy(Generic[ConnectionType]):
 
     formatter: "Formatter"
 
-    def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"):
+    def __init__(
+        self,
+        cursor: "BaseCursor[ConnectionType, Any]",
+        *,
+        binary: Optional[bool] = None,
+    ):
         self.cursor = cursor
         self.connection = cursor.connection
         self._pgconn = self.connection.pgconn
@@ -94,8 +99,11 @@ class BaseCopy(Generic[ConnectionType]):
         else:
             self._direction = COPY_IN
 
+        if binary is None:
+            binary = bool(result and result.binary_tuples)
+
         tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor)
-        if result and result.binary_tuples:
+        if binary:
             self.formatter = BinaryFormatter(tx)
         else:
             self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn))
@@ -196,8 +204,14 @@ class Copy(BaseCopy["Connection[Any]"]):
 
     writer: "Writer"
 
-    def __init__(self, cursor: "Cursor[Any]", *, writer: Optional["Writer"] = None):
-        super().__init__(cursor)
+    def __init__(
+        self,
+        cursor: "Cursor[Any]",
+        *,
+        binary: Optional[bool] = None,
+        writer: Optional["Writer"] = None,
+    ):
+        super().__init__(cursor, binary=binary)
         if not writer:
             writer = LibpqWriter(cursor)
 
@@ -438,9 +452,13 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
     writer: "AsyncWriter"
 
     def __init__(
-        self, cursor: "AsyncCursor[Any]", *, writer: Optional["AsyncWriter"] = None
+        self,
+        cursor: "AsyncCursor[Any]",
+        *,
+        binary: Optional[bool] = None,
+        writer: Optional["AsyncWriter"] = None,
     ):
-        super().__init__(cursor)
+        super().__init__(cursor, binary=binary)
 
         if not writer:
             writer = AsyncLibpqWriter(cursor)
index 540b903dbaf414c5fec0cfbe34625b30ff87555c..7844c3c4ca608bb4f8b09e26819c1189f6c31485 100644 (file)
@@ -1,5 +1,6 @@
 import gc
 import string
+import struct
 import hashlib
 from io import BytesIO, StringIO
 from random import choice, randrange
@@ -22,21 +23,21 @@ from .utils import eur, gc_collect
 
 pytestmark = pytest.mark.crdb_skip("copy")
 
-sample_records = [(10, 20, "hello"), (40, None, "world")]
-sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')"
+sample_records = [(40010, 40020, "hello"), (40040, None, "world")]
+sample_values = "values (40010::int, 40020::int, 'hello'::text), (40040, NULL, 'world')"
 sample_tabledef = "col1 serial primary key, col2 int, data text"
 
 sample_text = b"""\
-10\t20\thello
-40\t\\N\tworld
+40010\t40020\thello
+40040\t\\N\tworld
 """
 
 sample_binary_str = """
 5047 434f 5059 0aff 0d0a 00
 00 0000 0000 0000 00
-00 0300 0000 0400 0000 0a00 0000 0400 0000 1400 0000 0568 656c 6c6f
+00 0300 0000 0400 009c 4a00 0000 0400 009c 5400 0000 0568 656c 6c6f
 
-0003 0000 0004 0000 0028 ffff ffff 0000 0005 776f 726c 64
+0003 0000 0004 0000 9c68 ffff ffff 0000 0005 776f 726c 64
 
 ff ff
 """
@@ -44,7 +45,6 @@ ff ff
 sample_binary_rows = [
     bytes.fromhex("".join(row.split())) for row in sample_binary_str.split("\n\n")
 ]
-
 sample_binary = b"".join(sample_binary_rows)
 
 special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"}
@@ -484,6 +484,23 @@ def test_copy_in_format(conn):
             assert fields[1].decode() == chr(i)
 
 
+@pytest.mark.parametrize(
+    "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+def test_file_writer(conn, format, buffer):
+    file = BytesIO()
+    conn.execute("set client_encoding to utf8")
+    cur = conn.cursor()
+    with Copy(cur, binary=format, writer=FileWriter(file)) as copy:
+        for record in sample_records:
+            copy.write_row(record)
+
+    file.seek(0)
+    want = globals()[buffer]
+    got = file.read()
+    assert got == want
+
+
 @pytest.mark.slow
 def test_copy_from_to(conn):
     # Roundtrip from file to database to file blockwise
@@ -611,8 +628,7 @@ def test_description(conn):
 
 
 @pytest.mark.parametrize(
-    "format, buffer",
-    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+    "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
 )
 def test_worker_life(conn, format, buffer):
     cur = conn.cursor()
@@ -643,8 +659,7 @@ def test_worker_error_propagated(conn, monkeypatch):
 
 
 @pytest.mark.parametrize(
-    "format, buffer",
-    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+    "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
 )
 def test_connection_writer(conn, format, buffer):
     cur = conn.cursor()
@@ -798,7 +813,8 @@ def py_to_raw(item, fmt):
             return str(item)
     else:
         if isinstance(item, int):
-            return bytes([0, 0, 0, item])
+            # Assume int4
+            return struct.pack("!i", item)
         elif isinstance(item, str):
             return item.encode()
     return item
index e841d21c3aedaff67e3d34b042c41313301262f7..24aa9cabfa328ef9eeafc46fe9fa061600755b4b 100644 (file)
@@ -190,8 +190,7 @@ async def test_copy_out_badntypes(aconn, format, err):
 
 
 @pytest.mark.parametrize(
-    "format, buffer",
-    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+    "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
 )
 async def test_copy_in_buffers(aconn, format, buffer):
     cur = aconn.cursor()
@@ -465,15 +464,15 @@ from copy_in group by 1, 2, 3
 
 
 async def test_copy_in_format(aconn):
-    writer = AsyncBytesWriter()
+    file = BytesIO()
     await aconn.execute("set client_encoding to utf8")
     cur = aconn.cursor()
-    async with AsyncCopy(cur, writer=writer) as copy:
+    async with AsyncCopy(cur, writer=AsyncFileWriter(file)) as copy:
         for i in range(1, 256):
             await copy.write_row((i, chr(i)))
 
-    writer.file.seek(0)
-    rows = writer.file.read().split(b"\n")
+    file.seek(0)
+    rows = file.read().split(b"\n")
     assert not rows[-1]
     del rows[-1]
 
@@ -487,6 +486,23 @@ async def test_copy_in_format(aconn):
             assert fields[1].decode() == chr(i)
 
 
+@pytest.mark.parametrize(
+    "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+async def test_file_writer(aconn, format, buffer):
+    file = BytesIO()
+    await aconn.execute("set client_encoding to utf8")
+    cur = aconn.cursor()
+    async with AsyncCopy(cur, binary=format, writer=AsyncFileWriter(file)) as copy:
+        for record in sample_records:
+            await copy.write_row(record)
+
+    file.seek(0)
+    want = globals()[buffer]
+    got = file.read()
+    assert got == want
+
+
 @pytest.mark.slow
 async def test_copy_from_to(aconn):
     # Roundtrip from file to database to file blockwise
@@ -614,8 +630,7 @@ async def test_description(aconn):
 
 
 @pytest.mark.parametrize(
-    "format, buffer",
-    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+    "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
 )
 async def test_worker_life(aconn, format, buffer):
     cur = aconn.cursor()
@@ -650,8 +665,7 @@ async def test_worker_error_propagated(aconn, monkeypatch):
 
 
 @pytest.mark.parametrize(
-    "format, buffer",
-    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+    "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
 )
 async def test_connection_writer(aconn, format, buffer):
     cur = aconn.cursor()
@@ -859,9 +873,9 @@ class DataGenerator:
         return m.hexdigest()
 
 
-class AsyncBytesWriter(AsyncWriter):
-    def __init__(self):
-        self.file = BytesIO()
+class AsyncFileWriter(AsyncWriter):
+    def __init__(self, file):
+        self.file = file
 
     async def write(self, data):
         self.file.write(data)