From: Daniele Varrazzo Date: Sun, 21 Jun 2020 10:15:15 +0000 (+1200) Subject: Implemented rows and fields splitting on text copy X-Git-Tag: 3.0.dev0~483 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=561be60770c62faa5a63dd8afaed91fff856eb5b;p=thirdparty%2Fpsycopg.git Implemented rows and fields splitting on text copy --- diff --git a/psycopg3/copy.py b/psycopg3/copy.py index 39b010ffc..2dccc8262 100644 --- a/psycopg3/copy.py +++ b/psycopg3/copy.py @@ -4,9 +4,12 @@ psycopg3 copy support # Copyright (C) 2020 The Psycopg Team -from typing import Optional +import re +from typing import Any, Deque, Dict, List, Match, Optional, Tuple +from collections import deque from .proto import AdaptContext +from . import errors as e from . import pq @@ -14,14 +17,22 @@ class BaseCopy: def __init__( self, context: AdaptContext, - result: pq.proto.PGresult, + result: Optional[pq.proto.PGresult], format: pq.Format = pq.Format.TEXT, ): - from .transform import Transformer + from .adapt import Transformer self._transformer = Transformer(context) - self.format = format # TODO: maybe not needed + self.format = format self.pgresult = result + self._finished = False + + self._partial: Deque[bytes] = deque() + self._header_seen = False + + @property + def finished(self) -> bool: + return self._finished @property def pgresult(self) -> Optional[pq.proto.PGresult]: @@ -32,6 +43,72 @@ class BaseCopy: self._pgresult = result self._transformer.pgresult = result + def load(self, buffer: bytes) -> List[Tuple[Any, ...]]: + if self._finished: + raise e.ProgrammingError("copy already finished") + + if self.format == pq.Format.TEXT: + return self._load_text(buffer) + else: + return self._load_binary(buffer) + + def _load_text(self, buffer: bytes) -> List[Tuple[Any, ...]]: + rows = buffer.split(b"\n") + last_row = rows.pop(-1) + + if self._partial and rows: + self._partial.append(rows[0]) + rows[0] = b"".join(self._partial) + self._partial.clear() + + if last_row: + self._partial.append(last_row) + + # If there is no result then the transformer has no info about types + load_sequence = ( + self._transformer.load_sequence + if self.pgresult is not None + else None + ) + + rv = [] + for row in rows: + if row == b"\\.": + self._finished = True + break + + values = row.split(b"\t") + prow = tuple( + _bsrepl_re.sub(_bsrepl_sub, v) if v != b"\\N" else None + for v in values + ) + rv.append( + load_sequence(prow) if load_sequence is not None else prow + ) + + return rv + + def _load_binary(self, buffer: bytes) -> List[Tuple[Any, ...]]: + raise NotImplementedError + + +def _bsrepl_sub( + m: Match[bytes], + __map: Dict[bytes, bytes] = { + b"b": b"\b", + b"t": b"\t", + b"n": b"\n", + b"v": b"\v", + b"f": b"\f", + b"r": b"\r", + }, +) -> bytes: + g = m.group(0) + return __map.get(g, g) + + +_bsrepl_re = re.compile(rb"\\(.)") + class Copy(BaseCopy): pass diff --git a/psycopg3/cursor.py b/psycopg3/cursor.py index d85f23966..8ee944658 100644 --- a/psycopg3/cursor.py +++ b/psycopg3/cursor.py @@ -363,7 +363,7 @@ class Cursor(BaseCursor): with self.connection.lock: self._start_query() self.connection._start_query() - # Make sure to avoid PQexec to avoid sending a mix of COPY and + # Make sure to avoid PQexec to avoid receiving a mix of COPY and # other operations. self._execute_send(statement, vars, no_pqexec=True) gen = execute(self.connection.pgconn) @@ -371,7 +371,9 @@ class Cursor(BaseCursor): tx = self._transformer self._check_copy_results(results) - return Copy(context=tx, result=results[0], format=self.format) + return Copy( + context=tx, result=results[0], format=results[0].binary_tuples + ) class AsyncCursor(BaseCursor): @@ -475,7 +477,7 @@ class AsyncCursor(BaseCursor): async with self.connection.lock: self._start_query() await self.connection._start_query() - # Make sure to avoid PQexec to avoid sending a mix of COPY and + # Make sure to avoid PQexec to avoid receiving a mix of COPY and # other operations. self._execute_send(statement, vars, no_pqexec=True) gen = execute(self.connection.pgconn) @@ -483,7 +485,9 @@ class AsyncCursor(BaseCursor): tx = self._transformer self._check_copy_results(results) - return AsyncCopy(context=tx, result=results[0], format=self.format) + return AsyncCopy( + context=tx, result=results[0], format=results[0].binary_tuples + ) class NamedCursorMixin: diff --git a/tests/test_copy.py b/tests/test_copy.py index 69b431efa..1e68b61d0 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -1,12 +1,17 @@ import pytest +from psycopg3 import pq +from psycopg3.adapt import Format +from psycopg3.types import builtins + + sample_records = [(10, 20, "hello"), (40, None, "world")] sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')" -sample_tabledef = "col1 int primary key, col2 int, date text" +sample_tabledef = "col1 int primary key, col2 int, data text" -sample_text = b""" +sample_text = b"""\ 10\t20\thello 40\t\\N\tworld """ @@ -20,93 +25,126 @@ sample_binary = """ """ +def set_sample_attributes(res, format): + attrs = [ + pq.PGresAttDesc(b"col1", 0, 0, format, builtins["int4"].oid, 0, 0), + pq.PGresAttDesc(b"col2", 0, 0, format, builtins["int4"].oid, 0, 0), + pq.PGresAttDesc(b"data", 0, 0, format, builtins["text"].oid, 0, 0), + ] + res.set_attributes(attrs) + + @pytest.mark.parametrize( - "format, block", [("text", sample_text), ("binary", sample_binary)] + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], ) -def test_load(format, block): +def test_load_noinfo(conn, format, buffer): from psycopg3.copy import Copy - copy = Copy(format=format) - records = copy.load(block) + copy = Copy(context=None, result=None, format=format) + records = copy.load(globals()[buffer]) + assert records == as_bytes(sample_records) + + +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +def test_load(conn, format, buffer): + from psycopg3.copy import Copy + + res = conn.pgconn.make_empty_result(pq.ExecStatus.COPY_OUT) + set_sample_attributes(res, format) + + copy = Copy(context=None, result=res, format=format) + records = copy.load(globals()[buffer]) assert records == sample_records @pytest.mark.parametrize( - "format, block", [("text", sample_text), ("binary", sample_binary)] + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], ) -def test_dump(format, block): +def test_dump(conn, format, buffer): from psycopg3.copy import Copy - copy = Copy(format=format) + res = conn.pgconn.make_empty_result(pq.ExecStatus.COPY_OUT) + set_sample_attributes(res, format) + + copy = Copy(context=None, result=res, format=format) assert copy.get_buffer() is None for row in sample_records: copy.dump(row) - assert copy.get_buffer() == block + assert copy.get_buffer() == globals()[buffer] assert copy.get_buffer() is None @pytest.mark.parametrize( - "format, block", [("text", sample_text), ("binary", sample_binary)] + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], ) -def test_buffers(format, block): +def test_buffers(format, buffer): from psycopg3.copy import Copy copy = Copy(format=format) - assert list(copy.buffers(sample_records)) == [block] + assert list(copy.buffers(sample_records)) == [globals()[buffer]] @pytest.mark.parametrize( - "format, want", [("text", sample_text), ("binary", sample_binary)] + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], ) -def test_copy_out_read(conn, format, want): +def test_copy_out_read(conn, format, buffer): cur = conn.cursor() - copy = cur.copy(f"copy ({sample_values}) to stdout (format {format})") - assert copy.read() == want + copy = cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") + assert copy.read() == globals()[buffer] assert copy.read() is None assert copy.read() is None -@pytest.mark.parametrize("format", ["text", "binary"]) +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) def test_iter(conn, format): cur = conn.cursor() - copy = cur.copy(f"copy ({sample_values}) to stdout (format {format})") + copy = cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") assert list(copy) == sample_records @pytest.mark.parametrize( - "format, buffer", [("text", sample_text), ("binary", sample_binary)] + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], ) def test_copy_in_buffers(conn, format, buffer): cur = conn.cursor() ensure_table(cur, sample_tabledef) - copy = cur.copy(f"copy copy_in from stdin (format {format})") - copy.write(buffer) + copy = cur.copy(f"copy copy_in from stdin (format {format.name})") + copy.write(globals()[buffer]) copy.end() data = cur.execute("select * from copy_in order by 1").fetchall() assert data == sample_records @pytest.mark.parametrize( - "format, buffer", [("text", sample_text), ("binary", sample_binary)] + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], ) def test_copy_in_buffers_with(conn, format, buffer): cur = conn.cursor() ensure_table(cur, sample_tabledef) - with cur.copy(f"copy copy_in from stdin (format {format})") as copy: - copy.write(buffer) + with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + copy.write(globals()[buffer]) data = cur.execute("select * from copy_in order by 1").fetchall() assert data == sample_records @pytest.mark.parametrize( - "format, buffer", [("text", sample_text), ("binary", sample_binary)] + "format", [(Format.TEXT,), (Format.BINARY,)], ) -def test_copy_in_records(conn, format, buffer): +def test_copy_in_records(conn, format): cur = conn.cursor() ensure_table(cur, sample_tabledef) - with cur.copy(f"copy copy_in from stdin (format {format})") as copy: + with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: for row in sample_records: copy.write(row) @@ -117,3 +155,20 @@ def test_copy_in_records(conn, format, buffer): def ensure_table(cur, tabledef, name="copy_in"): cur.execute(f"drop table if exists {name}") cur.execute(f"create table {name} ({tabledef})") + + +def as_bytes(records): + out = [] + for rin in records: + rout = [] + for v in rin: + if v is None or isinstance(v, bytes): + rout.append(v) + continue + if not isinstance(v, str): + v = str(v) + if isinstance(v, str): + v = v.encode("utf8") + rout.append(v) + out.append(tuple(rout)) + return out