# Copyright (C) 2020 The Psycopg Team
-from typing import Optional
+import re
+from typing import Any, Deque, Dict, List, Match, Optional, Tuple
+from collections import deque
from .proto import AdaptContext
+from . import errors as e
from . import pq
def __init__(
self,
context: AdaptContext,
- result: pq.proto.PGresult,
+ result: Optional[pq.proto.PGresult],
format: pq.Format = pq.Format.TEXT,
):
- from .transform import Transformer
+ from .adapt import Transformer
self._transformer = Transformer(context)
- self.format = format # TODO: maybe not needed
+ self.format = format
self.pgresult = result
+ self._finished = False
+
+ self._partial: Deque[bytes] = deque()
+ self._header_seen = False
+
+ @property
+ def finished(self) -> bool:
+ return self._finished
@property
def pgresult(self) -> Optional[pq.proto.PGresult]:
self._pgresult = result
self._transformer.pgresult = result
+ def load(self, buffer: bytes) -> List[Tuple[Any, ...]]:
+ if self._finished:
+ raise e.ProgrammingError("copy already finished")
+
+ if self.format == pq.Format.TEXT:
+ return self._load_text(buffer)
+ else:
+ return self._load_binary(buffer)
+
+ def _load_text(self, buffer: bytes) -> List[Tuple[Any, ...]]:
+ rows = buffer.split(b"\n")
+ last_row = rows.pop(-1)
+
+ if self._partial and rows:
+ self._partial.append(rows[0])
+ rows[0] = b"".join(self._partial)
+ self._partial.clear()
+
+ if last_row:
+ self._partial.append(last_row)
+
+ # If there is no result then the transformer has no info about types
+ load_sequence = (
+ self._transformer.load_sequence
+ if self.pgresult is not None
+ else None
+ )
+
+ rv = []
+ for row in rows:
+ if row == b"\\.":
+ self._finished = True
+ break
+
+ values = row.split(b"\t")
+ prow = tuple(
+ _bsrepl_re.sub(_bsrepl_sub, v) if v != b"\\N" else None
+ for v in values
+ )
+ rv.append(
+ load_sequence(prow) if load_sequence is not None else prow
+ )
+
+ return rv
+
+ def _load_binary(self, buffer: bytes) -> List[Tuple[Any, ...]]:
+ raise NotImplementedError
+
+
+def _bsrepl_sub(
+ m: Match[bytes],
+ __map: Dict[bytes, bytes] = {
+ b"b": b"\b",
+ b"t": b"\t",
+ b"n": b"\n",
+ b"v": b"\v",
+ b"f": b"\f",
+ b"r": b"\r",
+ },
+) -> bytes:
+ g = m.group(0)
+ return __map.get(g, g)
+
+
+_bsrepl_re = re.compile(rb"\\(.)")
+
class Copy(BaseCopy):
pass
import pytest
+from psycopg3 import pq
+from psycopg3.adapt import Format
+from psycopg3.types import builtins
+
+
sample_records = [(10, 20, "hello"), (40, None, "world")]
sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')"
-sample_tabledef = "col1 int primary key, col2 int, date text"
+sample_tabledef = "col1 int primary key, col2 int, data text"
-sample_text = b"""
+sample_text = b"""\
10\t20\thello
40\t\\N\tworld
"""
"""
+def set_sample_attributes(res, format):
+ attrs = [
+ pq.PGresAttDesc(b"col1", 0, 0, format, builtins["int4"].oid, 0, 0),
+ pq.PGresAttDesc(b"col2", 0, 0, format, builtins["int4"].oid, 0, 0),
+ pq.PGresAttDesc(b"data", 0, 0, format, builtins["text"].oid, 0, 0),
+ ]
+ res.set_attributes(attrs)
+
+
@pytest.mark.parametrize(
- "format, block", [("text", sample_text), ("binary", sample_binary)]
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
)
-def test_load(format, block):
+def test_load_noinfo(conn, format, buffer):
from psycopg3.copy import Copy
- copy = Copy(format=format)
- records = copy.load(block)
+ copy = Copy(context=None, result=None, format=format)
+ records = copy.load(globals()[buffer])
+ assert records == as_bytes(sample_records)
+
+
+@pytest.mark.parametrize(
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+def test_load(conn, format, buffer):
+ from psycopg3.copy import Copy
+
+ res = conn.pgconn.make_empty_result(pq.ExecStatus.COPY_OUT)
+ set_sample_attributes(res, format)
+
+ copy = Copy(context=None, result=res, format=format)
+ records = copy.load(globals()[buffer])
assert records == sample_records
@pytest.mark.parametrize(
- "format, block", [("text", sample_text), ("binary", sample_binary)]
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
)
-def test_dump(format, block):
+def test_dump(conn, format, buffer):
from psycopg3.copy import Copy
- copy = Copy(format=format)
+ res = conn.pgconn.make_empty_result(pq.ExecStatus.COPY_OUT)
+ set_sample_attributes(res, format)
+
+ copy = Copy(context=None, result=res, format=format)
assert copy.get_buffer() is None
for row in sample_records:
copy.dump(row)
- assert copy.get_buffer() == block
+ assert copy.get_buffer() == globals()[buffer]
assert copy.get_buffer() is None
@pytest.mark.parametrize(
- "format, block", [("text", sample_text), ("binary", sample_binary)]
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
)
-def test_buffers(format, block):
+def test_buffers(format, buffer):
from psycopg3.copy import Copy
copy = Copy(format=format)
- assert list(copy.buffers(sample_records)) == [block]
+ assert list(copy.buffers(sample_records)) == [globals()[buffer]]
@pytest.mark.parametrize(
- "format, want", [("text", sample_text), ("binary", sample_binary)]
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
)
-def test_copy_out_read(conn, format, want):
+def test_copy_out_read(conn, format, buffer):
cur = conn.cursor()
- copy = cur.copy(f"copy ({sample_values}) to stdout (format {format})")
- assert copy.read() == want
+ copy = cur.copy(f"copy ({sample_values}) to stdout (format {format.name})")
+ assert copy.read() == globals()[buffer]
assert copy.read() is None
assert copy.read() is None
-@pytest.mark.parametrize("format", ["text", "binary"])
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
def test_iter(conn, format):
cur = conn.cursor()
- copy = cur.copy(f"copy ({sample_values}) to stdout (format {format})")
+ copy = cur.copy(f"copy ({sample_values}) to stdout (format {format.name})")
assert list(copy) == sample_records
@pytest.mark.parametrize(
- "format, buffer", [("text", sample_text), ("binary", sample_binary)]
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
)
def test_copy_in_buffers(conn, format, buffer):
cur = conn.cursor()
ensure_table(cur, sample_tabledef)
- copy = cur.copy(f"copy copy_in from stdin (format {format})")
- copy.write(buffer)
+ copy = cur.copy(f"copy copy_in from stdin (format {format.name})")
+ copy.write(globals()[buffer])
copy.end()
data = cur.execute("select * from copy_in order by 1").fetchall()
assert data == sample_records
@pytest.mark.parametrize(
- "format, buffer", [("text", sample_text), ("binary", sample_binary)]
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
)
def test_copy_in_buffers_with(conn, format, buffer):
cur = conn.cursor()
ensure_table(cur, sample_tabledef)
- with cur.copy(f"copy copy_in from stdin (format {format})") as copy:
- copy.write(buffer)
+ with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ copy.write(globals()[buffer])
data = cur.execute("select * from copy_in order by 1").fetchall()
assert data == sample_records
@pytest.mark.parametrize(
- "format, buffer", [("text", sample_text), ("binary", sample_binary)]
+ "format", [(Format.TEXT,), (Format.BINARY,)],
)
-def test_copy_in_records(conn, format, buffer):
+def test_copy_in_records(conn, format):
cur = conn.cursor()
ensure_table(cur, sample_tabledef)
- with cur.copy(f"copy copy_in from stdin (format {format})") as copy:
+ with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
for row in sample_records:
copy.write(row)
def ensure_table(cur, tabledef, name="copy_in"):
cur.execute(f"drop table if exists {name}")
cur.execute(f"create table {name} ({tabledef})")
+
+
+def as_bytes(records):
+ out = []
+ for rin in records:
+ rout = []
+ for v in rin:
+ if v is None or isinstance(v, bytes):
+ rout.append(v)
+ continue
+ if not isinstance(v, str):
+ v = str(v)
+ if isinstance(v, str):
+ v = v.encode("utf8")
+ rout.append(v)
+ out.append(tuple(rout))
+ return out