From: Daniele Varrazzo Date: Tue, 30 Jun 2020 18:41:50 +0000 (+1200) Subject: Added row-by-row copy in X-Git-Tag: 3.0.dev0~472 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=84a3db1e28804d483236edeb4215459b7b8c83e2;p=thirdparty%2Fpsycopg.git Added row-by-row copy in --- diff --git a/psycopg3/copy.py b/psycopg3/copy.py index 38c9e88f2..f8931deca 100644 --- a/psycopg3/copy.py +++ b/psycopg3/copy.py @@ -6,8 +6,9 @@ psycopg3 copy support import re import codecs +import struct from typing import TYPE_CHECKING, AsyncGenerator, Generator -from typing import Dict, Match, Optional, Type, Union +from typing import Any, Dict, Match, Optional, Sequence, Type, Union from types import TracebackType from . import pq @@ -31,9 +32,15 @@ class BaseCopy: self._transformer = Transformer(context) self.format = format self.pgresult = result + self._first_row = True self._finished = False self._codec: Optional[codecs.CodecInfo] = None + if format == pq.Format.TEXT: + self._format_row = self._format_row_text + else: + self._format_row = self._format_row_binary + @property def finished(self) -> bool: return self._finished @@ -76,23 +83,69 @@ class BaseCopy: self._codec = self.connection.codec return self._codec.encode(data)[0] + else: + raise TypeError(f"can't write {type(data).__name__}") + + def format_row(self, row: Sequence[Any]) -> bytes: + # TODO: cache this, or pass just a single format + formats = [self.format] * len(row) + out, _ = self._transformer.dump_sequence(row, formats) + return self._format_row(out) + + def _format_row_text(self, row: Sequence[Optional[bytes]],) -> bytes: + return ( + b"\t".join( + _bsrepl_re.sub(_bsrepl_sub, item) + if item is not None + else br"\N" + for item in row + ) + + b"\n" + ) + + def _format_row_binary( + self, + row: Sequence[Optional[bytes]], + __int2_struct: struct.Struct = struct.Struct("!h"), + __int4_struct: struct.Struct = struct.Struct("!i"), + ) -> bytes: + out = [] + if self._first_row: + out.append( + # Signature, flags, extra length + b"PGCOPY\n\xff\r\n\0" + b"\x00\x00\x00\x00" + b"\x00\x00\x00\x00" + ) + self._first_row = False + + out.append(__int2_struct.pack(len(row))) + for item in row: + if item is not None: + out.append(__int4_struct.pack(len(item))) + out.append(item) + else: + out.append(b"\xff\xff\xff\xff") + + return b"".join(out) + def _bsrepl_sub( m: Match[bytes], __map: Dict[bytes, bytes] = { - b"b": b"\b", - b"t": b"\t", - b"n": b"\n", - b"v": b"\v", - b"f": b"\f", - b"r": b"\r", + b"\b": b"\\b", + b"\t": b"\\t", + b"\n": b"\\n", + b"\v": b"\\v", + b"\f": b"\\f", + b"\r": b"\\r", + b"\\": b"\\\\", }, ) -> bytes: - g = m.group(0) - return __map.get(g, g) + return __map[m.group(0)] -_bsrepl_re = re.compile(rb"\\(.)") +_bsrepl_re = re.compile(b"[\b\t\n\v\f\r\\\\]") class Copy(BaseCopy): @@ -119,6 +172,10 @@ class Copy(BaseCopy): conn = self.connection conn.wait(copy_to(conn.pgconn, self._ensure_bytes(buffer))) + def write_row(self, row: Sequence[Any]) -> None: + data = self.format_row(row) + self.write(data) + def finish(self, error: Optional[str] = None) -> None: conn = self.connection berr = ( @@ -139,6 +196,9 @@ class Copy(BaseCopy): exc_tb: Optional[TracebackType], ) -> None: if exc_val is None: + if self.format == pq.Format.BINARY and not self._first_row: + # send EOF only if we copied binary rows (_first_row is False) + self.write(b"\xff\xff") self.finish() else: self.finish(str(exc_val)) @@ -173,6 +233,10 @@ class AsyncCopy(BaseCopy): conn = self.connection await conn.wait(copy_to(conn.pgconn, self._ensure_bytes(buffer))) + async def write_row(self, row: Sequence[Any]) -> None: + data = self.format_row(row) + await self.write(data) + async def finish(self, error: Optional[str] = None) -> None: conn = self.connection berr = ( @@ -193,6 +257,9 @@ class AsyncCopy(BaseCopy): exc_tb: Optional[TracebackType], ) -> None: if exc_val is None: + if self.format == pq.Format.BINARY and not self._first_row: + # send EOF only if we copied binary rows (_first_row is False) + await self.write(b"\xff\xff") await self.finish() else: await self.finish(str(exc_val)) diff --git a/tests/test_copy.py b/tests/test_copy.py index e739f1985..993a1bb69 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -4,6 +4,7 @@ from psycopg3 import pq from psycopg3 import errors as e from psycopg3.adapt import Format +eur = "\u20ac" sample_records = [(10, 20, "hello"), (40, None, "world")] @@ -141,22 +142,59 @@ def test_copy_in_buffers_with_py_error(conn): assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR -@pytest.mark.xfail -@pytest.mark.parametrize( - "format", [(Format.TEXT,), (Format.BINARY,)], -) +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) def test_copy_in_records(conn, format): + if format == Format.BINARY: + pytest.skip("TODO: implement int binary adapter") + cur = conn.cursor() ensure_table(cur, sample_tabledef) with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: for row in sample_records: - copy.write(row) + copy.write_row(row) data = cur.execute("select * from copy_in order by 1").fetchall() assert data == sample_records +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_copy_in_records_binary(conn, format): + if format == Format.TEXT: + pytest.skip("TODO: remove after implementing int binary adapter") + + cur = conn.cursor() + ensure_table(cur, "col1 serial primary key, col2 int, data text") + + with cur.copy( + f"copy copy_in (col2, data) from stdin (format {format.name})" + ) as copy: + for row in sample_records: + copy.write_row((None, row[2])) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == [(1, None, "hello"), (2, None, "world")] + + +def test_copy_in_allchars(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + conn.set_client_encoding("utf8") + with cur.copy("copy copy_in from stdin (format text)") as copy: + for i in range(1, 256): + copy.write_row((i, None, chr(i))) + copy.write_row((ord(eur), None, eur)) + + data = cur.execute( + """ +select col1 = ascii(data), col2 is null, length(data), count(*) +from copy_in group by 1, 2, 3 +""" + ).fetchall() + assert data == [(True, True, 1, 256)] + + def ensure_table(cur, tabledef, name="copy_in"): cur.execute(f"drop table if exists {name}") cur.execute(f"create table {name} ({tabledef})") diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 289a6c428..a48d82bb5 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -5,7 +5,7 @@ from psycopg3 import errors as e from psycopg3.adapt import Format from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa -from .test_copy import sample_values, sample_records, sample_tabledef +from .test_copy import eur, sample_values, sample_records, sample_tabledef pytestmark = pytest.mark.asyncio @@ -156,6 +156,68 @@ async def test_copy_in_buffers_with_py_error(aconn): assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +async def test_copy_in_records(aconn, format): + if format == Format.BINARY: + pytest.skip("TODO: implement int binary adapter") + + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with ( + await cur.copy(f"copy copy_in from stdin (format {format.name})") + ) as copy: + for row in sample_records: + await copy.write_row(row) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +async def test_copy_in_records_binary(aconn, format): + if format == Format.TEXT: + pytest.skip("TODO: remove after implementing int binary adapter") + + cur = aconn.cursor() + await ensure_table(cur, "col1 serial primary key, col2 int, data text") + + async with ( + await cur.copy( + f"copy copy_in (col2, data) from stdin (format {format.name})" + ) + ) as copy: + for row in sample_records: + await copy.write_row((None, row[2])) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == [(1, None, "hello"), (2, None, "world")] + + +async def test_copy_in_allchars(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + await aconn.set_client_encoding("utf8") + async with ( + await cur.copy("copy copy_in from stdin (format text)") + ) as copy: + for i in range(1, 256): + await copy.write_row((i, None, chr(i))) + await copy.write_row((ord(eur), None, eur)) + + await cur.execute( + """ +select col1 = ascii(data), col2 is null, length(data), count(*) +from copy_in group by 1, 2, 3 +""" + ) + data = await cur.fetchall() + assert data == [(True, True, 1, 256)] + + async def ensure_table(cur, tabledef, name="copy_in"): await cur.execute(f"drop table if exists {name}") await cur.execute(f"create table {name} ({tabledef})")