From: Daniele Varrazzo Date: Mon, 22 Jun 2020 07:32:34 +0000 (+1200) Subject: Added basic copy to server in blocks X-Git-Tag: 3.0.dev0~482 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ff66ca3b29f4c929c9673f7bde1c5635ffae860e;p=thirdparty%2Fpsycopg.git Added basic copy to server in blocks --- diff --git a/psycopg3/copy.py b/psycopg3/copy.py index 2dccc8262..8daf4158f 100644 --- a/psycopg3/copy.py +++ b/psycopg3/copy.py @@ -5,12 +5,17 @@ psycopg3 copy support # Copyright (C) 2020 The Psycopg Team import re +from typing import cast, TYPE_CHECKING from typing import Any, Deque, Dict, List, Match, Optional, Tuple from collections import deque -from .proto import AdaptContext -from . import errors as e from . import pq +from . import errors as e +from .proto import AdaptContext +from .generators import copy_to, copy_end + +if TYPE_CHECKING: + from .connection import Connection, AsyncConnection class BaseCopy: @@ -111,8 +116,72 @@ _bsrepl_re = re.compile(rb"\\(.)") class Copy(BaseCopy): - pass + def __init__( + self, + context: AdaptContext, + result: Optional[pq.proto.PGresult], + format: pq.Format = pq.Format.TEXT, + ): + super().__init__(context=context, result=result, format=format) + self._connection: Optional["Connection"] = None + + @property + def connection(self) -> "Connection": + if self._connection is None: + conn = self._transformer.connection + if conn is None: + raise ValueError("no connection available") + self._connection = cast("Connection", conn) + + return self._connection + + def write(self, buffer: bytes) -> None: + conn = self.connection + conn.wait(copy_to(conn.pgconn, buffer)) + + def finish(self, error: Optional[str] = None) -> None: + conn = self.connection + berr = ( + conn.codec.encode(error, "replace")[0] + if error is not None + else None + ) + result = conn.wait(copy_end(conn.pgconn, berr)) + if result.status != pq.ExecStatus.COMMAND_OK: + raise e.error_from_result( + result, encoding=self.connection.codec.name + ) class AsyncCopy(BaseCopy): - pass + def __init__( + self, + context: AdaptContext, + result: Optional[pq.proto.PGresult], + format: pq.Format = pq.Format.TEXT, + ): + super().__init__(context=context, result=result, format=format) + self._connection: Optional["AsyncConnection"] = None + + @property + def connection(self) -> "AsyncConnection": + if self._connection is None: + conn = self._transformer.connection + if conn is None: + raise ValueError("no connection available") + self._connection = cast("AsyncConnection", conn) + + return self._connection + + async def write(self, buffer: bytes) -> None: + conn = self.connection + await conn.wait(copy_to(conn.pgconn, buffer)) + + async def finish(self, error: Optional[str] = None) -> None: + conn = self.connection + berr = ( + conn.codec.encode(error, "replace")[0] + if error is not None + else None + ) + await conn.wait(copy_end(conn.pgconn, berr)) diff --git a/psycopg3/generators.py b/psycopg3/generators.py index 3fd82e6be..b4269a84a 100644 --- a/psycopg3/generators.py +++ b/psycopg3/generators.py @@ -16,7 +16,7 @@ when the file descriptor is ready. # Copyright (C) 2020 The Psycopg Team import logging -from typing import List +from typing import List, Optional from . import pq from . import errors as e @@ -149,3 +149,33 @@ def notifies(pgconn: pq.proto.PGconn) -> PQGen[List[pq.PGnotify]]: break return ns + + +def copy_to(pgconn: pq.proto.PGconn, buffer: bytes) -> PQGen[None]: + # Retry enqueuing data until successful + while pgconn.put_copy_data(buffer) == 0: + yield pgconn.socket, Wait.W + + +def copy_end( + pgconn: pq.proto.PGconn, error: Optional[bytes] +) -> PQGen[pq.proto.PGresult]: + # Retry enqueuing end copy message until successful + while pgconn.put_copy_end(error) == 0: + yield pgconn.socket, Wait.W + + # Repeat until it the message is flushed to the server + while 1: + yield pgconn.socket, Wait.W + f = pgconn.flush() + if f == 0: + break + + # Retrieve the final result of copy + results = yield from fetch(pgconn) + if len(results) == 1: + return results[0] + else: + raise e.InternalError( + f"1 result expected from copy end, got {len(results)}" + ) diff --git a/psycopg3/pq/_pq_ctypes.py b/psycopg3/pq/_pq_ctypes.py index 1ab7ed4c2..86cf35ccb 100644 --- a/psycopg3/pq/_pq_ctypes.py +++ b/psycopg3/pq/_pq_ctypes.py @@ -500,6 +500,17 @@ PQnotifies.argtypes = [PGconn_ptr] PQnotifies.restype = PGnotify_ptr +# 33.9. Functions Associated with the COPY Command + +PQputCopyData = pq.PQputCopyData +PQputCopyData.argtypes = [PGconn_ptr, c_char_p, c_int] +PQputCopyData.restype = c_int + +PQputCopyEnd = pq.PQputCopyEnd +PQputCopyEnd.argtypes = [PGconn_ptr, c_char_p] +PQputCopyEnd.restype = c_int + + # 33.11. Miscellaneous Functions PQfreemem = pq.PQfreemem diff --git a/psycopg3/pq/_pq_ctypes.pyi b/psycopg3/pq/_pq_ctypes.pyi index 5a26f3f30..d10d64e23 100644 --- a/psycopg3/pq/_pq_ctypes.pyi +++ b/psycopg3/pq/_pq_ctypes.pyi @@ -89,6 +89,9 @@ def PQsetNoticeReceiver( def PQnotifies( arg1: Optional[PGconn_struct], ) -> Optional[pointer[PGnotify_struct]]: ... # type: ignore +def PQputCopyEnd( + arg1: Optional[PGconn_struct], arg2: Optional[bytes] +) -> int: ... def PQsetResultAttrs( arg1: Optional[PGresult_struct], arg2: int, arg3: Array[PGresAttDesc_struct] # type: ignore ) -> int: ... @@ -162,6 +165,7 @@ def PQisnonblocking(arg1: Optional[PGconn_struct]) -> int: ... def PQflush(arg1: Optional[PGconn_struct]) -> int: ... def PQgetCancel(arg1: Optional[PGconn_struct]) -> PGcancel_struct: ... def PQfreeCancel(arg1: Optional[PGcancel_struct]) -> None: ... +def PQputCopyData(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> int: ... def PQfreemem(arg1: Any) -> None: ... def PQmakeEmptyPGresult(arg1: Optional[PGconn_struct], arg2: int) -> PGresult_struct: ... # autogenerated: end diff --git a/psycopg3/pq/libpq.pxd b/psycopg3/pq/libpq.pxd index 1ab1595e2..af0a753f3 100644 --- a/psycopg3/pq/libpq.pxd +++ b/psycopg3/pq/libpq.pxd @@ -233,6 +233,10 @@ cdef extern from "libpq-fe.h": # 33.8. Asynchronous Notification PGnotify *PQnotifies(PGconn *conn) + # 33.9. Functions Associated with the COPY Command + int PQputCopyData(PGconn *conn, const char *buffer, int nbytes) + int PQputCopyEnd(PGconn *conn, const char *errormsg) + # 33.11. Miscellaneous Functions void PQfreemem(void *ptr) void PQconninfoFree(PQconninfoOption *connOptions) diff --git a/psycopg3/pq/pq_ctypes.py b/psycopg3/pq/pq_ctypes.py index 6e6cd3ca0..35022bc60 100644 --- a/psycopg3/pq/pq_ctypes.py +++ b/psycopg3/pq/pq_ctypes.py @@ -500,6 +500,18 @@ class PGconn: else: return None + def put_copy_data(self, buffer: bytes) -> int: + rv = impl.PQputCopyData(self.pgconn_ptr, buffer, len(buffer)) + if rv < 0: + raise PQerror(f"sending copy data failed: {error_message(self)}") + return rv + + def put_copy_end(self, error: Optional[bytes] = None) -> int: + rv = impl.PQputCopyEnd(self.pgconn_ptr, error) + if rv < 0: + raise PQerror(f"sending copy end failed: {error_message(self)}") + return rv + def make_empty_result(self, exec_status: ExecStatus) -> "PGresult": rv = impl.PQmakeEmptyPGresult(self.pgconn_ptr, exec_status) if not rv: diff --git a/psycopg3/pq/pq_cython.pyx b/psycopg3/pq/pq_cython.pyx index f228c1064..0bebb044a 100644 --- a/psycopg3/pq/pq_cython.pyx +++ b/psycopg3/pq/pq_cython.pyx @@ -6,6 +6,7 @@ libpq Python wrapper using cython bindings. from posix.unistd cimport getpid from cpython.mem cimport PyMem_Malloc, PyMem_Free +from cpython.bytes cimport PyBytes_AsString import logging from typing import List, Optional, Sequence @@ -213,10 +214,7 @@ cdef class PGconn: def send_query(self, command: bytes) -> None: self._ensure_pgconn() if not impl.PQsendQuery(self.pgconn_ptr, command): - raise PQerror( - "sending query failed:" - f" {error_message(self)}" - ) + raise PQerror(f"sending query failed: {error_message(self)}") def exec_params( self, @@ -269,8 +267,7 @@ cdef class PGconn: _clear_query_params(ctypes, cvalues, clengths, cformats) if not rv: raise PQerror( - "sending query and params failed:" - f" {error_message(self)}" + f"sending query and params failed: {error_message(self)}" ) def send_prepare( @@ -295,8 +292,7 @@ cdef class PGconn: PyMem_Free(atypes) if not rv: raise PQerror( - "sending query and params failed:" - f" {error_message(self)}" + f"sending query and params failed: {error_message(self)}" ) def send_query_prepared( @@ -323,8 +319,7 @@ cdef class PGconn: _clear_query_params(ctypes, cvalues, clengths, cformats) if not rv: raise PQerror( - "sending prepared query failed:" - f" {error_message(self)}" + f"sending prepared query failed: {error_message(self)}" ) def prepare( @@ -399,10 +394,7 @@ cdef class PGconn: def consume_input(self) -> None: if 1 != impl.PQconsumeInput(self.pgconn_ptr): - raise PQerror( - "consuming input failed:" - f" {error_message(self)}" - ) + raise PQerror(f"consuming input failed: {error_message(self)}") def is_busy(self) -> int: return impl.PQisBusy(self.pgconn_ptr) @@ -414,17 +406,12 @@ cdef class PGconn: @nonblocking.setter def nonblocking(self, arg: int) -> None: if 0 > impl.PQsetnonblocking(self.pgconn_ptr, arg): - raise PQerror( - f"setting nonblocking failed:" - f" {error_message(self)}" - ) + raise PQerror(f"setting nonblocking failed: {error_message(self)}") def flush(self) -> int: cdef int rv = impl.PQflush(self.pgconn_ptr) if rv < 0: - raise PQerror( - f"flushing failed:{error_message(self)}" - ) + raise PQerror(f"flushing failed:{error_message(self)}") return rv def get_cancel(self) -> PGcancel: @@ -442,6 +429,25 @@ cdef class PGconn: else: return None + def put_copy_data(self, buffer: bytes) -> int: + cdef int rv + cdef const char *cbuffer = PyBytes_AsString(buffer) + cdef int length = len(buffer) + rv = impl.PQputCopyData(self.pgconn_ptr, cbuffer, length) + if rv < 0: + raise PQerror(f"sending copy data failed: {error_message(self)}") + return rv + + def put_copy_end(self, error: Optional[bytes] = None) -> int: + cdef int rv + cdef const char *cerr = NULL + if error is not None: + cerr = PyBytes_AsString(error) + rv = impl.PQputCopyEnd(self.pgconn_ptr, cerr) + if rv < 0: + raise PQerror(f"sending copy end failed: {error_message(self)}") + return rv + def make_empty_result(self, exec_status: ExecStatus) -> PGresult: cdef impl.PGresult *rv = impl.PQmakeEmptyPGresult( self.pgconn_ptr, exec_status) diff --git a/psycopg3/pq/proto.py b/psycopg3/pq/proto.py index dac21ccdc..bd6aab190 100644 --- a/psycopg3/pq/proto.py +++ b/psycopg3/pq/proto.py @@ -224,6 +224,12 @@ class PGconn(Protocol): def notifies(self) -> Optional["PGnotify"]: ... + def put_copy_data(self, buffer: bytes) -> int: + ... + + def put_copy_end(self, error: Optional[bytes] = None) -> int: + ... + def make_empty_result(self, exec_status: ExecStatus) -> "PGresult": ... diff --git a/psycopg3/waiting.py b/psycopg3/waiting.py index 7ce2f54e1..67ac85280 100644 --- a/psycopg3/waiting.py +++ b/psycopg3/waiting.py @@ -29,7 +29,7 @@ class Ready(IntEnum): W = EVENT_WRITE -def wait(gen: "PQGen[RV]", timeout: Optional[float] = None) -> "RV": +def wait(gen: PQGen[RV], timeout: Optional[float] = None) -> RV: """ Wait for a generator using the best option available on the platform. @@ -54,11 +54,11 @@ def wait(gen: "PQGen[RV]", timeout: Optional[float] = None) -> "RV": fd, s = gen.send(ready[0][1]) except StopIteration as ex: - rv: "RV" = ex.args[0] + rv: RV = ex.args[0] if ex.args else None return rv -async def wait_async(gen: "PQGen[RV]") -> "RV": +async def wait_async(gen: PQGen[RV]) -> RV: """ Coroutine waiting for a generator to complete. @@ -102,5 +102,5 @@ async def wait_async(gen: "PQGen[RV]") -> "RV": fd, s = gen.send(ready) except StopIteration as ex: - rv: "RV" = ex.args[0] + rv: RV = ex.args[0] if ex.args else None return rv diff --git a/tests/pq/test_copy.py b/tests/pq/test_copy.py new file mode 100644 index 000000000..ddcd3534c --- /dev/null +++ b/tests/pq/test_copy.py @@ -0,0 +1,116 @@ +import pytest + +from psycopg3 import pq + +sample_tabledef = "col1 int primary key, col2 int, data text" + + +def test_put_data_no_copy(pgconn): + with pytest.raises(pq.PQerror): + pgconn.put_copy_data(b"wat") + + pgconn.finish() + with pytest.raises(pq.PQerror): + pgconn.put_copy_data(b"wat") + + +def test_put_end_no_copy(pgconn): + with pytest.raises(pq.PQerror): + pgconn.put_copy_end() + + pgconn.finish() + with pytest.raises(pq.PQerror): + pgconn.put_copy_end() + + +def test_copy_out(pgconn): + ensure_table(pgconn, sample_tabledef) + res = pgconn.exec_(b"copy copy_in from stdin") + assert res.status == pq.ExecStatus.COPY_IN + + for i in range(10): + data = [] + for j in range(20): + data.append( + f"""\ +{i * 20 + j}\t{j}\t{'X' * (i * 20 + j)} +""" + ) + rv = pgconn.put_copy_data("".join(data).encode("ascii")) + assert rv > 0 + + rv = pgconn.put_copy_end() + assert rv > 0 + + res = pgconn.get_result() + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.exec_( + b"select min(col1), max(col1), count(*), max(length(data)) from copy_in" + ) + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.get_value(0, 0) == b"0" + assert res.get_value(0, 1) == b"199" + assert res.get_value(0, 2) == b"200" + assert res.get_value(0, 3) == b"199" + + +def test_copy_out_err(pgconn): + ensure_table(pgconn, sample_tabledef) + res = pgconn.exec_(b"copy copy_in from stdin") + assert res.status == pq.ExecStatus.COPY_IN + + for i in range(10): + data = [] + for j in range(20): + data.append( + f"""\ +{i * 20 + j}\thardly a number\tnope +""" + ) + rv = pgconn.put_copy_data("".join(data).encode("ascii")) + assert rv > 0 + + rv = pgconn.put_copy_end() + assert rv > 0 + + res = pgconn.get_result() + assert res.status == pq.ExecStatus.FATAL_ERROR + assert b"hardly a number" in res.error_message + + res = pgconn.exec_(b"select count(*) from copy_in") + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.get_value(0, 0) == b"0" + + +def test_copy_out_error_end(pgconn): + ensure_table(pgconn, sample_tabledef) + res = pgconn.exec_(b"copy copy_in from stdin") + assert res.status == pq.ExecStatus.COPY_IN + + for i in range(10): + data = [] + for j in range(20): + data.append( + f"""\ +{i * 20 + j}\t{j}\t{'X' * (i * 20 + j)} +""" + ) + rv = pgconn.put_copy_data("".join(data).encode("ascii")) + assert rv > 0 + + rv = pgconn.put_copy_end(b"nuttengoggenio") + assert rv > 0 + + res = pgconn.get_result() + assert res.status == pq.ExecStatus.FATAL_ERROR + assert b"nuttengoggenio" in res.error_message + + res = pgconn.exec_(b"select count(*) from copy_in") + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.get_value(0, 0) == b"0" + + +def ensure_table(pgconn, tabledef, name="copy_in"): + pgconn.exec_(f"drop table if exists {name}".encode("ascii")) + pgconn.exec_(f"create table {name} ({tabledef})".encode("ascii")) diff --git a/tests/test_copy.py b/tests/test_copy.py index 1e68b61d0..49635fc40 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -23,6 +23,7 @@ sample_binary = """ 0000 0004 0000 0028 ffff ffff 0000 0005 776f 726c 64ff ff """ +sample_binary = bytes.fromhex("".join(sample_binary.split())) def set_sample_attributes(res, format): @@ -34,6 +35,7 @@ def set_sample_attributes(res, format): res.set_attributes(attrs) +@pytest.mark.xfail @pytest.mark.parametrize( "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], @@ -46,6 +48,7 @@ def test_load_noinfo(conn, format, buffer): assert records == as_bytes(sample_records) +@pytest.mark.xfail @pytest.mark.parametrize( "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], @@ -61,6 +64,7 @@ def test_load(conn, format, buffer): assert records == sample_records +@pytest.mark.xfail @pytest.mark.parametrize( "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], @@ -79,6 +83,7 @@ def test_dump(conn, format, buffer): assert copy.get_buffer() is None +@pytest.mark.xfail @pytest.mark.parametrize( "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], @@ -90,6 +95,7 @@ def test_buffers(format, buffer): assert list(copy.buffers(sample_records)) == [globals()[buffer]] +@pytest.mark.xfail @pytest.mark.parametrize( "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], @@ -102,6 +108,7 @@ def test_copy_out_read(conn, format, buffer): assert copy.read() is None +@pytest.mark.xfail @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) def test_iter(conn, format): cur = conn.cursor() @@ -118,11 +125,12 @@ def test_copy_in_buffers(conn, format, buffer): ensure_table(cur, sample_tabledef) copy = cur.copy(f"copy copy_in from stdin (format {format.name})") copy.write(globals()[buffer]) - copy.end() + copy.finish() data = cur.execute("select * from copy_in order by 1").fetchall() assert data == sample_records +@pytest.mark.xfail @pytest.mark.parametrize( "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], @@ -137,6 +145,7 @@ def test_copy_in_buffers_with(conn, format, buffer): assert data == sample_records +@pytest.mark.xfail @pytest.mark.parametrize( "format", [(Format.TEXT,), (Format.BINARY,)], ) diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py new file mode 100644 index 000000000..f6b414c2e --- /dev/null +++ b/tests/test_copy_async.py @@ -0,0 +1,43 @@ +import pytest + +from psycopg3.adapt import Format + +from .test_copy import sample_text, sample_binary # noqa +from .test_copy import sample_values, sample_records, sample_tabledef + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.xfail +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +async def test_copy_out_read(aconn, format, buffer): + cur = aconn.cursor() + copy = await cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) + assert await copy.read() == globals()[buffer] + assert await copy.read() is None + assert await copy.read() is None + + +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +async def test_copy_in_buffers(aconn, format, buffer): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + copy = await cur.copy(f"copy copy_in from stdin (format {format.name})") + await copy.write(globals()[buffer]) + await copy.finish() + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +async def ensure_table(cur, tabledef, name="copy_in"): + await cur.execute(f"drop table if exists {name}") + await cur.execute(f"create table {name} ({tabledef})")