]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added C implementation of binary copy load
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 10 Jan 2021 03:46:26 +0000 (04:46 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 10 Jan 2021 03:52:32 +0000 (04:52 +0100)
Return value of the copy read function redefined as memoryview to avoid
unneeded data copying.

docs/cursor.rst
psycopg3/psycopg3/copy.py
psycopg3_c/psycopg3_c/_psycopg3.pyi
psycopg3_c/psycopg3_c/_psycopg3/copy.pyx
psycopg3_c/psycopg3_c/_psycopg3/transform.pyx
tests/test_copy.py
tests/test_copy_async.py

index 6f5368dc3d0c4ec1b646c488410b44f7bc6da20d..4e1e4bc7d16507ed622f890f471cb3d000cb684c 100644 (file)
@@ -195,10 +195,13 @@ Cursor support objects
     .. automethod:: write
     .. automethod:: read
 
-        Instead of using `!read()` you can even iterate on the object to read
-        its data row by row, using ``for row in copy: ...``.
+        Instead of using `!read()` you can iterate on the `!Copy` object to
+        read its data row by row, using ``for row in copy: ...``.
 
     .. automethod:: rows
+
+        Equivalent of iterating on `read_row()` until it returns `!None`
+
     .. automethod:: read_row
     .. automethod:: set_types
 
@@ -213,8 +216,8 @@ Cursor support objects
     .. automethod:: write
     .. automethod:: read
 
-        Instead of using `!read()` you can even iterate on the object to read
-        its data row by row, using ``async for row in copy: ...``.
+        Instead of using `!read()` you can iterate on the `!AsyncCopy` object
+        to read its data row by row, using ``async for row in copy: ...``.
 
     .. automethod:: rows
 
index 0d4057dbcaacd22f75a15565ebb068e7d0debd11..1dc3b6880a270580cac60992201bb40696fe0925 100644 (file)
@@ -9,7 +9,6 @@ import struct
 from types import TracebackType
 from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union
 from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple
-from typing_extensions import Protocol
 
 from . import pq
 from . import errors as e
@@ -23,23 +22,6 @@ if TYPE_CHECKING:
     from .connection import Connection, AsyncConnection  # noqa: F401
 
 
-class CopyFormatFunc(Protocol):
-    """The type of a function to format copy data to a bytearray."""
-
-    def __call__(
-        self,
-        row: Sequence[Any],
-        tx: Transformer,
-        out: Optional[bytearray] = None,
-    ) -> bytearray:
-        ...
-
-
-class CopyParseFunc(Protocol):
-    def __call__(self, data: bytes, tx: Transformer) -> Tuple[Any, ...]:
-        ...
-
-
 class BaseCopy(Generic[ConnectionType]):
     def __init__(self, cursor: "BaseCursor[ConnectionType]"):
         self.cursor = cursor
@@ -60,8 +42,6 @@ class BaseCopy(Generic[ConnectionType]):
         self._write_buffer_size = 32 * 1024
         self._finished = False
 
-        self._format_row: CopyFormatFunc
-        self._parse_row: CopyParseFunc
         if self.format == Format.TEXT:
             self._format_row = format_row_text
             self._parse_row = parse_row_text
@@ -85,9 +65,9 @@ class BaseCopy(Generic[ConnectionType]):
 
     # High level copy protocol generators (state change of the Copy object)
 
-    def _read_gen(self) -> PQGen[bytes]:
+    def _read_gen(self) -> PQGen[memoryview]:
         if self._finished:
-            return b""
+            return memoryview(b"")
 
         res = yield from copy_from(self._pgconn)
         if isinstance(res, memoryview):
@@ -97,7 +77,7 @@ class BaseCopy(Generic[ConnectionType]):
         self._finished = True
         nrows = res.command_tuples
         self.cursor._rowcount = nrows if nrows is not None else -1
-        return b""
+        return memoryview(b"")
 
     def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]:
         data = yield from self._read_gen()
@@ -208,11 +188,11 @@ class Copy(BaseCopy["Connection"]):
 
     __module__ = "psycopg3"
 
-    def read(self) -> bytes:
+    def read(self) -> memoryview:
         """
-        Read an unparsed row from a table after a :sql:`COPY TO` operation.
+        Read an unparsed row after a :sql:`COPY TO` operation.
 
-        Return an empty bytes string when the data is finished.
+        Return an empty string when the data is finished.
         """
         return self.connection.wait(self._read_gen())
 
@@ -269,7 +249,7 @@ class Copy(BaseCopy["Connection"]):
     ) -> None:
         self.connection.wait(self._exit_gen(exc_type, exc_val))
 
