]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added row-by-row copy in
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 30 Jun 2020 18:41:50 +0000 (06:41 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 1 Jul 2020 13:51:07 +0000 (01:51 +1200)
psycopg3/copy.py
tests/test_copy.py
tests/test_copy_async.py

index 38c9e88f23720d16974ca8ad8626f95aa290ef07..f8931decab375fa249bae7b1a06712caddaf4e50 100644 (file)
@@ -6,8 +6,9 @@ psycopg3 copy support
 
 import re
 import codecs
+import struct
 from typing import TYPE_CHECKING, AsyncGenerator, Generator
-from typing import Dict, Match, Optional, Type, Union
+from typing import Any, Dict, Match, Optional, Sequence, Type, Union
 from types import TracebackType
 
 from . import pq
@@ -31,9 +32,15 @@ class BaseCopy:
         self._transformer = Transformer(context)
         self.format = format
         self.pgresult = result
+        self._first_row = True
         self._finished = False
         self._codec: Optional[codecs.CodecInfo] = None
 
+        if format == pq.Format.TEXT:
+            self._format_row = self._format_row_text
+        else:
+            self._format_row = self._format_row_binary
+
     @property
     def finished(self) -> bool:
         return self._finished
@@ -76,23 +83,69 @@ class BaseCopy:
             self._codec = self.connection.codec
             return self._codec.encode(data)[0]
 
+        else:
+            raise TypeError(f"can't write {type(data).__name__}")
+
+    def format_row(self, row: Sequence[Any]) -> bytes:
+        # TODO: cache this, or pass just a single format
+        formats = [self.format] * len(row)
+        out, _ = self._transformer.dump_sequence(row, formats)
+        return self._format_row(out)
+
+    def _format_row_text(self, row: Sequence[Optional[bytes]],) -> bytes:
+        return (
+            b"\t".join(
+                _bsrepl_re.sub(_bsrepl_sub, item)
+                if item is not None
+                else br"\N"
+                for item in row
+            )
+            + b"\n"
+        )
+
+    def _format_row_binary(
+        self,
+        row: Sequence[Optional[bytes]],
+        __int2_struct: struct.Struct = struct.Struct("!h"),
+        __int4_struct: struct.Struct = struct.Struct("!i"),
+    ) -> bytes:
+        out = []
+        if self._first_row:
+            out.append(
+                # Signature, flags, extra length
+                b"PGCOPY\n\xff\r\n\0"
+                b"\x00\x00\x00\x00"
+                b"\x00\x00\x00\x00"
+            )
+            self._first_row = False
+
+        out.append(__int2_struct.pack(len(row)))
+        for item in row:
+            if item is not None:
+                out.append(__int4_struct.pack(len(item)))
+                out.append(item)
+            else:
+                out.append(b"\xff\xff\xff\xff")
+
+        return b"".join(out)
+
 
 def _bsrepl_sub(
     m: Match[bytes],
     __map: Dict[bytes, bytes] = {
-        b"b": b"\b",
-        b"t": b"\t",
-        b"n": b"\n",
-        b"v": b"\v",
-        b"f": b"\f",
-        b"r": b"\r",
+        b"\b": b"\\b",
+        b"\t": b"\\t",
+        b"\n": b"\\n",
+        b"\v": b"\\v",
+        b"\f": b"\\f",
+        b"\r": b"\\r",
+        b"\\": b"\\\\",
     },
 ) -> bytes:
-    g = m.group(0)
-    return __map.get(g, g)
+    return __map[m.group(0)]
 
 
-_bsrepl_re = re.compile(rb"\\(.)")
+_bsrepl_re = re.compile(b"[\b\t\n\v\f\r\\\\]")
 
 
 class Copy(BaseCopy):
@@ -119,6 +172,10 @@ class Copy(BaseCopy):
         conn = self.connection
         conn.wait(copy_to(conn.pgconn, self._ensure_bytes(buffer)))
 
+    def write_row(self, row: Sequence[Any]) -> None:
+        data = self.format_row(row)
+        self.write(data)
+
     def finish(self, error: Optional[str] = None) -> None:
         conn = self.connection
         berr = (
@@ -139,6 +196,9 @@ class Copy(BaseCopy):
         exc_tb: Optional[TracebackType],
     ) -> None:
         if exc_val is None:
+            if self.format == pq.Format.BINARY and not self._first_row:
+                # send EOF only if we copied binary rows (_first_row is False)
+                self.write(b"\xff\xff")
             self.finish()
         else:
             self.finish(str(exc_val))
@@ -173,6 +233,10 @@ class AsyncCopy(BaseCopy):
         conn = self.connection
         await conn.wait(copy_to(conn.pgconn, self._ensure_bytes(buffer)))
 
+    async def write_row(self, row: Sequence[Any]) -> None:
+        data = self.format_row(row)
+        await self.write(data)
+
     async def finish(self, error: Optional[str] = None) -> None:
         conn = self.connection
         berr = (
@@ -193,6 +257,9 @@ class AsyncCopy(BaseCopy):
         exc_tb: Optional[TracebackType],
     ) -> None:
         if exc_val is None:
+            if self.format == pq.Format.BINARY and not self._first_row:
+                # send EOF only if we copied binary rows (_first_row is False)
+                await self.write(b"\xff\xff")
             await self.finish()
         else:
             await self.finish(str(exc_val))
index e739f1985b60be872cbfc2df52a8f565877c4a6a..993a1bb694af7fa5b09deee3b38d78a5cbf77f72 100644 (file)
@@ -4,6 +4,7 @@ from psycopg3 import pq
 from psycopg3 import errors as e
 from psycopg3.adapt import Format
 
+eur = "\u20ac"
 
 sample_records = [(10, 20, "hello"), (40, None, "world")]
 
@@ -141,22 +142,59 @@ def test_copy_in_buffers_with_py_error(conn):
     assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
 
 
-@pytest.mark.xfail
-@pytest.mark.parametrize(
-    "format", [(Format.TEXT,), (Format.BINARY,)],
-)
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
 def test_copy_in_records(conn, format):
+    if format == Format.BINARY:
+        pytest.skip("TODO: implement int binary adapter")
+
     cur = conn.cursor()
     ensure_table(cur, sample_tabledef)
 
     with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
         for row in sample_records:
-            copy.write(row)
+            copy.write_row(row)
 
     data = cur.execute("select * from copy_in order by 1").fetchall()
     assert data == sample_records
 
 
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_copy_in_records_binary(conn, format):
+    if format == Format.TEXT:
+        pytest.skip("TODO: remove after implementing int binary adapter")
+
+    cur = conn.cursor()
+    ensure_table(cur, "col1 serial primary key, col2 int, data text")
+
+    with cur.copy(
+        f"copy copy_in (col2, data) from stdin (format {format.name})"
+    ) as copy:
+        for row in sample_records:
+            copy.write_row((None, row[2]))
+
+    data = cur.execute("select * from copy_in order by 1").fetchall()
+    assert data == [(1, None, "hello"), (2, None, "world")]
+
+
+def test_copy_in_allchars(conn):
+    cur = conn.cursor()
+    ensure_table(cur, sample_tabledef)
+
+    conn.set_client_encoding("utf8")
+    with cur.copy("copy copy_in from stdin (format text)") as copy:
+        for i in range(1, 256):
+            copy.write_row((i, None, chr(i)))
+        copy.write_row((ord(eur), None, eur))
+
+    data = cur.execute(
+        """
+select col1 = ascii(data), col2 is null, length(data), count(*)
+from copy_in group by 1, 2, 3
+"""
+    ).fetchall()
+    assert data == [(True, True, 1, 256)]
+
+
 def ensure_table(cur, tabledef, name="copy_in"):
     cur.execute(f"drop table if exists {name}")
     cur.execute(f"create table {name} ({tabledef})")
index 289a6c428156dbbd85e2cb4f8669fde2e915245c..a48d82bb5427d8d07f4ac759e422de83a8762c78 100644 (file)
@@ -5,7 +5,7 @@ from psycopg3 import errors as e
 from psycopg3.adapt import Format
 
 from .test_copy import sample_text, sample_binary, sample_binary_rows  # noqa
-from .test_copy import sample_values, sample_records, sample_tabledef
+from .test_copy import eur, sample_values, sample_records, sample_tabledef
 
 pytestmark = pytest.mark.asyncio
 
@@ -156,6 +156,68 @@ async def test_copy_in_buffers_with_py_error(aconn):
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
 
 
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_in_records(aconn, format):
+    if format == Format.BINARY:
+        pytest.skip("TODO: implement int binary adapter")
+
+    cur = aconn.cursor()
+    await ensure_table(cur, sample_tabledef)
+
+    async with (
+        await cur.copy(f"copy copy_in from stdin (format {format.name})")
+    ) as copy:
+        for row in sample_records:
+            await copy.write_row(row)
+
+    await cur.execute("select * from copy_in order by 1")
+    data = await cur.fetchall()
+    assert data == sample_records
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_in_records_binary(aconn, format):
+    if format == Format.TEXT:
+        pytest.skip("TODO: remove after implementing int binary adapter")
+
+    cur = aconn.cursor()
+    await ensure_table(cur, "col1 serial primary key, col2 int, data text")
+
+    async with (
+        await cur.copy(
+            f"copy copy_in (col2, data) from stdin (format {format.name})"
+        )
+    ) as copy:
+        for row in sample_records:
+            await copy.write_row((None, row[2]))
+
+    await cur.execute("select * from copy_in order by 1")
+    data = await cur.fetchall()
+    assert data == [(1, None, "hello"), (2, None, "world")]
+
+
+async def test_copy_in_allchars(aconn):
+    cur = aconn.cursor()
+    await ensure_table(cur, sample_tabledef)
+
+    await aconn.set_client_encoding("utf8")
+    async with (
+        await cur.copy("copy copy_in from stdin (format text)")
+    ) as copy:
+        for i in range(1, 256):
+            await copy.write_row((i, None, chr(i)))
+        await copy.write_row((ord(eur), None, eur))
+
+    await cur.execute(
+        """
+select col1 = ascii(data), col2 is null, length(data), count(*)
+from copy_in group by 1, 2, 3
+"""
+    )
+    data = await cur.fetchall()
+    assert data == [(True, True, 1, 256)]
+
+
 async def ensure_table(cur, tabledef, name="copy_in"):
     await cur.execute(f"drop table if exists {name}")
     await cur.execute(f"create table {name} ({tabledef})")