]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added wrapper for libpq function PQgetCopyData
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 23 Jun 2020 08:28:27 +0000 (20:28 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 23 Jun 2020 10:24:12 +0000 (22:24 +1200)
psycopg3/pq/_pq_ctypes.py
psycopg3/pq/_pq_ctypes.pyi
psycopg3/pq/libpq.pxd
psycopg3/pq/pq_ctypes.py
psycopg3/pq/pq_cython.pyx
psycopg3/pq/proto.py
tests/pq/test_copy.py

index 86cf35ccb34adbb963f8f19d461446f9436f715f..79e857cf9c9853a21b22780e942b200fd18c91e8 100644 (file)
@@ -510,6 +510,10 @@ PQputCopyEnd = pq.PQputCopyEnd
 PQputCopyEnd.argtypes = [PGconn_ptr, c_char_p]
 PQputCopyEnd.restype = c_int
 
+PQgetCopyData = pq.PQgetCopyData
+PQgetCopyData.argtypes = [PGconn_ptr, POINTER(c_char_p), c_int]
+PQgetCopyData.restype = c_int
+
 
 # 33.11. Miscellaneous Functions
 
index d10d64e23bb398b734993e91a3e7c7d0eacfcf67..e3d223a5a8ceba725c297b340081f182e071e32d 100644 (file)
@@ -92,8 +92,15 @@ def PQnotifies(
 def PQputCopyEnd(
     arg1: Optional[PGconn_struct], arg2: Optional[bytes]
 ) -> int: ...
+
+# Arg 2 is a pointer, reported as _CArgObject by mypy
+def PQgetCopyData(
+    arg1: Optional[PGconn_struct], arg2: Any, arg3: int
+) -> int: ...
 def PQsetResultAttrs(
-    arg1: Optional[PGresult_struct], arg2: int, arg3: Array[PGresAttDesc_struct]  # type: ignore
+    arg1: Optional[PGresult_struct],
+    arg2: int,
+    arg3: Array[PGresAttDesc_struct],  # type: ignore
 ) -> int: ...
 
 # fmt: off
index af0a753f39bc034d7e87c03b74cda3f803821f6d..7e0f4f4bc5ec60de01fe29d1859acec3a5d94225 100644 (file)
@@ -236,6 +236,7 @@ cdef extern from "libpq-fe.h":
     # 33.9. Functions Associated with the COPY Command
     int PQputCopyData(PGconn *conn, const char *buffer, int nbytes)
     int PQputCopyEnd(PGconn *conn, const char *errormsg)
+    int PQgetCopyData(PGconn *conn, char **buffer, int async)
 
     # 33.11. Miscellaneous Functions
     void PQfreemem(void *ptr)
index 35022bc6064fa4c04b72f3d6c15bfabc633fe618..b85eaa7a846f51a45f5835b192a2b5b0493c9283 100644 (file)
@@ -13,9 +13,9 @@ import logging
 from weakref import ref
 from functools import partial
 
-from ctypes import Array, pointer, string_at, create_string_buffer
+from ctypes import Array, pointer, string_at, create_string_buffer, byref
 from ctypes import c_char_p, c_int, c_size_t, c_ulong
-from typing import Any, Callable, List, Optional, Sequence
+from typing import Any, Callable, List, Optional, Sequence, Tuple
 from typing import cast as t_cast, TYPE_CHECKING
 
 from .enums import (
@@ -512,6 +512,19 @@ class PGconn:
             raise PQerror(f"sending copy end failed: {error_message(self)}")
         return rv
 
+    def get_copy_data(self, async_: int) -> Tuple[int, Optional[bytes]]:
+        buffer_ptr = c_char_p()
+        nbytes = impl.PQgetCopyData(self.pgconn_ptr, byref(buffer_ptr), async_)
+        if nbytes == -2:
+            raise PQerror(f"receiving copy data failed: {error_message(self)}")
+        if buffer_ptr:
+            # TODO: do it without copy
+            data = string_at(buffer_ptr, nbytes)
+            impl.PQfreemem(buffer_ptr)
+            return nbytes, data
+        else:
+            return nbytes, None
+
     def make_empty_result(self, exec_status: ExecStatus) -> "PGresult":
         rv = impl.PQmakeEmptyPGresult(self.pgconn_ptr, exec_status)
         if not rv:
index 0bebb044abcf8a947a1f79e130b303625abadedd..903e332d68afb454882a298a2284b2fa7d04b3bf 100644 (file)
@@ -9,7 +9,7 @@ from cpython.mem cimport PyMem_Malloc, PyMem_Free
 from cpython.bytes cimport PyBytes_AsString
 
 import logging
-from typing import List, Optional, Sequence
+from typing import List, Optional, Sequence, Tuple
 
 from psycopg3.pq cimport libpq as impl
 from psycopg3.pq.libpq cimport Oid
@@ -448,6 +448,20 @@ cdef class PGconn:
             raise PQerror(f"sending copy end failed: {error_message(self)}")
         return rv
 
+    def get_copy_data(self, async_: int) -> Tuple[int, Optional[bytes]]:
+        cdef char *buffer_ptr = NULL
+        cdef int nbytes
+        nbytes = impl.PQgetCopyData(self.pgconn_ptr, &buffer_ptr, async_)
+        if nbytes == -2:
+            raise PQerror(f"receiving copy data failed: {error_message(self)}")
+        if buffer_ptr is not NULL:
+            # TODO: do it without copy
+            data = buffer_ptr[:nbytes]
+            impl.PQfreemem(buffer_ptr)
+            return nbytes, data
+        else:
+            return nbytes, None
+
     def make_empty_result(self, exec_status: ExecStatus) -> PGresult:
         cdef impl.PGresult *rv = impl.PQmakeEmptyPGresult(
             self.pgconn_ptr, exec_status)
index bd6aab19061557080958621f5e5c81b79dce9ceb..ad8d108c7414b3b8c82267041b475abf6a6e4545 100644 (file)
@@ -4,7 +4,8 @@ Protocol objects to represent objects exposed by different pq implementations.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Any, Callable, List, Optional, Sequence, TYPE_CHECKING
+from typing import Any, Callable, List, Optional, Sequence, Tuple
+from typing import TYPE_CHECKING
 from typing_extensions import Protocol
 
 from .enums import (
@@ -230,6 +231,9 @@ class PGconn(Protocol):
     def put_copy_end(self, error: Optional[bytes] = None) -> int:
         ...
 
+    def get_copy_data(self, async_: int) -> Tuple[int, Optional[bytes]]:
+        ...
+
     def make_empty_result(self, exec_status: ExecStatus) -> "PGresult":
         ...
 
index ddcd3534c15f731c5642b65af2020ab11612cfc3..db0c641d6042c9cdf61ce8a1a2683c64147a8a2f 100644 (file)
@@ -2,8 +2,31 @@ import pytest
 
 from psycopg3 import pq
 
+sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')"
+
 sample_tabledef = "col1 int primary key, col2 int, data text"
 
+sample_text = b"""\
+10\t20\thello
+40\t\\N\tworld
+"""
+
+sample_binary = """
+5047 434f 5059 0aff 0d0a 00
+00 0000 0000 0000 00
+00 0300 0000 0400 0000 0a00 0000 0400 0000 1400 0000 0568 656c 6c6f
+
+0003 0000 0004 0000 0028 ffff ffff 0000 0005 776f 726c 64
+
+ff ff
+"""
+
+sample_binary_rows = [
+    bytes.fromhex("".join(row.split())) for row in sample_binary.split("\n\n")
+]
+
+sample_binary = b"".join(sample_binary_rows)
+
 
 def test_put_data_no_copy(pgconn):
     with pytest.raises(pq.PQerror):
@@ -111,6 +134,38 @@ def test_copy_out_error_end(pgconn):
     assert res.get_value(0, 0) == b"0"
 
 
+def test_get_data_no_copy(pgconn):
+    with pytest.raises(pq.PQerror):
+        pgconn.get_copy_data(0)
+
+    pgconn.finish()
+    with pytest.raises(pq.PQerror):
+        pgconn.get_copy_data(0)
+
+
+@pytest.mark.parametrize("format", [pq.Format.TEXT, pq.Format.BINARY])
+def test_copy_out_read(pgconn, format):
+    stmt = f"copy ({sample_values}) to stdout (format {format.name})"
+    res = pgconn.exec_(stmt.encode("ascii"))
+    assert res.status == pq.ExecStatus.COPY_OUT
+    assert res.binary_tuples == format
+
+    if format == pq.Format.TEXT:
+        want = [row + b"\n" for row in sample_text.splitlines()]
+    else:
+        want = sample_binary_rows
+
+    for row in want:
+        nbytes, data = pgconn.get_copy_data(0)
+        assert nbytes == len(data)
+        assert data == row
+
+    assert pgconn.get_copy_data(0) == (-1, None)
+
+    res = pgconn.get_result()
+    assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+
 def ensure_table(pgconn, tabledef, name="copy_in"):
     pgconn.exec_(f"drop table if exists {name}".encode("ascii"))
     pgconn.exec_(f"create table {name} ({tabledef})".encode("ascii"))