]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Implemented rows and fields splitting on text copy
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 21 Jun 2020 10:15:15 +0000 (22:15 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 22 Jun 2020 05:09:07 +0000 (17:09 +1200)
psycopg3/copy.py
psycopg3/cursor.py
tests/test_copy.py

index 39b010ffc9c16fbb926ba8947085126cc2d9d595..2dccc8262bfcc5130542bebc6e8311495bf50f95 100644 (file)
@@ -4,9 +4,12 @@ psycopg3 copy support
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Optional
+import re
+from typing import Any, Deque, Dict, List, Match, Optional, Tuple
+from collections import deque
 
 from .proto import AdaptContext
+from . import errors as e
 from . import pq
 
 
@@ -14,14 +17,22 @@ class BaseCopy:
     def __init__(
         self,
         context: AdaptContext,
-        result: pq.proto.PGresult,
+        result: Optional[pq.proto.PGresult],
         format: pq.Format = pq.Format.TEXT,
     ):
-        from .transform import Transformer
+        from .adapt import Transformer
 
         self._transformer = Transformer(context)
-        self.format = format  # TODO: maybe not needed
+        self.format = format
         self.pgresult = result
+        self._finished = False
+
+        self._partial: Deque[bytes] = deque()
+        self._header_seen = False
+
+    @property
+    def finished(self) -> bool:
+        return self._finished
 
     @property
     def pgresult(self) -> Optional[pq.proto.PGresult]:
@@ -32,6 +43,72 @@ class BaseCopy:
         self._pgresult = result
         self._transformer.pgresult = result
 
+    def load(self, buffer: bytes) -> List[Tuple[Any, ...]]:
+        if self._finished:
+            raise e.ProgrammingError("copy already finished")
+
+        if self.format == pq.Format.TEXT:
+            return self._load_text(buffer)
+        else:
+            return self._load_binary(buffer)
+
+    def _load_text(self, buffer: bytes) -> List[Tuple[Any, ...]]:
+        rows = buffer.split(b"\n")
+        last_row = rows.pop(-1)
+
+        if self._partial and rows:
+            self._partial.append(rows[0])
+            rows[0] = b"".join(self._partial)
+            self._partial.clear()
+
+        if last_row:
+            self._partial.append(last_row)
+
+        # If there is no result then the transformer has no info about types
+        load_sequence = (
+            self._transformer.load_sequence
+            if self.pgresult is not None
+            else None
+        )
+
+        rv = []
+        for row in rows:
+            if row == b"\\.":
+                self._finished = True
+                break
+
+            values = row.split(b"\t")
+            prow = tuple(
+                _bsrepl_re.sub(_bsrepl_sub, v) if v != b"\\N" else None
+                for v in values
+            )
+            rv.append(
+                load_sequence(prow) if load_sequence is not None else prow
+            )
+
+        return rv
+
+    def _load_binary(self, buffer: bytes) -> List[Tuple[Any, ...]]:
+        raise NotImplementedError
+
+
+def _bsrepl_sub(
+    m: Match[bytes],
+    __map: Dict[bytes, bytes] = {
+        b"b": b"\b",
+        b"t": b"\t",
+        b"n": b"\n",
+        b"v": b"\v",
+        b"f": b"\f",
+        b"r": b"\r",
+    },
+) -> bytes:
+    g = m.group(0)
+    return __map.get(g, g)
+
+
+_bsrepl_re = re.compile(rb"\\(.)")
+
 
 class Copy(BaseCopy):
     pass
index d85f23966de0eb4867d09e0f4d908a70318bcefc..8ee944658c157eefc50dd2c049e3233f66a5d4f7 100644 (file)
@@ -363,7 +363,7 @@ class Cursor(BaseCursor):
         with self.connection.lock:
             self._start_query()
             self.connection._start_query()
-            # Make sure to avoid PQexec to avoid sending a mix of COPY and
+            # Make sure to avoid PQexec to avoid receiving a mix of COPY and
             # other operations.
             self._execute_send(statement, vars, no_pqexec=True)
             gen = execute(self.connection.pgconn)
@@ -371,7 +371,9 @@ class Cursor(BaseCursor):
             tx = self._transformer
 
         self._check_copy_results(results)
-        return Copy(context=tx, result=results[0], format=self.format)
+        return Copy(
+            context=tx, result=results[0], format=results[0].binary_tuples
+        )
 
 
 class AsyncCursor(BaseCursor):
@@ -475,7 +477,7 @@ class AsyncCursor(BaseCursor):
         async with self.connection.lock:
             self._start_query()
             await self.connection._start_query()
-            # Make sure to avoid PQexec to avoid sending a mix of COPY and
+            # Make sure to avoid PQexec to avoid receiving a mix of COPY and
             # other operations.
             self._execute_send(statement, vars, no_pqexec=True)
             gen = execute(self.connection.pgconn)
@@ -483,7 +485,9 @@ class AsyncCursor(BaseCursor):
             tx = self._transformer
 
         self._check_copy_results(results)
-        return AsyncCopy(context=tx, result=results[0], format=self.format)
+        return AsyncCopy(
+            context=tx, result=results[0], format=results[0].binary_tuples
+        )
 
 
 class NamedCursorMixin:
index 69b431efaad627ae48a3066b84ec9473c11d1c52..1e68b61d0a60682603716d8a5b403152d37d0e83 100644 (file)
@@ -1,12 +1,17 @@
 import pytest
 
+from psycopg3 import pq
+from psycopg3.adapt import Format
+from psycopg3.types import builtins
+
+
 sample_records = [(10, 20, "hello"), (40, None, "world")]
 
 sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')"
 
-sample_tabledef = "col1 int primary key, col2 int, date text"
+sample_tabledef = "col1 int primary key, col2 int, data text"
 
-sample_text = b"""
+sample_text = b"""\
 10\t20\thello
 40\t\\N\tworld
 """
@@ -20,93 +25,126 @@ sample_binary = """
 """
 
 
+def set_sample_attributes(res, format):
+    attrs = [
+        pq.PGresAttDesc(b"col1", 0, 0, format, builtins["int4"].oid, 0, 0),
+        pq.PGresAttDesc(b"col2", 0, 0, format, builtins["int4"].oid, 0, 0),
+        pq.PGresAttDesc(b"data", 0, 0, format, builtins["text"].oid, 0, 0),
+    ]
+    res.set_attributes(attrs)
+
+
 @pytest.mark.parametrize(
-    "format, block", [("text", sample_text), ("binary", sample_binary)]
+    "format, buffer",
+    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
 )
-def test_load(format, block):
+def test_load_noinfo(conn, format, buffer):
     from psycopg3.copy import Copy
 
-    copy = Copy(format=format)
-    records = copy.load(block)
+    copy = Copy(context=None, result=None, format=format)
+    records = copy.load(globals()[buffer])
+    assert records == as_bytes(sample_records)
+
+
+@pytest.mark.parametrize(
+    "format, buffer",
+    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+def test_load(conn, format, buffer):
+    from psycopg3.copy import Copy
+
+    res = conn.pgconn.make_empty_result(pq.ExecStatus.COPY_OUT)
+    set_sample_attributes(res, format)
+
+    copy = Copy(context=None, result=res, format=format)
+    records = copy.load(globals()[buffer])
     assert records == sample_records
 
 
 @pytest.mark.parametrize(
-    "format, block", [("text", sample_text), ("binary", sample_binary)]
+    "format, buffer",
+    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
 )
-def test_dump(format, block):
+def test_dump(conn, format, buffer):
     from psycopg3.copy import Copy
 
-    copy = Copy(format=format)
+    res = conn.pgconn.make_empty_result(pq.ExecStatus.COPY_OUT)
+    set_sample_attributes(res, format)
+
+    copy = Copy(context=None, result=res, format=format)
     assert copy.get_buffer() is None
     for row in sample_records:
         copy.dump(row)
-    assert copy.get_buffer() == block
+    assert copy.get_buffer() == globals()[buffer]
     assert copy.get_buffer() is None
 
 
 @pytest.mark.parametrize(
-    "format, block", [("text", sample_text), ("binary", sample_binary)]
+    "format, buffer",
+    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
 )
-def test_buffers(format, block):
+def test_buffers(format, buffer):
     from psycopg3.copy import Copy
 
     copy = Copy(format=format)
-    assert list(copy.buffers(sample_records)) == [block]
+    assert list(copy.buffers(sample_records)) == [globals()[buffer]]
 
 
 @pytest.mark.parametrize(
-    "format, want", [("text", sample_text), ("binary", sample_binary)]
+    "format, buffer",
+    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
 )
-def test_copy_out_read(conn, format, want):
+def test_copy_out_read(conn, format, buffer):
     cur = conn.cursor()
-    copy = cur.copy(f"copy ({sample_values}) to stdout (format {format})")
-    assert copy.read() == want
+    copy = cur.copy(f"copy ({sample_values}) to stdout (format {format.name})")
+    assert copy.read() == globals()[buffer]
     assert copy.read() is None
     assert copy.read() is None
 
 
-@pytest.mark.parametrize("format", ["text", "binary"])
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
 def test_iter(conn, format):
     cur = conn.cursor()
-    copy = cur.copy(f"copy ({sample_values}) to stdout (format {format})")
+    copy = cur.copy(f"copy ({sample_values}) to stdout (format {format.name})")
     assert list(copy) == sample_records
 
 
 @pytest.mark.parametrize(
-    "format, buffer", [("text", sample_text), ("binary", sample_binary)]
+    "format, buffer",
+    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
 )
 def test_copy_in_buffers(conn, format, buffer):
     cur = conn.cursor()
     ensure_table(cur, sample_tabledef)
-    copy = cur.copy(f"copy copy_in from stdin (format {format})")
-    copy.write(buffer)
+    copy = cur.copy(f"copy copy_in from stdin (format {format.name})")
+    copy.write(globals()[buffer])
     copy.end()
     data = cur.execute("select * from copy_in order by 1").fetchall()
     assert data == sample_records
 
 
 @pytest.mark.parametrize(
-    "format, buffer", [("text", sample_text), ("binary", sample_binary)]
+    "format, buffer",
+    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
 )
 def test_copy_in_buffers_with(conn, format, buffer):
     cur = conn.cursor()
     ensure_table(cur, sample_tabledef)
-    with cur.copy(f"copy copy_in from stdin (format {format})") as copy:
-        copy.write(buffer)
+    with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+        copy.write(globals()[buffer])
 
     data = cur.execute("select * from copy_in order by 1").fetchall()
     assert data == sample_records
 
 
 @pytest.mark.parametrize(
-    "format, buffer", [("text", sample_text), ("binary", sample_binary)]
+    "format", [(Format.TEXT,), (Format.BINARY,)],
 )
-def test_copy_in_records(conn, format, buffer):
+def test_copy_in_records(conn, format):
     cur = conn.cursor()
     ensure_table(cur, sample_tabledef)
 
-    with cur.copy(f"copy copy_in from stdin (format {format})") as copy:
+    with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
         for row in sample_records:
             copy.write(row)
 
@@ -117,3 +155,20 @@ def test_copy_in_records(conn, format, buffer):
 def ensure_table(cur, tabledef, name="copy_in"):
     cur.execute(f"drop table if exists {name}")
     cur.execute(f"create table {name} ({tabledef})")
+
+
+def as_bytes(records):
+    out = []
+    for rin in records:
+        rout = []
+        for v in rin:
+            if v is None or isinstance(v, bytes):
+                rout.append(v)
+                continue
+            if not isinstance(v, str):
+                v = str(v)
+            if isinstance(v, str):
+                v = v.encode("utf8")
+            rout.append(v)
+        out.append(tuple(rout))
+    return out