from typing_extensions import Protocol
from . import pq
+from . import errors as e
from .pq import Format, ExecStatus
from .proto import ConnectionType, PQGen, Transformer
from .generators import copy_from, copy_to, copy_end
from .connection import Connection, AsyncConnection # noqa: F401
-class FormatFunc(Protocol):
+class CopyFormatFunc(Protocol):
"""The type of a function to format copy data to a bytearray."""
def __call__(
...
+class CopyParseFunc(Protocol):
+ def __call__(self, data: bytes, tx: Transformer) -> Tuple[Any, ...]:
+ ...
+
+
class BaseCopy(Generic[ConnectionType]):
def __init__(self, cursor: "BaseCursor[ConnectionType]"):
self.cursor = cursor
self._write_buffer_size = 32 * 1024
self._finished = False
- self._format_row: FormatFunc
+ self._format_row: CopyFormatFunc
+ self._parse_row: CopyParseFunc
if self.format == Format.TEXT:
self._format_row = format_row_text
self._parse_row = parse_row_text
return b""
res = yield from copy_from(self._pgconn)
- if isinstance(res, bytes):
+ if isinstance(res, memoryview):
return res
# res is the final PGresult
return None
if self.format == Format.BINARY:
if not self._signature_sent:
- assert data.startswith(_binary_signature)
+ if data[: len(_binary_signature)] != _binary_signature:
+ raise e.DataError(
+ "binary copy doesn't start with the expected signature"
+ )
self._signature_sent = True
data = data[len(_binary_signature) :]
elif data == _binary_trailer:
return out
-def parse_row_text(data: bytes, tx: Transformer) -> Tuple[Any, ...]:
+def _parse_row_text(data: bytes, tx: Transformer) -> Tuple[Any, ...]:
+ if not isinstance(data, bytes):
+ data = bytes(data)
fields = data.split(b"\t")
fields[-1] = fields[-1][:-1] # drop \n
row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields]
# Override it with fast object if available
-format_row_binary: FormatFunc
+format_row_binary: CopyFormatFunc
+parse_row_text: CopyParseFunc
if pq.__impl__ == "c":
from psycopg3_c import _psycopg3
format_row_text = _psycopg3.format_row_text
format_row_binary = _psycopg3.format_row_binary
+ parse_row_text = _psycopg3.parse_row_text
else:
format_row_text = _format_row_text
format_row_binary = _format_row_binary
+ parse_row_text = _parse_row_text
from libc.stdint cimport uint16_t, uint32_t, int32_t
from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize
from cpython.bytearray cimport PyByteArray_AS_STRING, PyByteArray_GET_SIZE
+from cpython.memoryview cimport PyMemoryView_FromObject
from psycopg3_c._psycopg3.endian cimport htobe16, htobe32
+from psycopg3_c.pq cimport ViewBuffer
cdef int32_t _binary_null = -1
# Now from pos to pos + size there is a textual representation: it may
# contain chars to escape. Scan to find how many such chars there are.
for j in range(size):
- nesc += copy_escape_char[target[j]]
+ nesc += copy_escape_lut[target[j]]
# If there is any char to escape, walk backwards pushing the chars
# forward and interspersing backslashes.
target = <unsigned char *>CDumper.ensure_size(out, pos, tmpsize)
for j in range(size - 1, -1, -1):
target[j + nesc] = target[j]
- if copy_escape_char[target[j]] != 0:
+ if copy_escape_lut[target[j]] != 0:
nesc -= 1
target[j + nesc] = b"\\"
if nesc <= 0:
return out
+def parse_row_text(data, tx: Transformer) -> Tuple[Any, ...]:
+ cdef unsigned char *fstart
+ cdef Py_ssize_t size
+ _buffer_as_string_and_size(data, <char **>&fstart, &size)
+
+ # politely assume that the number of fields will be what in the result
+ cdef int nfields = tx._nfields
+ cdef list row = PyList_New(nfields)
+
+ cdef unsigned char *fend
+ cdef unsigned char *rowend = fstart + size
+ cdef unsigned char *src
+ cdef unsigned char *tgt
+ cdef int col = 0
+ cdef int num_bs
+ for col in range(nfields):
+ fend = fstart
+ num_bs = 0
+ # Scan to the end of the field, remember if you see any backslash
+ while fend[0] != b'\t' and fend[0] != b'\n' and fend < rowend:
+ if fend[0] == b'\\':
+ num_bs += 1
+ # skip the next char to avoid counting escaped backslashes twice
+ fend += 1
+ fend += 1
+
+ # Check if we stopped for the right reason
+ if fend >= rowend:
+ raise ValueError("bad copy format, field delimiter not found")
+ elif fend[0] == b'\t' and col == nfields - 1:
+ raise ValueError("bad copy format, got a tab at the end of the row")
+ elif fend[0] == b'\n' and col != nfields - 1:
+ raise ValueError(
+ "bad copy format, got a newline before the end of the row")
+
+ # Is this a NULL?
+ if fend - fstart == 2 and fstart[0] == b'\\' and fstart[1] == b'N':
+ field = None
+
+ # Is this a field with no backslash?
+ elif num_bs == 0:
+ # Nothing to unescape: we don't need a copy
+ field = PyMemoryView_FromObject(
+ ViewBuffer._from_buffer(fstart, fend - fstart))
+
+ # This is a field containing backslashes
+ else:
+ # We need a copy of the buffer to unescape
+ field = PyByteArray_FromStringAndSize("", 0)
+ PyByteArray_Resize(field, fend - fstart - num_bs)
+ tgt = <unsigned char *>PyByteArray_AS_STRING(field)
+ src = fstart
+ while (src < fend):
+ if src[0] != b'\\':
+ tgt[0] = src[0]
+ else:
+ src += 1
+ tgt[0] = copy_unescape_lut[src[0]]
+ src += 1
+ tgt += 1
+
+ Py_INCREF(field)
+ PyList_SET_ITEM(row, col, field)
+
+ # Start of the field
+ fstart = fend + 1
+
+ # Convert the array of buffers into Python objects
+ return tx.load_sequence(row)
+
+
cdef extern from *:
"""
-/* The characters to escape in textual copy */
+/* handle chars to (un)escape in text copy representation */
/* '\b', '\t', '\n', '\v', '\f', '\r', '\\' */
-static const char copy_escape_char[] = {
+
+/* Which char to prepend a backslash when escaping */
+static const char copy_escape_lut[] = {
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+};
+
+/* Conversion of escaped to unescaped chars */
+static const char copy_unescape_lut[] = {
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
+ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
+ 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
+ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
+ 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
+ 96, 97, 8, 99, 100, 101, 12, 103, 104, 105, 106, 107, 108, 109, 10, 111,
+112, 113, 13, 115, 9, 117, 11, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
+144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159,
+160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
+176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191,
+192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207,
+208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223,
+224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
+240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255,
};
"""
- const char[256] copy_escape_char
+ const char[256] copy_escape_lut
+ const char[256] copy_unescape_lut