import re
import codecs
+import struct
from typing import TYPE_CHECKING, AsyncGenerator, Generator
-from typing import Dict, Match, Optional, Type, Union
+from typing import Any, Dict, Match, Optional, Sequence, Type, Union
from types import TracebackType
from . import pq
self._transformer = Transformer(context)
self.format = format
self.pgresult = result
+ self._first_row = True
self._finished = False
self._codec: Optional[codecs.CodecInfo] = None
+ if format == pq.Format.TEXT:
+ self._format_row = self._format_row_text
+ else:
+ self._format_row = self._format_row_binary
+
@property
def finished(self) -> bool:
return self._finished
self._codec = self.connection.codec
return self._codec.encode(data)[0]
+ else:
+ raise TypeError(f"can't write {type(data).__name__}")
+
+ def format_row(self, row: Sequence[Any]) -> bytes:
+ # TODO: cache this, or pass just a single format
+ formats = [self.format] * len(row)
+ out, _ = self._transformer.dump_sequence(row, formats)
+ return self._format_row(out)
+
+ def _format_row_text(self, row: Sequence[Optional[bytes]],) -> bytes:
+ return (
+ b"\t".join(
+ _bsrepl_re.sub(_bsrepl_sub, item)
+ if item is not None
+ else br"\N"
+ for item in row
+ )
+ + b"\n"
+ )
+
+ def _format_row_binary(
+ self,
+ row: Sequence[Optional[bytes]],
+ __int2_struct: struct.Struct = struct.Struct("!h"),
+ __int4_struct: struct.Struct = struct.Struct("!i"),
+ ) -> bytes:
+ out = []
+ if self._first_row:
+ out.append(
+ # Signature, flags, extra length
+ b"PGCOPY\n\xff\r\n\0"
+ b"\x00\x00\x00\x00"
+ b"\x00\x00\x00\x00"
+ )
+ self._first_row = False
+
+ out.append(__int2_struct.pack(len(row)))
+ for item in row:
+ if item is not None:
+ out.append(__int4_struct.pack(len(item)))
+ out.append(item)
+ else:
+ out.append(b"\xff\xff\xff\xff")
+
+ return b"".join(out)
+
def _bsrepl_sub(
m: Match[bytes],
__map: Dict[bytes, bytes] = {
- b"b": b"\b",
- b"t": b"\t",
- b"n": b"\n",
- b"v": b"\v",
- b"f": b"\f",
- b"r": b"\r",
+ b"\b": b"\\b",
+ b"\t": b"\\t",
+ b"\n": b"\\n",
+ b"\v": b"\\v",
+ b"\f": b"\\f",
+ b"\r": b"\\r",
+ b"\\": b"\\\\",
},
) -> bytes:
- g = m.group(0)
- return __map.get(g, g)
+ return __map[m.group(0)]
-_bsrepl_re = re.compile(rb"\\(.)")
+_bsrepl_re = re.compile(b"[\b\t\n\v\f\r\\\\]")
class Copy(BaseCopy):
conn = self.connection
conn.wait(copy_to(conn.pgconn, self._ensure_bytes(buffer)))
+ def write_row(self, row: Sequence[Any]) -> None:
+ data = self.format_row(row)
+ self.write(data)
+
def finish(self, error: Optional[str] = None) -> None:
conn = self.connection
berr = (
exc_tb: Optional[TracebackType],
) -> None:
if exc_val is None:
+ if self.format == pq.Format.BINARY and not self._first_row:
+ # send EOF only if we copied binary rows (_first_row is False)
+ self.write(b"\xff\xff")
self.finish()
else:
self.finish(str(exc_val))
conn = self.connection
await conn.wait(copy_to(conn.pgconn, self._ensure_bytes(buffer)))
+ async def write_row(self, row: Sequence[Any]) -> None:
+ data = self.format_row(row)
+ await self.write(data)
+
async def finish(self, error: Optional[str] = None) -> None:
conn = self.connection
berr = (
exc_tb: Optional[TracebackType],
) -> None:
if exc_val is None:
+ if self.format == pq.Format.BINARY and not self._first_row:
+ # send EOF only if we copied binary rows (_first_row is False)
+ await self.write(b"\xff\xff")
await self.finish()
else:
await self.finish(str(exc_val))
from psycopg3 import errors as e
from psycopg3.adapt import Format
+eur = "\u20ac"
sample_records = [(10, 20, "hello"), (40, None, "world")]
assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
-@pytest.mark.xfail
-@pytest.mark.parametrize(
- "format", [(Format.TEXT,), (Format.BINARY,)],
-)
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
def test_copy_in_records(conn, format):
+ if format == Format.BINARY:
+ pytest.skip("TODO: implement int binary adapter")
+
cur = conn.cursor()
ensure_table(cur, sample_tabledef)
with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
for row in sample_records:
- copy.write(row)
+ copy.write_row(row)
data = cur.execute("select * from copy_in order by 1").fetchall()
assert data == sample_records
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_copy_in_records_binary(conn, format):
+ if format == Format.TEXT:
+ pytest.skip("TODO: remove after implementing int binary adapter")
+
+ cur = conn.cursor()
+ ensure_table(cur, "col1 serial primary key, col2 int, data text")
+
+ with cur.copy(
+ f"copy copy_in (col2, data) from stdin (format {format.name})"
+ ) as copy:
+ for row in sample_records:
+ copy.write_row((None, row[2]))
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == [(1, None, "hello"), (2, None, "world")]
+
+
+def test_copy_in_allchars(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ conn.set_client_encoding("utf8")
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ for i in range(1, 256):
+ copy.write_row((i, None, chr(i)))
+ copy.write_row((ord(eur), None, eur))
+
+ data = cur.execute(
+ """
+select col1 = ascii(data), col2 is null, length(data), count(*)
+from copy_in group by 1, 2, 3
+"""
+ ).fetchall()
+ assert data == [(True, True, 1, 256)]
+
+
def ensure_table(cur, tabledef, name="copy_in"):
cur.execute(f"drop table if exists {name}")
cur.execute(f"create table {name} ({tabledef})")
from psycopg3.adapt import Format
from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa
-from .test_copy import sample_values, sample_records, sample_tabledef
+from .test_copy import eur, sample_values, sample_records, sample_tabledef
pytestmark = pytest.mark.asyncio
assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_in_records(aconn, format):
+ if format == Format.BINARY:
+ pytest.skip("TODO: implement int binary adapter")
+
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ async with (
+ await cur.copy(f"copy copy_in from stdin (format {format.name})")
+ ) as copy:
+ for row in sample_records:
+ await copy.write_row(row)
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_in_records_binary(aconn, format):
+ if format == Format.TEXT:
+ pytest.skip("TODO: remove after implementing int binary adapter")
+
+ cur = aconn.cursor()
+ await ensure_table(cur, "col1 serial primary key, col2 int, data text")
+
+ async with (
+ await cur.copy(
+ f"copy copy_in (col2, data) from stdin (format {format.name})"
+ )
+ ) as copy:
+ for row in sample_records:
+ await copy.write_row((None, row[2]))
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == [(1, None, "hello"), (2, None, "world")]
+
+
+async def test_copy_in_allchars(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ await aconn.set_client_encoding("utf8")
+ async with (
+ await cur.copy("copy copy_in from stdin (format text)")
+ ) as copy:
+ for i in range(1, 256):
+ await copy.write_row((i, None, chr(i)))
+ await copy.write_row((ord(eur), None, eur))
+
+ await cur.execute(
+ """
+select col1 = ascii(data), col2 is null, length(data), count(*)
+from copy_in group by 1, 2, 3
+"""
+ )
+ data = await cur.fetchall()
+ assert data == [(True, True, 1, 256)]
+
+
async def ensure_table(cur, tabledef, name="copy_in"):
await cur.execute(f"drop table if exists {name}")
await cur.execute(f"create table {name} ({tabledef})")