]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added basic copy to server in blocks
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 22 Jun 2020 07:32:34 +0000 (19:32 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 22 Jun 2020 09:43:41 +0000 (21:43 +1200)
12 files changed:
psycopg3/copy.py
psycopg3/generators.py
psycopg3/pq/_pq_ctypes.py
psycopg3/pq/_pq_ctypes.pyi
psycopg3/pq/libpq.pxd
psycopg3/pq/pq_ctypes.py
psycopg3/pq/pq_cython.pyx
psycopg3/pq/proto.py
psycopg3/waiting.py
tests/pq/test_copy.py [new file with mode: 0644]
tests/test_copy.py
tests/test_copy_async.py [new file with mode: 0644]

index 2dccc8262bfcc5130542bebc6e8311495bf50f95..8daf4158f246571499798250dc7dec687035d171 100644 (file)
@@ -5,12 +5,17 @@ psycopg3 copy support
 # Copyright (C) 2020 The Psycopg Team
 
 import re
+from typing import cast, TYPE_CHECKING
 from typing import Any, Deque, Dict, List, Match, Optional, Tuple
 from collections import deque
 
-from .proto import AdaptContext
-from . import errors as e
 from . import pq
+from . import errors as e
+from .proto import AdaptContext
+from .generators import copy_to, copy_end
+
+if TYPE_CHECKING:
+    from .connection import Connection, AsyncConnection
 
 
 class BaseCopy:
@@ -111,8 +116,72 @@ _bsrepl_re = re.compile(rb"\\(.)")
 
 
 class Copy(BaseCopy):
-    pass
+    def __init__(
+        self,
+        context: AdaptContext,
+        result: Optional[pq.proto.PGresult],
+        format: pq.Format = pq.Format.TEXT,
+    ):
+        super().__init__(context=context, result=result, format=format)
+        self._connection: Optional["Connection"] = None
+
+    @property
+    def connection(self) -> "Connection":
+        if self._connection is None:
+            conn = self._transformer.connection
+            if conn is None:
+                raise ValueError("no connection available")
+            self._connection = cast("Connection", conn)
+
+        return self._connection
+
+    def write(self, buffer: bytes) -> None:
+        conn = self.connection
+        conn.wait(copy_to(conn.pgconn, buffer))
+
+    def finish(self, error: Optional[str] = None) -> None:
+        conn = self.connection
+        berr = (
+            conn.codec.encode(error, "replace")[0]
+            if error is not None
+            else None
+        )
+        result = conn.wait(copy_end(conn.pgconn, berr))
+        if result.status != pq.ExecStatus.COMMAND_OK:
+            raise e.error_from_result(
+                result, encoding=self.connection.codec.name
+            )
 
 
 class AsyncCopy(BaseCopy):
-    pass
+    def __init__(
+        self,
+        context: AdaptContext,
+        result: Optional[pq.proto.PGresult],
+        format: pq.Format = pq.Format.TEXT,
+    ):
+        super().__init__(context=context, result=result, format=format)
+        self._connection: Optional["AsyncConnection"] = None
+
+    @property
+    def connection(self) -> "AsyncConnection":
+        if self._connection is None:
+            conn = self._transformer.connection
+            if conn is None:
+                raise ValueError("no connection available")
+            self._connection = cast("AsyncConnection", conn)
+
+        return self._connection
+
+    async def write(self, buffer: bytes) -> None:
+        conn = self.connection
+        await conn.wait(copy_to(conn.pgconn, buffer))
+
+    async def finish(self, error: Optional[str] = None) -> None:
+        conn = self.connection
+        berr = (
+            conn.codec.encode(error, "replace")[0]
+            if error is not None
+            else None
+        )
+        await conn.wait(copy_end(conn.pgconn, berr))
index 3fd82e6beda0c0127004ec59e284fcf19bc476c2..b4269a84afc59b2e2851152921da8adde8e28e0b 100644 (file)
@@ -16,7 +16,7 @@ when the file descriptor is ready.
 # Copyright (C) 2020 The Psycopg Team
 
 import logging
-from typing import List
+from typing import List, Optional
 
 from . import pq
 from . import errors as e
@@ -149,3 +149,33 @@ def notifies(pgconn: pq.proto.PGconn) -> PQGen[List[pq.PGnotify]]:
             break
 
     return ns
+
+
+def copy_to(pgconn: pq.proto.PGconn, buffer: bytes) -> PQGen[None]:
+    # Retry enqueuing data until successful
+    while pgconn.put_copy_data(buffer) == 0:
+        yield pgconn.socket, Wait.W
+
+
+def copy_end(
+    pgconn: pq.proto.PGconn, error: Optional[bytes]
+) -> PQGen[pq.proto.PGresult]:
+    # Retry enqueuing end copy message until successful
+    while pgconn.put_copy_end(error) == 0:
+        yield pgconn.socket, Wait.W
+
+    # Repeat until it the message is flushed to the server
+    while 1:
+        yield pgconn.socket, Wait.W
+        f = pgconn.flush()
+        if f == 0:
+            break
+
+    # Retrieve the final result of copy
+    results = yield from fetch(pgconn)
+    if len(results) == 1:
+        return results[0]
+    else:
+        raise e.InternalError(
+            f"1 result expected from copy end, got {len(results)}"
+        )
index 1ab7ed4c24a85d933b71924a8d845c69fcc8bf7a..86cf35ccb34adbb963f8f19d461446f9436f715f 100644 (file)
@@ -500,6 +500,17 @@ PQnotifies.argtypes = [PGconn_ptr]
 PQnotifies.restype = PGnotify_ptr
 
 
+# 33.9. Functions Associated with the COPY Command
+
+PQputCopyData = pq.PQputCopyData
+PQputCopyData.argtypes = [PGconn_ptr, c_char_p, c_int]
+PQputCopyData.restype = c_int
+
+PQputCopyEnd = pq.PQputCopyEnd
+PQputCopyEnd.argtypes = [PGconn_ptr, c_char_p]
+PQputCopyEnd.restype = c_int
+
+
 # 33.11. Miscellaneous Functions
 
 PQfreemem = pq.PQfreemem
index 5a26f3f30d2ec980b154da59ba6a4981d44d3bb5..d10d64e23bb398b734993e91a3e7c7d0eacfcf67 100644 (file)
@@ -89,6 +89,9 @@ def PQsetNoticeReceiver(
 def PQnotifies(
     arg1: Optional[PGconn_struct],
 ) -> Optional[pointer[PGnotify_struct]]: ...  # type: ignore
+def PQputCopyEnd(
+    arg1: Optional[PGconn_struct], arg2: Optional[bytes]
+) -> int: ...
 def PQsetResultAttrs(
     arg1: Optional[PGresult_struct], arg2: int, arg3: Array[PGresAttDesc_struct]  # type: ignore
 ) -> int: ...
@@ -162,6 +165,7 @@ def PQisnonblocking(arg1: Optional[PGconn_struct]) -> int: ...
 def PQflush(arg1: Optional[PGconn_struct]) -> int: ...
 def PQgetCancel(arg1: Optional[PGconn_struct]) -> PGcancel_struct: ...
 def PQfreeCancel(arg1: Optional[PGcancel_struct]) -> None: ...
+def PQputCopyData(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> int: ...
 def PQfreemem(arg1: Any) -> None: ...
 def PQmakeEmptyPGresult(arg1: Optional[PGconn_struct], arg2: int) -> PGresult_struct: ...
 # autogenerated: end
index 1ab1595e29854d04266c158bc732f7d086648efa..af0a753f39bc034d7e87c03b74cda3f803821f6d 100644 (file)
@@ -233,6 +233,10 @@ cdef extern from "libpq-fe.h":
     # 33.8. Asynchronous Notification
     PGnotify *PQnotifies(PGconn *conn)
 
+    # 33.9. Functions Associated with the COPY Command
+    int PQputCopyData(PGconn *conn, const char *buffer, int nbytes)
+    int PQputCopyEnd(PGconn *conn, const char *errormsg)
+
     # 33.11. Miscellaneous Functions
     void PQfreemem(void *ptr)
     void PQconninfoFree(PQconninfoOption *connOptions)
index 6e6cd3ca0b3303eb78cd508734d1d960cf66a201..35022bc6064fa4c04b72f3d6c15bfabc633fe618 100644 (file)
@@ -500,6 +500,18 @@ class PGconn:
         else:
             return None
 
+    def put_copy_data(self, buffer: bytes) -> int:
+        rv = impl.PQputCopyData(self.pgconn_ptr, buffer, len(buffer))
+        if rv < 0:
+            raise PQerror(f"sending copy data failed: {error_message(self)}")
+        return rv
+
+    def put_copy_end(self, error: Optional[bytes] = None) -> int:
+        rv = impl.PQputCopyEnd(self.pgconn_ptr, error)
+        if rv < 0:
+            raise PQerror(f"sending copy end failed: {error_message(self)}")
+        return rv
+
     def make_empty_result(self, exec_status: ExecStatus) -> "PGresult":
         rv = impl.PQmakeEmptyPGresult(self.pgconn_ptr, exec_status)
         if not rv:
index f228c106464352c734dec173eae242c53a464363..0bebb044abcf8a947a1f79e130b303625abadedd 100644 (file)
@@ -6,6 +6,7 @@ libpq Python wrapper using cython bindings.
 
 from posix.unistd cimport getpid
 from cpython.mem cimport PyMem_Malloc, PyMem_Free
+from cpython.bytes cimport PyBytes_AsString
 
 import logging
 from typing import List, Optional, Sequence
@@ -213,10 +214,7 @@ cdef class PGconn:
     def send_query(self, command: bytes) -> None:
         self._ensure_pgconn()
         if not impl.PQsendQuery(self.pgconn_ptr, command):
-            raise PQerror(
-                "sending query failed:"
-                f" {error_message(self)}"
-            )
+            raise PQerror(f"sending query failed: {error_message(self)}")
 
     def exec_params(
         self,
@@ -269,8 +267,7 @@ cdef class PGconn:
         _clear_query_params(ctypes, cvalues, clengths, cformats)
         if not rv:
             raise PQerror(
-                "sending query and params failed:"
-                f" {error_message(self)}"
+                f"sending query and params failed: {error_message(self)}"
             )
 
     def send_prepare(
@@ -295,8 +292,7 @@ cdef class PGconn:
         PyMem_Free(atypes)
         if not rv:
             raise PQerror(
-                "sending query and params failed:"
-                f" {error_message(self)}"
+                f"sending query and params failed: {error_message(self)}"
             )
 
     def send_query_prepared(
@@ -323,8 +319,7 @@ cdef class PGconn:
         _clear_query_params(ctypes, cvalues, clengths, cformats)
         if not rv:
             raise PQerror(
-                "sending prepared query failed:"
-                f" {error_message(self)}"
+                f"sending prepared query failed: {error_message(self)}"
             )
 
     def prepare(
@@ -399,10 +394,7 @@ cdef class PGconn:
 
     def consume_input(self) -> None:
         if 1 != impl.PQconsumeInput(self.pgconn_ptr):
-            raise PQerror(
-                "consuming input failed:"
-                f" {error_message(self)}"
-            )
+            raise PQerror(f"consuming input failed: {error_message(self)}")
 
     def is_busy(self) -> int:
         return impl.PQisBusy(self.pgconn_ptr)
@@ -414,17 +406,12 @@ cdef class PGconn:
     @nonblocking.setter
     def nonblocking(self, arg: int) -> None:
         if 0 > impl.PQsetnonblocking(self.pgconn_ptr, arg):
-            raise PQerror(
-                f"setting nonblocking failed:"
-                f" {error_message(self)}"
-            )
+            raise PQerror(f"setting nonblocking failed: {error_message(self)}")
 
     def flush(self) -> int:
         cdef int rv = impl.PQflush(self.pgconn_ptr)
         if rv < 0:
-            raise PQerror(
-                f"flushing failed:{error_message(self)}"
-            )
+            raise PQerror(f"flushing failed:{error_message(self)}")
         return rv
 
     def get_cancel(self) -> PGcancel:
@@ -442,6 +429,25 @@ cdef class PGconn:
         else:
             return None
 
+    def put_copy_data(self, buffer: bytes) -> int:
+        cdef int rv
+        cdef const char *cbuffer = PyBytes_AsString(buffer)
+        cdef int length = len(buffer)
+        rv = impl.PQputCopyData(self.pgconn_ptr, cbuffer, length)
+        if rv < 0:
+            raise PQerror(f"sending copy data failed: {error_message(self)}")
+        return rv
+
+    def put_copy_end(self, error: Optional[bytes] = None) -> int:
+        cdef int rv
+        cdef const char *cerr = NULL
+        if error is not None:
+            cerr = PyBytes_AsString(error)
+        rv = impl.PQputCopyEnd(self.pgconn_ptr, cerr)
+        if rv < 0:
+            raise PQerror(f"sending copy end failed: {error_message(self)}")
+        return rv
+
     def make_empty_result(self, exec_status: ExecStatus) -> PGresult:
         cdef impl.PGresult *rv = impl.PQmakeEmptyPGresult(
             self.pgconn_ptr, exec_status)
index dac21ccdc1e9d4c9b5405d35314eb1ed67e4e9cf..bd6aab19061557080958621f5e5c81b79dce9ceb 100644 (file)
@@ -224,6 +224,12 @@ class PGconn(Protocol):
     def notifies(self) -> Optional["PGnotify"]:
         ...
 
+    def put_copy_data(self, buffer: bytes) -> int:
+        ...
+
+    def put_copy_end(self, error: Optional[bytes] = None) -> int:
+        ...
+
     def make_empty_result(self, exec_status: ExecStatus) -> "PGresult":
         ...
 
index 7ce2f54e152e425ee395bc001948367f1b8acca7..67ac85280779e772afbfe813c7a3e15303180e4f 100644 (file)
@@ -29,7 +29,7 @@ class Ready(IntEnum):
     W = EVENT_WRITE
 
 
-def wait(gen: "PQGen[RV]", timeout: Optional[float] = None) -> "RV":
+def wait(gen: PQGen[RV], timeout: Optional[float] = None) -> RV:
     """
     Wait for a generator using the best option available on the platform.
 
@@ -54,11 +54,11 @@ def wait(gen: "PQGen[RV]", timeout: Optional[float] = None) -> "RV":
             fd, s = gen.send(ready[0][1])
 
     except StopIteration as ex:
-        rv: "RV" = ex.args[0]
+        rv: RV = ex.args[0] if ex.args else None
         return rv
 
 
-async def wait_async(gen: "PQGen[RV]") -> "RV":
+async def wait_async(gen: PQGen[RV]) -> RV:
     """
     Coroutine waiting for a generator to complete.
 
@@ -102,5 +102,5 @@ async def wait_async(gen: "PQGen[RV]") -> "RV":
             fd, s = gen.send(ready)
 
     except StopIteration as ex:
-        rv: "RV" = ex.args[0]
+        rv: RV = ex.args[0] if ex.args else None
         return rv
diff --git a/tests/pq/test_copy.py b/tests/pq/test_copy.py
new file mode 100644 (file)
index 0000000..ddcd353
--- /dev/null
@@ -0,0 +1,116 @@
+import pytest
+
+from psycopg3 import pq
+
+sample_tabledef = "col1 int primary key, col2 int, data text"
+
+
+def test_put_data_no_copy(pgconn):
+    with pytest.raises(pq.PQerror):
+        pgconn.put_copy_data(b"wat")
+
+    pgconn.finish()
+    with pytest.raises(pq.PQerror):
+        pgconn.put_copy_data(b"wat")
+
+
+def test_put_end_no_copy(pgconn):
+    with pytest.raises(pq.PQerror):
+        pgconn.put_copy_end()
+
+    pgconn.finish()
+    with pytest.raises(pq.PQerror):
+        pgconn.put_copy_end()
+
+
+def test_copy_out(pgconn):
+    ensure_table(pgconn, sample_tabledef)
+    res = pgconn.exec_(b"copy copy_in from stdin")
+    assert res.status == pq.ExecStatus.COPY_IN
+
+    for i in range(10):
+        data = []
+        for j in range(20):
+            data.append(
+                f"""\
+{i * 20 + j}\t{j}\t{'X' * (i * 20 + j)}
+"""
+            )
+        rv = pgconn.put_copy_data("".join(data).encode("ascii"))
+        assert rv > 0
+
+    rv = pgconn.put_copy_end()
+    assert rv > 0
+
+    res = pgconn.get_result()
+    assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+    res = pgconn.exec_(
+        b"select min(col1), max(col1), count(*), max(length(data)) from copy_in"
+    )
+    assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+    assert res.get_value(0, 0) == b"0"
+    assert res.get_value(0, 1) == b"199"
+    assert res.get_value(0, 2) == b"200"
+    assert res.get_value(0, 3) == b"199"
+
+
+def test_copy_out_err(pgconn):
+    ensure_table(pgconn, sample_tabledef)
+    res = pgconn.exec_(b"copy copy_in from stdin")
+    assert res.status == pq.ExecStatus.COPY_IN
+
+    for i in range(10):
+        data = []
+        for j in range(20):
+            data.append(
+                f"""\
+{i * 20 + j}\thardly a number\tnope
+"""
+            )
+        rv = pgconn.put_copy_data("".join(data).encode("ascii"))
+        assert rv > 0
+
+    rv = pgconn.put_copy_end()
+    assert rv > 0
+
+    res = pgconn.get_result()
+    assert res.status == pq.ExecStatus.FATAL_ERROR
+    assert b"hardly a number" in res.error_message
+
+    res = pgconn.exec_(b"select count(*) from copy_in")
+    assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+    assert res.get_value(0, 0) == b"0"
+
+
+def test_copy_out_error_end(pgconn):
+    ensure_table(pgconn, sample_tabledef)
+    res = pgconn.exec_(b"copy copy_in from stdin")
+    assert res.status == pq.ExecStatus.COPY_IN
+
+    for i in range(10):
+        data = []
+        for j in range(20):
+            data.append(
+                f"""\
+{i * 20 + j}\t{j}\t{'X' * (i * 20 + j)}
+"""
+            )
+        rv = pgconn.put_copy_data("".join(data).encode("ascii"))
+        assert rv > 0
+
+    rv = pgconn.put_copy_end(b"nuttengoggenio")
+    assert rv > 0
+
+    res = pgconn.get_result()
+    assert res.status == pq.ExecStatus.FATAL_ERROR
+    assert b"nuttengoggenio" in res.error_message
+
+    res = pgconn.exec_(b"select count(*) from copy_in")
+    assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+    assert res.get_value(0, 0) == b"0"
+
+
+def ensure_table(pgconn, tabledef, name="copy_in"):
+    pgconn.exec_(f"drop table if exists {name}".encode("ascii"))
+    pgconn.exec_(f"create table {name} ({tabledef})".encode("ascii"))
index 1e68b61d0a60682603716d8a5b403152d37d0e83..49635fc4084fc5497e9b5226535d2bd3a1024465 100644 (file)
@@ -23,6 +23,7 @@ sample_binary = """
 0000 0004 0000 0028 ffff ffff 0000 0005
 776f 726c 64ff ff
 """
+sample_binary = bytes.fromhex("".join(sample_binary.split()))
 
 
 def set_sample_attributes(res, format):
@@ -34,6 +35,7 @@ def set_sample_attributes(res, format):
     res.set_attributes(attrs)
 
 
+@pytest.mark.xfail
 @pytest.mark.parametrize(
     "format, buffer",
     [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
@@ -46,6 +48,7 @@ def test_load_noinfo(conn, format, buffer):
     assert records == as_bytes(sample_records)
 
 
+@pytest.mark.xfail
 @pytest.mark.parametrize(
     "format, buffer",
     [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
@@ -61,6 +64,7 @@ def test_load(conn, format, buffer):
     assert records == sample_records
 
 
+@pytest.mark.xfail
 @pytest.mark.parametrize(
     "format, buffer",
     [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
@@ -79,6 +83,7 @@ def test_dump(conn, format, buffer):
     assert copy.get_buffer() is None
 
 
+@pytest.mark.xfail
 @pytest.mark.parametrize(
     "format, buffer",
     [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
@@ -90,6 +95,7 @@ def test_buffers(format, buffer):
     assert list(copy.buffers(sample_records)) == [globals()[buffer]]
 
 
+@pytest.mark.xfail
 @pytest.mark.parametrize(
     "format, buffer",
     [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
@@ -102,6 +108,7 @@ def test_copy_out_read(conn, format, buffer):
     assert copy.read() is None
 
 
+@pytest.mark.xfail
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
 def test_iter(conn, format):
     cur = conn.cursor()
@@ -118,11 +125,12 @@ def test_copy_in_buffers(conn, format, buffer):
     ensure_table(cur, sample_tabledef)
     copy = cur.copy(f"copy copy_in from stdin (format {format.name})")
     copy.write(globals()[buffer])
-    copy.end()
+    copy.finish()
     data = cur.execute("select * from copy_in order by 1").fetchall()
     assert data == sample_records
 
 
+@pytest.mark.xfail
 @pytest.mark.parametrize(
     "format, buffer",
     [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
@@ -137,6 +145,7 @@ def test_copy_in_buffers_with(conn, format, buffer):
     assert data == sample_records
 
 
+@pytest.mark.xfail
 @pytest.mark.parametrize(
     "format", [(Format.TEXT,), (Format.BINARY,)],
 )
diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py
new file mode 100644 (file)
index 0000000..f6b414c
--- /dev/null
@@ -0,0 +1,43 @@
+import pytest
+
+from psycopg3.adapt import Format
+
+from .test_copy import sample_text, sample_binary  # noqa
+from .test_copy import sample_values, sample_records, sample_tabledef
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.mark.xfail
+@pytest.mark.parametrize(
+    "format, buffer",
+    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+async def test_copy_out_read(aconn, format, buffer):
+    cur = aconn.cursor()
+    copy = await cur.copy(
+        f"copy ({sample_values}) to stdout (format {format.name})"
+    )
+    assert await copy.read() == globals()[buffer]
+    assert await copy.read() is None
+    assert await copy.read() is None
+
+
+@pytest.mark.parametrize(
+    "format, buffer",
+    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+async def test_copy_in_buffers(aconn, format, buffer):
+    cur = aconn.cursor()
+    await ensure_table(cur, sample_tabledef)
+    copy = await cur.copy(f"copy copy_in from stdin (format {format.name})")
+    await copy.write(globals()[buffer])
+    await copy.finish()
+    await cur.execute("select * from copy_in order by 1")
+    data = await cur.fetchall()
+    assert data == sample_records
+
+
+async def ensure_table(cur, tabledef, name="copy_in"):
+    await cur.execute(f"drop table if exists {name}")
+    await cur.execute(f"create table {name} ({tabledef})")