]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added C implementation of text copy load
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 10 Jan 2021 01:56:10 +0000 (02:56 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 10 Jan 2021 02:45:17 +0000 (03:45 +0100)
psycopg3/psycopg3/copy.py
psycopg3/psycopg3/generators.py
psycopg3/psycopg3/pq/pq_ctypes.py
psycopg3/psycopg3/pq/proto.py
psycopg3/psycopg3/types/text.py
psycopg3_c/psycopg3_c/_psycopg3.pyi
psycopg3_c/psycopg3_c/_psycopg3/copy.pyx
psycopg3_c/psycopg3_c/pq.pxd
psycopg3_c/psycopg3_c/pq/pgconn.pyx
psycopg3_c/psycopg3_c/pq/pqbuffer.pyx
psycopg3_c/psycopg3_c/types/numeric.pyx

index 0d91961d9aebafc3686972680d1a42948896bc36..0d4057dbcaacd22f75a15565ebb068e7d0debd11 100644 (file)
@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple
 from typing_extensions import Protocol
 
 from . import pq
+from . import errors as e
 from .pq import Format, ExecStatus
 from .proto import ConnectionType, PQGen, Transformer
 from .generators import copy_from, copy_to, copy_end
@@ -22,7 +23,7 @@ if TYPE_CHECKING:
     from .connection import Connection, AsyncConnection  # noqa: F401
 
 
-class FormatFunc(Protocol):
+class CopyFormatFunc(Protocol):
     """The type of a function to format copy data to a bytearray."""
 
     def __call__(
@@ -34,6 +35,11 @@ class FormatFunc(Protocol):
         ...
 
 
+class CopyParseFunc(Protocol):
+    def __call__(self, data: bytes, tx: Transformer) -> Tuple[Any, ...]:
+        ...
+
+
 class BaseCopy(Generic[ConnectionType]):
     def __init__(self, cursor: "BaseCursor[ConnectionType]"):
         self.cursor = cursor
@@ -54,7 +60,8 @@ class BaseCopy(Generic[ConnectionType]):
         self._write_buffer_size = 32 * 1024
         self._finished = False
 
-        self._format_row: FormatFunc
+        self._format_row: CopyFormatFunc
+        self._parse_row: CopyParseFunc
         if self.format == Format.TEXT:
             self._format_row = format_row_text
             self._parse_row = parse_row_text
@@ -83,7 +90,7 @@ class BaseCopy(Generic[ConnectionType]):
             return b""
 
         res = yield from copy_from(self._pgconn)
-        if isinstance(res, bytes):
+        if isinstance(res, memoryview):
             return res
 
         # res is the final PGresult
@@ -98,7 +105,10 @@ class BaseCopy(Generic[ConnectionType]):
             return None
         if self.format == Format.BINARY:
             if not self._signature_sent:
-                assert data.startswith(_binary_signature)
+                if data[: len(_binary_signature)] != _binary_signature:
+                    raise e.DataError(
+                        "binary copy doesn't start with the expected signature"
+                    )
                 self._signature_sent = True
                 data = data[len(_binary_signature) :]
             elif data == _binary_trailer:
@@ -358,7 +368,9 @@ def _format_row_binary(
     return out
 
 
-def parse_row_text(data: bytes, tx: Transformer) -> Tuple[Any, ...]:
+def _parse_row_text(data: bytes, tx: Transformer) -> Tuple[Any, ...]:
+    if not isinstance(data, bytes):
+        data = bytes(data)
     fields = data.split(b"\t")
     fields[-1] = fields[-1][:-1]  # drop \n
     row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields]
@@ -425,14 +437,17 @@ def _load_sub(
 
 # Override it with fast object if available
 
-format_row_binary: FormatFunc
+format_row_binary: CopyFormatFunc
+parse_row_text: CopyParseFunc
 
 if pq.__impl__ == "c":
     from psycopg3_c import _psycopg3
 
     format_row_text = _psycopg3.format_row_text
     format_row_binary = _psycopg3.format_row_binary
+    parse_row_text = _psycopg3.parse_row_text
 
 else:
     format_row_text = _format_row_text
     format_row_binary = _format_row_binary
+    parse_row_text = _parse_row_text
index e442c2df1719b1062da2f2bdb08b9f7fdd8ad2b9..e27fdc565f0dc2ab7635c19355add7065659fb7b 100644 (file)
@@ -161,7 +161,7 @@ def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]:
     return ns
 
 
-def copy_from(pgconn: PGconn) -> PQGen[Union[bytes, PGresult]]:
+def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]:
     while 1:
         nbytes, data = pgconn.get_copy_data(1)
         if nbytes != 0:
index 2ef92fbf685dd0fe56f16185bfbe989469ace192..18289859c1c89ba1065fa35292644cf668635a4e 100644 (file)
@@ -535,7 +535,7 @@ class PGconn:
             raise PQerror(f"sending copy end failed: {error_message(self)}")
         return rv
 
-    def get_copy_data(self, async_: int) -> Tuple[int, bytes]:
+    def get_copy_data(self, async_: int) -> Tuple[int, memoryview]:
         buffer_ptr = c_char_p()
         nbytes = impl.PQgetCopyData(self.pgconn_ptr, byref(buffer_ptr), async_)
         if nbytes == -2:
@@ -544,9 +544,9 @@ class PGconn:
             # TODO: do it without copy
             data = string_at(buffer_ptr, nbytes)
             impl.PQfreemem(buffer_ptr)
-            return nbytes, data
+            return nbytes, memoryview(data)
         else:
-            return nbytes, b""
+            return nbytes, memoryview(b"")
 
     def make_empty_result(self, exec_status: int) -> "PGresult":
         rv = impl.PQmakeEmptyPGresult(self.pgconn_ptr, exec_status)
index f0afeb4767aea102f4bfe62040940265219890b7..6cef8fe2478a7b38b7cc5f7237e37a71c0f07c9a 100644 (file)
@@ -226,7 +226,7 @@ class PGconn(Protocol):
     def put_copy_end(self, error: Optional[bytes] = None) -> int:
         ...
 
-    def get_copy_data(self, async_: int) -> Tuple[int, bytes]:
+    def get_copy_data(self, async_: int) -> Tuple[int, memoryview]:
         ...
 
     def make_empty_result(self, exec_status: int) -> "PGresult":
index d671bc47492631ace288965e4406a548066e820b..06917330fa28a1a32c6b2b614653ccafbb285004 100644 (file)
@@ -66,7 +66,10 @@ class TextLoader(Loader):
 
     def load(self, data: bytes) -> Union[bytes, str]:
         if self._encoding:
-            return data.decode(self._encoding)
+            if isinstance(data, memoryview):
+                return bytes(data).decode(self._encoding)
+            else:
+                return data.decode(self._encoding)
         else:
             # return bytes for SQL_ASCII db
             return data
index b1043a528963304fd88314e408559f12b1821ea5..ae4c6419f590e3fed7f1a5641b5df90a2e87c93c 100644 (file)
@@ -50,5 +50,6 @@ def format_row_text(
 def format_row_binary(
     row: Sequence[Any], tx: proto.Transformer, out: Optional[bytearray] = None
 ) -> bytearray: ...
+def parse_row_text(data: bytes, tx: proto.Transformer) -> Tuple[Any, ...]: ...
 
 # vim: set syntax=python:
index 009e6ea298e865a47381d63b6cf8e67e7e4ee95e..16171b6f1d3196368380e1da14f04c293bb0b0e1 100644 (file)
@@ -9,8 +9,10 @@ from libc.string cimport memcpy
 from libc.stdint cimport uint16_t, uint32_t, int32_t
 from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize
 from cpython.bytearray cimport PyByteArray_AS_STRING, PyByteArray_GET_SIZE
+from cpython.memoryview cimport PyMemoryView_FromObject
 
 from psycopg3_c._psycopg3.endian cimport htobe16, htobe32
+from psycopg3_c.pq cimport ViewBuffer
 
 cdef int32_t _binary_null = -1
 
@@ -138,7 +140,7 @@ def format_row_text(
         # Now from pos to pos + size there is a textual representation: it may
         # contain chars to escape. Scan to find how many such chars there are.
         for j in range(size):
-            nesc += copy_escape_char[target[j]]
+            nesc += copy_escape_lut[target[j]]
 
         # If there is any char to escape, walk backwards pushing the chars
         # forward and interspersing backslashes.
@@ -147,7 +149,7 @@ def format_row_text(
             target = <unsigned char *>CDumper.ensure_size(out, pos, tmpsize)
             for j in range(size - 1, -1, -1):
                 target[j + nesc] = target[j]
-                if copy_escape_char[target[j]] != 0:
+                if copy_escape_lut[target[j]] != 0:
                     nesc -= 1
                     target[j + nesc] = b"\\"
                     if nesc <= 0:
@@ -162,11 +164,84 @@ def format_row_text(
     return out
 
 
+def parse_row_text(data, tx: Transformer) -> Tuple[Any, ...]:
+    cdef unsigned char *fstart
+    cdef Py_ssize_t size
+    _buffer_as_string_and_size(data, <char **>&fstart, &size)
+
+    # politely assume that the number of fields will be what in the result
+    cdef int nfields = tx._nfields
+    cdef list row = PyList_New(nfields)
+
+    cdef unsigned char *fend
+    cdef unsigned char *rowend = fstart + size
+    cdef unsigned char *src
+    cdef unsigned char *tgt
+    cdef int col = 0
+    cdef int num_bs
+    for col in range(nfields):
+        fend = fstart
+        num_bs = 0
+        # Scan to the end of the field, remember if you see any backslash
+        while fend[0] != b'\t' and fend[0] != b'\n' and fend < rowend:
+            if fend[0] == b'\\':
+                num_bs += 1
+                # skip the next char to avoid counting escaped backslashes twice
+                fend += 1
+            fend += 1
+
+        # Check if we stopped for the right reason
+        if fend >= rowend:
+            raise ValueError("bad copy format, field delimiter not found")
+        elif fend[0] == b'\t' and col == nfields - 1:
+            raise ValueError("bad copy format, got a tab at the end of the row")
+        elif fend[0] == b'\n' and col != nfields - 1:
+            raise ValueError(
+                "bad copy format, got a newline before the end of the row")
+
+        # Is this a NULL?
+        if fend - fstart == 2 and fstart[0] == b'\\' and fstart[1] == b'N':
+            field = None
+
+        # Is this a field with no backslash?
+        elif num_bs == 0:
+            # Nothing to unescape: we don't need a copy
+            field = PyMemoryView_FromObject(
+                ViewBuffer._from_buffer(fstart, fend - fstart))
+
+        # This is a field containing backslashes
+        else:
+            # We need a copy of the buffer to unescape
+            field = PyByteArray_FromStringAndSize("", 0)
+            PyByteArray_Resize(field, fend - fstart - num_bs)
+            tgt = <unsigned char *>PyByteArray_AS_STRING(field)
+            src = fstart
+            while (src < fend):
+                if src[0] != b'\\':
+                    tgt[0] = src[0]
+                else:
+                    src += 1
+                    tgt[0] = copy_unescape_lut[src[0]]
+                src += 1
+                tgt += 1
+
+        Py_INCREF(field)
+        PyList_SET_ITEM(row, col, field)
+
+        # Start of the field
+        fstart = fend + 1
+
+    # Convert the array of buffers into Python objects
+    return tx.load_sequence(row)
+
+
 cdef extern from *:
     """
-/* The characters to escape in textual copy */
+/* handle chars to (un)escape in text copy representation */
 /* '\b', '\t', '\n', '\v', '\f', '\r', '\\' */
-static const char copy_escape_char[] = {
+
+/* Which char to prepend a backslash when escaping */
+static const char copy_escape_lut[] = {
     0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0,
     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
@@ -183,6 +258,27 @@ static const char copy_escape_char[] = {
     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+};
+
+/* Conversion of escaped to unescaped chars */
+static const char copy_unescape_lut[] = {
+  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,  15,
+ 16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,  29,  30,  31,
+ 32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,
+ 48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,
+ 64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,  79,
+ 80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,
+ 96,  97,   8,  99, 100, 101,  12, 103, 104, 105, 106, 107, 108, 109,  10, 111,
+112, 113,  13, 115,   9, 117,  11, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
+144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159,
+160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
+176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191,
+192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207,
+208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223,
+224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
+240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255,
 };
     """
-    const char[256] copy_escape_char
+    const char[256] copy_escape_lut
+    const char[256] copy_unescape_lut
index 3f2a1ae27a6350ac6bc4a30e0b20713a30c89f98..b19350f575aa14280251d605dc36138979d66cdc 100644 (file)
@@ -43,6 +43,14 @@ cdef class PQBuffer:
     cdef PQBuffer _from_buffer(unsigned char *buf, Py_ssize_t len)
 
 
+cdef class ViewBuffer:
+    cdef unsigned char *buf
+    cdef Py_ssize_t len
+
+    @staticmethod
+    cdef ViewBuffer _from_buffer(unsigned char *buf, Py_ssize_t len)
+
+
 cdef int _buffer_as_string_and_size(
     data: "Buffer", char **ptr, Py_ssize_t *length
 ) except -1
index d2560259cb30ec6aa5573416b0a857bfec4797ae..acd4ac2ccb855bca5c3270fcbf8615b2bef0306d 100644 (file)
@@ -7,10 +7,12 @@ psycopg3_c.pq.PGconn object implementation.
 from posix.unistd cimport getpid
 from cpython.mem cimport PyMem_Malloc, PyMem_Free
 from cpython.bytes cimport PyBytes_AsString, PyBytes_AsStringAndSize
+from cpython.memoryview cimport PyMemoryView_FromObject
 
 import logging
 
 from psycopg3.pq.misc import PGnotify, connection_summary
+from psycopg3_c.pq cimport PQBuffer
 
 logger = logging.getLogger('psycopg3')
 
@@ -445,19 +447,18 @@ cdef class PGconn:
             raise PQerror(f"sending copy end failed: {error_message(self)}")
         return rv
 
-    def get_copy_data(self, int async_) -> Tuple[int, bytes]:
+    def get_copy_data(self, int async_) -> Tuple[int, memoryview]:
         cdef char *buffer_ptr = NULL
         cdef int nbytes
         nbytes = libpq.PQgetCopyData(self.pgconn_ptr, &buffer_ptr, async_)
         if nbytes == -2:
             raise PQerror(f"receiving copy data failed: {error_message(self)}")
         if buffer_ptr is not NULL:
-            # TODO: do it without copy
-            data = buffer_ptr[:nbytes]
-            libpq.PQfreemem(buffer_ptr)
+            data = PyMemoryView_FromObject(
+                PQBuffer._from_buffer(<unsigned char *>buffer_ptr, nbytes))
             return nbytes, data
         else:
-            return nbytes, b""
+            return nbytes, b""  # won't parse it, doesn't really be memoryview
 
     def make_empty_result(self, int exec_status) -> PGresult:
         cdef libpq.PGresult *rv = libpq.PQmakeEmptyPGresult(
index 228201dfebb1f95ba8fa1b3e93ea4500a4a390a4..ad3f9a5e636d57bda6c7325c693e4cce6679131a 100644 (file)
@@ -51,6 +51,44 @@ cdef class PQBuffer:
         pass
 
 
+cdef class ViewBuffer:
+    """
+    Wrap a chunk of memory for view only.
+    """
+    @staticmethod
+    cdef ViewBuffer _from_buffer(unsigned char *buf, Py_ssize_t len):
+        cdef ViewBuffer rv = ViewBuffer.__new__(ViewBuffer)
+        rv.buf = buf
+        rv.len = len
+        return rv
+
+    def __cinit__(self):
+        self.buf = NULL
+        self.len = 0
+
+    def __repr__(self):
+        return (
+            f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+            f"({bytes(self)})"
+        )
+
+    def __getbuffer__(self, Py_buffer *buffer, int flags):
+        buffer.buf = self.buf
+        buffer.obj = self
+        buffer.len = self.len
+        buffer.itemsize = sizeof(unsigned char)
+        buffer.readonly = 1
+        buffer.ndim = 1
+        buffer.format = NULL  # unsigned char
+        buffer.shape = &self.len
+        buffer.strides = NULL
+        buffer.suboffsets = NULL
+        buffer.internal = NULL
+
+    def __releasebuffer__(self, Py_buffer *buffer):
+        pass
+
+
 cdef int _buffer_as_string_and_size(
     data: "Buffer", char **ptr, Py_ssize_t *length
 ) except -1:
@@ -65,4 +103,3 @@ cdef int _buffer_as_string_and_size(
         PyBuffer_Release(&buf)
     else:
         raise TypeError(f"bytes or buffer expected, got {type(data)}")
-
index 41761b976fc9c291479742ea3ce2907b7f69dfef..bd830ec812acad6a0c104d520c66903555bc6a6c 100644 (file)
@@ -108,7 +108,19 @@ cdef class IntLoader(CLoader):
     format = Format.TEXT
 
     cdef object cload(self, const char *data, size_t length):
-        return PyLong_FromString(data, NULL, 10)
+        # if the number ends with a 0 we don't need a copy
+        if data[length] == b'\0':
+            return PyLong_FromString(data, NULL, 10)
+
+        # Otherwise we have to copy it aside
+        if length > MAXINT8LEN:
+            raise ValueError("string too big for an int")
+
+        cdef char[21] buf   # MAXINT8LEN + 1
+        memcpy(buf, data, length)
+        buf[length] = 0
+        return PyLong_FromString(buf, NULL, 10)
+
 
 
 @cython.final