-    def __iter__(self) -> Iterator[bytes]:
+    def __iter__(self) -> Iterator[memoryview]:
         while True:
             data = self.read()
             if not data:
@@ -282,7 +262,7 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
 
     __module__ = "psycopg3"
 
-    async def read(self) -> bytes:
+    async def read(self) -> memoryview:
         return await self.connection.wait(self._read_gen())
 
     async def rows(self) -> AsyncIterator[Tuple[Any, ...]]:
@@ -316,7 +296,7 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
     ) -> None:
         await self.connection.wait(self._exit_gen(exc_type, exc_val))
 
-    async def __aiter__(self) -> AsyncIterator[bytes]:
+    async def __aiter__(self) -> AsyncIterator[memoryview]:
         while True:
             data = await self.read()
             if not data:
@@ -377,7 +357,7 @@ def _parse_row_text(data: bytes, tx: Transformer) -> Tuple[Any, ...]:
     return tx.load_sequence(row)
 
 
-def parse_row_binary(data: bytes, tx: Transformer) -> Tuple[Any, ...]:
+def _parse_row_binary(data: bytes, tx: Transformer) -> Tuple[Any, ...]:
     row: List[Optional[bytes]] = []
     nfields = _unpack_int2(data, 0)[0]
     pos = 2
@@ -435,19 +415,17 @@ def _load_sub(
     return __map[m.group(0)]
 
 
-# Override it with fast object if available
-
-format_row_binary: CopyFormatFunc
-parse_row_text: CopyParseFunc
-
+# Override functions with fast versions if available
 if pq.__impl__ == "c":
     from psycopg3_c import _psycopg3
 
     format_row_text = _psycopg3.format_row_text
     format_row_binary = _psycopg3.format_row_binary
     parse_row_text = _psycopg3.parse_row_text
+    parse_row_binary = _psycopg3.parse_row_binary
 
 else:
     format_row_text = _format_row_text
     format_row_binary = _format_row_binary
     parse_row_text = _parse_row_text
+    parse_row_binary = _parse_row_binary
index ae4c6419f590e3fed7f1a5641b5df90a2e87c93c..dad01fabb0f6dd68637fa7c052f3c69e5ab094dd 100644 (file)
@@ -51,5 +51,8 @@ def format_row_binary(
     row: Sequence[Any], tx: proto.Transformer, out: Optional[bytearray] = None
 ) -> bytearray: ...
 def parse_row_text(data: bytes, tx: proto.Transformer) -> Tuple[Any, ...]: ...
+def parse_row_binary(
+    data: bytes, tx: proto.Transformer
+) -> Tuple[Any, ...]: ...
 
 # vim: set syntax=python:
index 16171b6f1d3196368380e1da14f04c293bb0b0e1..50e595bfef1b6e1868d69ab0e89d93101cb7b5b5 100644 (file)
@@ -11,9 +11,11 @@ from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize
 from cpython.bytearray cimport PyByteArray_AS_STRING, PyByteArray_GET_SIZE
 from cpython.memoryview cimport PyMemoryView_FromObject
 
-from psycopg3_c._psycopg3.endian cimport htobe16, htobe32
+from psycopg3_c._psycopg3 cimport endian
 from psycopg3_c.pq cimport ViewBuffer
 
+from psycopg3 import errors as e
+
 cdef int32_t _binary_null = -1
 
 
@@ -22,7 +24,7 @@ def format_row_binary(
 ) -> bytearray:
     """Convert a row of adapted data to the data to send for binary copy"""
     cdef Py_ssize_t rowlen = len(row)
-    cdef uint16_t berowlen = htobe16(rowlen)
+    cdef uint16_t berowlen = endian.htobe16(rowlen)
 
     cdef Py_ssize_t pos  # offset in 'out' where to write
     if out is None:
@@ -54,7 +56,7 @@ def format_row_binary(
                 # A cdumper can resize if necessary and copy in place
                 size = (<CDumper>dumper).cdump(item, out, pos + sizeof(besize))
                 # Also add the size of the item, before the item
-                besize = htobe32(size)
+                besize = endian.htobe32(size)
                 target = PyByteArray_AS_STRING(out)  # might have been moved by cdump
                 memcpy(target + pos, <void *>&besize, sizeof(besize))
             else:
@@ -62,7 +64,7 @@ def format_row_binary(
                 b = PyObject_CallFunctionObjArgs(dumper.dump, <PyObject *>item, NULL)
                 _buffer_as_string_and_size(b, &buf, &size)
                 target = CDumper.ensure_size(out, pos, size + sizeof(besize))
-                besize = htobe32(size)
+                besize = endian.htobe32(size)
                 memcpy(target, <void *>&besize, sizeof(besize))
                 memcpy(target + sizeof(besize), buf, size)
 
@@ -164,6 +166,40 @@ def format_row_text(
     return out
 
 
+def parse_row_binary(data, tx: Transformer) -> Tuple[Any, ...]:
+    cdef unsigned char *ptr
+    cdef Py_ssize_t bufsize
+    _buffer_as_string_and_size(data, <char **>&ptr, &bufsize)
+    cdef unsigned char *bufend = ptr + bufsize
+
+    cdef uint16_t benfields = (<uint16_t *>ptr)[0]
+    cdef int nfields = endian.be16toh(benfields)
+    ptr += sizeof(benfields)
+    cdef list row = PyList_New(nfields)
+
+    cdef int col
+    cdef int32_t belength
+    cdef Py_ssize_t length
+
+    for col in range(nfields):
+        memcpy(&belength, ptr, sizeof(belength))
+        ptr += sizeof(belength)
+        if belength == _binary_null:
+            field = None
+        else:
+            length = endian.be32toh(belength)
+            if ptr + length > bufend:
+                raise e.DataError("bad copy data: length exceeding data")
+            field = PyMemoryView_FromObject(
+                ViewBuffer._from_buffer(ptr, length))
+            ptr += length
+
+        Py_INCREF(field)
+        PyList_SET_ITEM(row, col, field)
+
+    return tx.load_sequence(row)
+
+
 def parse_row_text(data, tx: Transformer) -> Tuple[Any, ...]:
     cdef unsigned char *fstart
     cdef Py_ssize_t size
@@ -177,8 +213,9 @@ def parse_row_text(data, tx: Transformer) -> Tuple[Any, ...]:
     cdef unsigned char *rowend = fstart + size
     cdef unsigned char *src
     cdef unsigned char *tgt
-    cdef int col = 0
+    cdef int col
     cdef int num_bs
+
     for col in range(nfields):
         fend = fstart
         num_bs = 0
@@ -192,12 +229,12 @@ def parse_row_text(data, tx: Transformer) -> Tuple[Any, ...]:
 
         # Check if we stopped for the right reason
         if fend >= rowend:
-            raise ValueError("bad copy format, field delimiter not found")
+            raise e.DataError("bad copy data: field delimiter not found")
         elif fend[0] == b'\t' and col == nfields - 1:
-            raise ValueError("bad copy format, got a tab at the end of the row")
+            raise e.DataError("bad copy data: got a tab at the end of the row")
         elif fend[0] == b'\n' and col != nfields - 1:
-            raise ValueError(
-                "bad copy format, got a newline before the end of the row")
+            raise e.DataError(
+                "bad copy format: got a newline before the end of the row")
 
         # Is this a NULL?
         if fend - fstart == 2 and fstart[0] == b'\\' and fstart[1] == b'N':
index b2fecf3253204a2c9f6c9e7b08b3af91aad46522..8b33552696fdc1868f8d9bd2b6fae74a0022027f 100644 (file)
@@ -324,7 +324,7 @@ cdef class Transformer:
 
         return record
 
-    def load_sequence(self, record: Sequence[Optional[bytes]]) -> Tuple[Any, ...]:
+    cpdef object load_sequence(self, record: Sequence[Optional[bytes]]):
         cdef int nfields = len(record)
         out = PyTuple_New(nfields)
         cdef PyObject *loader  # borrowed RowLoader
index 31498b74cbd46224bba8ee0ba4edd56c10dae7f4..96bf4775e96e3765ae4ea9bb24a8c58893364a82 100644 (file)
@@ -369,10 +369,10 @@ def test_copy_from_to(conn):
 
     gen.assert_data()
 
-    f = StringIO()
+    f = BytesIO()
     with cur.copy("copy copy_in to stdout") as copy:
         for block in copy:
-            f.write(block.decode("utf8"))
+            f.write(block)
 
     f.seek(0)
     assert gen.sha(f) == gen.sha(gen.file())
index f8999c719a909534097e0e4c5911c79502540e87..9c6f33c30ada969af68caae65f0190907fdd18c3 100644 (file)
@@ -337,10 +337,10 @@ async def test_copy_from_to(aconn):
 
     await gen.assert_data()
 
-    f = StringIO()
+    f = BytesIO()
     async with cur.copy("copy copy_in to stdout") as copy:
         async for block in copy:
-            f.write(block.decode("utf8"))
+            f.write(block)
 
     f.seek(0)
     assert gen.sha(f) == gen.sha(gen.file())