From: Daniele Varrazzo Date: Sat, 26 Mar 2022 00:56:49 +0000 (+0100) Subject: feat(copy): allow bytearray/memoryview as copy.write() input X-Git-Tag: 3.1~159 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=94b92b5c453ec2ac173cf0b0b1237240285e862a;p=thirdparty%2Fpsycopg.git feat(copy): allow bytearray/memoryview as copy.write() input The C implementation can deal with these types efficiently and it may save a memcopy if that's what the user has available. Close #254 --- diff --git a/docs/api/cursors.rst b/docs/api/cursors.rst index 74d9a1634..6c5e58e64 100644 --- a/docs/api/cursors.rst +++ b/docs/api/cursors.rst @@ -474,6 +474,11 @@ COPY-related objects see :ref:`adaptation` for details. .. automethod:: write + + .. versionchanged:: 3.1 + + accept `bytearray` and `memoryview` data as input too. + .. automethod:: read Instead of using `!read()` you can iterate on the `!Copy` object to diff --git a/docs/news.rst b/docs/news.rst index f6c92f49c..ef25cc804 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -19,6 +19,8 @@ Psycopg 3.1 (unreleased) - Add `pq.PGconn.trace()` and related trace functions (:ticket:`#167`). - Add *prepare_threshold* parameter to `Connection` init (:ticket:`#200`). - Add `Error.pgconn` and `Error.pgresult` attributes (:ticket:`#242`). +- Allow `bytearray`/`memoryview` data too as `Copy.write()` input + (:ticket:`#254`). - Drop support for Python 3.6. diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index b8c181896..abd7addae 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -17,7 +17,7 @@ from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple from . import pq from . import errors as e from .pq import ExecStatus -from .abc import ConnectionType, PQGen, Transformer +from .abc import Buffer, ConnectionType, PQGen, Transformer from .adapt import PyFormat from ._compat import create_task from ._cmodule import _psycopg @@ -252,7 +252,7 @@ class Copy(BaseCopy["Connection[Any]"]): """ return self.connection.wait(self._read_row_gen()) - def write(self, buffer: Union[str, bytes]) -> None: + def write(self, buffer: Union[Buffer, str]) -> None: """ Write a block of data to a table after a :sql:`COPY FROM` operation. @@ -300,7 +300,7 @@ class Copy(BaseCopy["Connection[Any]"]): # Propagate the error to the main thread. self._worker_error = ex - def _write(self, data: bytes) -> None: + def _write(self, data: Buffer) -> None: if not data: return @@ -380,7 +380,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): async def read_row(self) -> Optional[Tuple[Any, ...]]: return await self.connection.wait(self._read_row_gen()) - async def write(self, buffer: Union[str, bytes]) -> None: + async def write(self, buffer: Union[Buffer, str]) -> None: data = self.formatter.write(buffer) await self._write(data) @@ -410,7 +410,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): break await self.connection.wait(copy_to(self._pgconn, data)) - async def _write(self, data: bytes) -> None: + async def _write(self, data: Buffer) -> None: if not data: return @@ -455,7 +455,7 @@ class Formatter(ABC): ... @abstractmethod - def write(self, buffer: Union[str, bytes]) -> bytes: + def write(self, buffer: Union[Buffer, str]) -> bytes: ... @abstractmethod @@ -481,7 +481,7 @@ class TextFormatter(Formatter): else: return None - def write(self, buffer: Union[str, bytes]) -> bytes: + def write(self, buffer: Union[Buffer, str]) -> Buffer: data = self._ensure_bytes(buffer) self._signature_sent = True return data @@ -502,15 +502,14 @@ class TextFormatter(Formatter): buffer, self._write_buffer = self._write_buffer, bytearray() return buffer - def _ensure_bytes(self, data: Union[bytes, str]) -> bytes: - if isinstance(data, bytes): - return data - - elif isinstance(data, str): + def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer: + if isinstance(data, str): return data.encode(self._encoding) - else: - raise TypeError(f"can't write {type(data).__name__}") + # Assume, for simplicity, that the user is not passing stupid + # things to the write function. If that's the case, things + # will fail downstream. + return data class BinaryFormatter(Formatter): @@ -535,7 +534,7 @@ class BinaryFormatter(Formatter): return parse_row_binary(data, self.transformer) - def write(self, buffer: Union[str, bytes]) -> bytes: + def write(self, buffer: Union[Buffer, str]) -> Buffer: data = self._ensure_bytes(buffer) self._signature_sent = True return data @@ -575,15 +574,14 @@ class BinaryFormatter(Formatter): buffer, self._write_buffer = self._write_buffer, bytearray() return buffer - def _ensure_bytes(self, data: Union[bytes, str]) -> bytes: - if isinstance(data, bytes): - return data - - elif isinstance(data, str): + def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer: + if isinstance(data, str): raise TypeError("cannot copy str data in binary mode: use bytes instead") - else: - raise TypeError(f"can't write {type(data).__name__}") + # Assume, for simplicity, that the user is not passing stupid + # things to the write function. If that's the case, things + # will fail downstream. + return data def _format_row_text( diff --git a/psycopg/psycopg/pq/pq_ctypes.py b/psycopg/psycopg/pq/pq_ctypes.py index a8acc3d85..33c607449 100644 --- a/psycopg/psycopg/pq/pq_ctypes.py +++ b/psycopg/psycopg/pq/pq_ctypes.py @@ -570,8 +570,7 @@ class PGconn: else: return None - def put_copy_data(self, buffer: bytes) -> int: - # TODO: should be done without copy + def put_copy_data(self, buffer: "abc.Buffer") -> int: if not isinstance(buffer, bytes): buffer = bytes(buffer) rv = impl.PQputCopyData(self._pgconn_ptr, buffer, len(buffer)) diff --git a/tests/test_copy.py b/tests/test_copy.py index e506ad06e..52e3e968d 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -278,12 +278,14 @@ def test_copy_big_size_record(conn): @pytest.mark.slow -def test_copy_big_size_block(conn): +@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview]) +def test_copy_big_size_block(conn, pytype): cur = conn.cursor() ensure_table(cur, sample_tabledef) data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) + copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n") with cur.copy("copy copy_in (data) from stdin") as copy: - copy.write(data + "\n") + copy.write(copy_data) cur.execute("select data from copy_in limit 1") assert cur.fetchone()[0] == data @@ -468,14 +470,15 @@ def test_copy_from_to(conn): @pytest.mark.slow -def test_copy_from_to_bytes(conn): +@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview]) +def test_copy_from_to_bytes(conn, pytype): # Roundtrip from file to database to file blockwise gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024) gen.ensure_table() cur = conn.cursor() with cur.copy("copy copy_in from stdin") as copy: for block in gen.blocks(): - copy.write(block.encode()) + copy.write(pytype(block.encode())) gen.assert_data() diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 271e92105..6fc33b1c2 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -268,12 +268,14 @@ async def test_copy_big_size_record(aconn): @pytest.mark.slow -async def test_copy_big_size_block(aconn): +@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview]) +async def test_copy_big_size_block(aconn, pytype): cur = aconn.cursor() await ensure_table(cur, sample_tabledef) data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) + copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n") async with cur.copy("copy copy_in (data) from stdin") as copy: - await copy.write(data + "\n") + await copy.write(copy_data) await cur.execute("select data from copy_in limit 1") assert await cur.fetchone() == (data,) @@ -467,14 +469,15 @@ async def test_copy_from_to(aconn): @pytest.mark.slow -async def test_copy_from_to_bytes(aconn): +@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview]) +async def test_copy_from_to_bytes(aconn, pytype): # Roundtrip from file to database to file blockwise gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024) await gen.ensure_table() cur = aconn.cursor() async with cur.copy("copy copy_in from stdin") as copy: for block in gen.blocks(): - await copy.write(block.encode()) + await copy.write(pytype(block.encode())) await gen.assert_data()