]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat(copy): allow bytearray/memoryview as copy.write() input
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 26 Mar 2022 00:56:49 +0000 (01:56 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 26 Mar 2022 14:07:57 +0000 (15:07 +0100)
The C implementation can deal with these types efficiently and it may
save a memcopy if that's what the user has available.

Close #254

docs/api/cursors.rst
docs/news.rst
psycopg/psycopg/copy.py
psycopg/psycopg/pq/pq_ctypes.py
tests/test_copy.py
tests/test_copy_async.py

index 74d9a16344a5ccd9034432bfdc8ece3815006b8d..6c5e58e64947389711961a9d9e5b8e1367b947ab 100644 (file)
@@ -474,6 +474,11 @@ COPY-related objects
         see :ref:`adaptation` for details.
 
     .. automethod:: write
+
+        .. versionchanged:: 3.1
+
+            accept `bytearray` and `memoryview` data as input too.
+
     .. automethod:: read
 
         Instead of using `!read()` you can iterate on the `!Copy` object to
index f6c92f49cdc7d4df8048a805fc83976c68435598..ef25cc804d11865b639fa7a4978e61e9e9b1b582 100644 (file)
@@ -19,6 +19,8 @@ Psycopg 3.1 (unreleased)
 - Add `pq.PGconn.trace()` and related trace functions (:ticket:`#167`).
 - Add *prepare_threshold* parameter to `Connection` init (:ticket:`#200`).
 - Add `Error.pgconn` and `Error.pgresult` attributes (:ticket:`#242`).
+- Allow `bytearray`/`memoryview` data too as `Copy.write()` input
+  (:ticket:`#254`).
 - Drop support for Python 3.6.
 
 
index b8c1818962d7f44a91a83b04f8a53add76c779b5..abd7addaeb219d0bfed8525b82205ccc1aaf191b 100644 (file)
@@ -17,7 +17,7 @@ from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple
 from . import pq
 from . import errors as e
 from .pq import ExecStatus
-from .abc import ConnectionType, PQGen, Transformer
+from .abc import Buffer, ConnectionType, PQGen, Transformer
 from .adapt import PyFormat
 from ._compat import create_task
 from ._cmodule import _psycopg
@@ -252,7 +252,7 @@ class Copy(BaseCopy["Connection[Any]"]):
         """
         return self.connection.wait(self._read_row_gen())
 
-    def write(self, buffer: Union[str, bytes]) -> None:
+    def write(self, buffer: Union[Buffer, str]) -> None:
         """
         Write a block of data to a table after a :sql:`COPY FROM` operation.
 
@@ -300,7 +300,7 @@ class Copy(BaseCopy["Connection[Any]"]):
             # Propagate the error to the main thread.
             self._worker_error = ex
 
-    def _write(self, data: bytes) -> None:
+    def _write(self, data: Buffer) -> None:
         if not data:
             return
 
@@ -380,7 +380,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
     async def read_row(self) -> Optional[Tuple[Any, ...]]:
         return await self.connection.wait(self._read_row_gen())
 
-    async def write(self, buffer: Union[str, bytes]) -> None:
+    async def write(self, buffer: Union[Buffer, str]) -> None:
         data = self.formatter.write(buffer)
         await self._write(data)
 
@@ -410,7 +410,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
                 break
             await self.connection.wait(copy_to(self._pgconn, data))
 
-    async def _write(self, data: bytes) -> None:
+    async def _write(self, data: Buffer) -> None:
         if not data:
             return
 
@@ -455,7 +455,7 @@ class Formatter(ABC):
         ...
 
     @abstractmethod
-    def write(self, buffer: Union[str, bytes]) -> bytes:
+    def write(self, buffer: Union[Buffer, str]) -> bytes:
         ...
 
     @abstractmethod
@@ -481,7 +481,7 @@ class TextFormatter(Formatter):
         else:
             return None
 
-    def write(self, buffer: Union[str, bytes]) -> bytes:
+    def write(self, buffer: Union[Buffer, str]) -> Buffer:
         data = self._ensure_bytes(buffer)
         self._signature_sent = True
         return data
@@ -502,15 +502,14 @@ class TextFormatter(Formatter):
         buffer, self._write_buffer = self._write_buffer, bytearray()
         return buffer
 
-    def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
-        if isinstance(data, bytes):
-            return data
-
-        elif isinstance(data, str):
+    def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
+        if isinstance(data, str):
             return data.encode(self._encoding)
-
         else:
-            raise TypeError(f"can't write {type(data).__name__}")
+            # Assume, for simplicity, that the user is not passing stupid
+            # things to the write function. If that's the case, things
+            # will fail downstream.
+            return data
 
 
 class BinaryFormatter(Formatter):
@@ -535,7 +534,7 @@ class BinaryFormatter(Formatter):
 
         return parse_row_binary(data, self.transformer)
 
-    def write(self, buffer: Union[str, bytes]) -> bytes:
+    def write(self, buffer: Union[Buffer, str]) -> Buffer:
         data = self._ensure_bytes(buffer)
         self._signature_sent = True
         return data
@@ -575,15 +574,14 @@ class BinaryFormatter(Formatter):
         buffer, self._write_buffer = self._write_buffer, bytearray()
         return buffer
 
-    def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
-        if isinstance(data, bytes):
-            return data
-
-        elif isinstance(data, str):
+    def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
+        if isinstance(data, str):
             raise TypeError("cannot copy str data in binary mode: use bytes instead")
-
         else:
-            raise TypeError(f"can't write {type(data).__name__}")
+            # Assume, for simplicity, that the user is not passing stupid
+            # things to the write function. If that's the case, things
+            # will fail downstream.
+            return data
 
 
 def _format_row_text(
index a8acc3d85f62067dbc814881ffc884b2c3125cd1..33c607449dc208e0b5392958bea017fbf124d11d 100644 (file)
@@ -570,8 +570,7 @@ class PGconn:
         else:
             return None
 
-    def put_copy_data(self, buffer: bytes) -> int:
-        # TODO: should be done without copy
+    def put_copy_data(self, buffer: "abc.Buffer") -> int:
         if not isinstance(buffer, bytes):
             buffer = bytes(buffer)
         rv = impl.PQputCopyData(self._pgconn_ptr, buffer, len(buffer))
index e506ad06e5f70d6457615500c49d3af97927588c..52e3e968d8f06dcb0bcafab68adb8cc4791602bf 100644 (file)
@@ -278,12 +278,14 @@ def test_copy_big_size_record(conn):
 
 
 @pytest.mark.slow
-def test_copy_big_size_block(conn):
+@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview])
+def test_copy_big_size_block(conn, pytype):
     cur = conn.cursor()
     ensure_table(cur, sample_tabledef)
     data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+    copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n")
     with cur.copy("copy copy_in (data) from stdin") as copy:
