From: Daniele Varrazzo Date: Tue, 23 Jun 2020 08:28:27 +0000 (+1200) Subject: Added wrapper for libpq function PQgetCopyData X-Git-Tag: 3.0.dev0~480 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=066fa265359a4a279c66d956130ac91dd6ad2628;p=thirdparty%2Fpsycopg.git Added wrapper for libpq function PQgetCopyData --- diff --git a/psycopg3/pq/_pq_ctypes.py b/psycopg3/pq/_pq_ctypes.py index 86cf35ccb..79e857cf9 100644 --- a/psycopg3/pq/_pq_ctypes.py +++ b/psycopg3/pq/_pq_ctypes.py @@ -510,6 +510,10 @@ PQputCopyEnd = pq.PQputCopyEnd PQputCopyEnd.argtypes = [PGconn_ptr, c_char_p] PQputCopyEnd.restype = c_int +PQgetCopyData = pq.PQgetCopyData +PQgetCopyData.argtypes = [PGconn_ptr, POINTER(c_char_p), c_int] +PQgetCopyData.restype = c_int + # 33.11. Miscellaneous Functions diff --git a/psycopg3/pq/_pq_ctypes.pyi b/psycopg3/pq/_pq_ctypes.pyi index d10d64e23..e3d223a5a 100644 --- a/psycopg3/pq/_pq_ctypes.pyi +++ b/psycopg3/pq/_pq_ctypes.pyi @@ -92,8 +92,15 @@ def PQnotifies( def PQputCopyEnd( arg1: Optional[PGconn_struct], arg2: Optional[bytes] ) -> int: ... + +# Arg 2 is a pointer, reported as _CArgObject by mypy +def PQgetCopyData( + arg1: Optional[PGconn_struct], arg2: Any, arg3: int +) -> int: ... def PQsetResultAttrs( - arg1: Optional[PGresult_struct], arg2: int, arg3: Array[PGresAttDesc_struct] # type: ignore + arg1: Optional[PGresult_struct], + arg2: int, + arg3: Array[PGresAttDesc_struct], # type: ignore ) -> int: ... # fmt: off diff --git a/psycopg3/pq/libpq.pxd b/psycopg3/pq/libpq.pxd index af0a753f3..7e0f4f4bc 100644 --- a/psycopg3/pq/libpq.pxd +++ b/psycopg3/pq/libpq.pxd @@ -236,6 +236,7 @@ cdef extern from "libpq-fe.h": # 33.9. Functions Associated with the COPY Command int PQputCopyData(PGconn *conn, const char *buffer, int nbytes) int PQputCopyEnd(PGconn *conn, const char *errormsg) + int PQgetCopyData(PGconn *conn, char **buffer, int async) # 33.11. Miscellaneous Functions void PQfreemem(void *ptr) diff --git a/psycopg3/pq/pq_ctypes.py b/psycopg3/pq/pq_ctypes.py index 35022bc60..b85eaa7a8 100644 --- a/psycopg3/pq/pq_ctypes.py +++ b/psycopg3/pq/pq_ctypes.py @@ -13,9 +13,9 @@ import logging from weakref import ref from functools import partial -from ctypes import Array, pointer, string_at, create_string_buffer +from ctypes import Array, pointer, string_at, create_string_buffer, byref from ctypes import c_char_p, c_int, c_size_t, c_ulong -from typing import Any, Callable, List, Optional, Sequence +from typing import Any, Callable, List, Optional, Sequence, Tuple from typing import cast as t_cast, TYPE_CHECKING from .enums import ( @@ -512,6 +512,19 @@ class PGconn: raise PQerror(f"sending copy end failed: {error_message(self)}") return rv + def get_copy_data(self, async_: int) -> Tuple[int, Optional[bytes]]: + buffer_ptr = c_char_p() + nbytes = impl.PQgetCopyData(self.pgconn_ptr, byref(buffer_ptr), async_) + if nbytes == -2: + raise PQerror(f"receiving copy data failed: {error_message(self)}") + if buffer_ptr: + # TODO: do it without copy + data = string_at(buffer_ptr, nbytes) + impl.PQfreemem(buffer_ptr) + return nbytes, data + else: + return nbytes, None + def make_empty_result(self, exec_status: ExecStatus) -> "PGresult": rv = impl.PQmakeEmptyPGresult(self.pgconn_ptr, exec_status) if not rv: diff --git a/psycopg3/pq/pq_cython.pyx b/psycopg3/pq/pq_cython.pyx index 0bebb044a..903e332d6 100644 --- a/psycopg3/pq/pq_cython.pyx +++ b/psycopg3/pq/pq_cython.pyx @@ -9,7 +9,7 @@ from cpython.mem cimport PyMem_Malloc, PyMem_Free from cpython.bytes cimport PyBytes_AsString import logging -from typing import List, Optional, Sequence +from typing import List, Optional, Sequence, Tuple from psycopg3.pq cimport libpq as impl from psycopg3.pq.libpq cimport Oid @@ -448,6 +448,20 @@ cdef class PGconn: raise PQerror(f"sending copy end failed: {error_message(self)}") return rv + def get_copy_data(self, async_: int) -> Tuple[int, Optional[bytes]]: + cdef char *buffer_ptr = NULL + cdef int nbytes + nbytes = impl.PQgetCopyData(self.pgconn_ptr, &buffer_ptr, async_) + if nbytes == -2: + raise PQerror(f"receiving copy data failed: {error_message(self)}") + if buffer_ptr is not NULL: + # TODO: do it without copy + data = buffer_ptr[:nbytes] + impl.PQfreemem(buffer_ptr) + return nbytes, data + else: + return nbytes, None + def make_empty_result(self, exec_status: ExecStatus) -> PGresult: cdef impl.PGresult *rv = impl.PQmakeEmptyPGresult( self.pgconn_ptr, exec_status) diff --git a/psycopg3/pq/proto.py b/psycopg3/pq/proto.py index bd6aab190..ad8d108c7 100644 --- a/psycopg3/pq/proto.py +++ b/psycopg3/pq/proto.py @@ -4,7 +4,8 @@ Protocol objects to represent objects exposed by different pq implementations. # Copyright (C) 2020 The Psycopg Team -from typing import Any, Callable, List, Optional, Sequence, TYPE_CHECKING +from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING from typing_extensions import Protocol from .enums import ( @@ -230,6 +231,9 @@ class PGconn(Protocol): def put_copy_end(self, error: Optional[bytes] = None) -> int: ... + def get_copy_data(self, async_: int) -> Tuple[int, Optional[bytes]]: + ... + def make_empty_result(self, exec_status: ExecStatus) -> "PGresult": ... diff --git a/tests/pq/test_copy.py b/tests/pq/test_copy.py index ddcd3534c..db0c641d6 100644 --- a/tests/pq/test_copy.py +++ b/tests/pq/test_copy.py @@ -2,8 +2,31 @@ import pytest from psycopg3 import pq +sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')" + sample_tabledef = "col1 int primary key, col2 int, data text" +sample_text = b"""\ +10\t20\thello +40\t\\N\tworld +""" + +sample_binary = """ +5047 434f 5059 0aff 0d0a 00 +00 0000 0000 0000 00 +00 0300 0000 0400 0000 0a00 0000 0400 0000 1400 0000 0568 656c 6c6f + +0003 0000 0004 0000 0028 ffff ffff 0000 0005 776f 726c 64 + +ff ff +""" + +sample_binary_rows = [ + bytes.fromhex("".join(row.split())) for row in sample_binary.split("\n\n") +] + +sample_binary = b"".join(sample_binary_rows) + def test_put_data_no_copy(pgconn): with pytest.raises(pq.PQerror): @@ -111,6 +134,38 @@ def test_copy_out_error_end(pgconn): assert res.get_value(0, 0) == b"0" +def test_get_data_no_copy(pgconn): + with pytest.raises(pq.PQerror): + pgconn.get_copy_data(0) + + pgconn.finish() + with pytest.raises(pq.PQerror): + pgconn.get_copy_data(0) + + +@pytest.mark.parametrize("format", [pq.Format.TEXT, pq.Format.BINARY]) +def test_copy_out_read(pgconn, format): + stmt = f"copy ({sample_values}) to stdout (format {format.name})" + res = pgconn.exec_(stmt.encode("ascii")) + assert res.status == pq.ExecStatus.COPY_OUT + assert res.binary_tuples == format + + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + + for row in want: + nbytes, data = pgconn.get_copy_data(0) + assert nbytes == len(data) + assert data == row + + assert pgconn.get_copy_data(0) == (-1, None) + + res = pgconn.get_result() + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + def ensure_table(pgconn, tabledef, name="copy_in"): pgconn.exec_(f"drop table if exists {name}".encode("ascii")) pgconn.exec_(f"create table {name} ({tabledef})".encode("ascii"))