-        copy.write(data + "\n")
+        copy.write(copy_data)
 
     cur.execute("select data from copy_in limit 1")
     assert cur.fetchone()[0] == data
@@ -468,14 +470,15 @@ def test_copy_from_to(conn):
 
 
 @pytest.mark.slow
-def test_copy_from_to_bytes(conn):
+@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview])
+def test_copy_from_to_bytes(conn, pytype):
     # Roundtrip from file to database to file blockwise
     gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024)
     gen.ensure_table()
     cur = conn.cursor()
     with cur.copy("copy copy_in from stdin") as copy:
         for block in gen.blocks():
-            copy.write(block.encode())
+            copy.write(pytype(block.encode()))
 
     gen.assert_data()
 
index 271e92105a0021c2725c82cc4448544741a2c061..6fc33b1c20b29645a1946cd2acfbbdc24539e5ca 100644 (file)
@@ -268,12 +268,14 @@ async def test_copy_big_size_record(aconn):
 
 
 @pytest.mark.slow
-async def test_copy_big_size_block(aconn):
+@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview])
+async def test_copy_big_size_block(aconn, pytype):
     cur = aconn.cursor()
     await ensure_table(cur, sample_tabledef)
     data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+    copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n")
     async with cur.copy("copy copy_in (data) from stdin") as copy:
-        await copy.write(data + "\n")
+        await copy.write(copy_data)
 
     await cur.execute("select data from copy_in limit 1")
     assert await cur.fetchone() == (data,)
@@ -467,14 +469,15 @@ async def test_copy_from_to(aconn):
 
 
 @pytest.mark.slow
-async def test_copy_from_to_bytes(aconn):
+@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview])
+async def test_copy_from_to_bytes(aconn, pytype):
     # Roundtrip from file to database to file blockwise
     gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024)
     await gen.ensure_table()
     cur = aconn.cursor()
     async with cur.copy("copy copy_in from stdin") as copy:
         for block in gen.blocks():
-            await copy.write(block.encode())
+            await copy.write(pytype(block.encode()))
 
     await gen.assert_data()