# Copyright (C) 2020 The Psycopg Team
import re
+from typing import cast, TYPE_CHECKING
from typing import Any, Deque, Dict, List, Match, Optional, Tuple
from collections import deque
-from .proto import AdaptContext
-from . import errors as e
from . import pq
+from . import errors as e
+from .proto import AdaptContext
+from .generators import copy_to, copy_end
+
+if TYPE_CHECKING:
+ from .connection import Connection, AsyncConnection
class BaseCopy:
class Copy(BaseCopy):
- pass
+ def __init__(
+ self,
+ context: AdaptContext,
+ result: Optional[pq.proto.PGresult],
+ format: pq.Format = pq.Format.TEXT,
+ ):
+ super().__init__(context=context, result=result, format=format)
+ self._connection: Optional["Connection"] = None
+
+ @property
+ def connection(self) -> "Connection":
+ if self._connection is None:
+ conn = self._transformer.connection
+ if conn is None:
+ raise ValueError("no connection available")
+ self._connection = cast("Connection", conn)
+
+ return self._connection
+
+ def write(self, buffer: bytes) -> None:
+ conn = self.connection
+ conn.wait(copy_to(conn.pgconn, buffer))
+
+ def finish(self, error: Optional[str] = None) -> None:
+ conn = self.connection
+ berr = (
+ conn.codec.encode(error, "replace")[0]
+ if error is not None
+ else None
+ )
+ result = conn.wait(copy_end(conn.pgconn, berr))
+ if result.status != pq.ExecStatus.COMMAND_OK:
+ raise e.error_from_result(
+ result, encoding=self.connection.codec.name
+ )
class AsyncCopy(BaseCopy):
- pass
+ def __init__(
+ self,
+ context: AdaptContext,
+ result: Optional[pq.proto.PGresult],
+ format: pq.Format = pq.Format.TEXT,
+ ):
+ super().__init__(context=context, result=result, format=format)
+ self._connection: Optional["AsyncConnection"] = None
+
+ @property
+ def connection(self) -> "AsyncConnection":
+ if self._connection is None:
+ conn = self._transformer.connection
+ if conn is None:
+ raise ValueError("no connection available")
+ self._connection = cast("AsyncConnection", conn)
+
+ return self._connection
+
+ async def write(self, buffer: bytes) -> None:
+ conn = self.connection
+ await conn.wait(copy_to(conn.pgconn, buffer))
+
+ async def finish(self, error: Optional[str] = None) -> None:
+ conn = self.connection
+ berr = (
+ conn.codec.encode(error, "replace")[0]
+ if error is not None
+ else None
+ )
+ await conn.wait(copy_end(conn.pgconn, berr))
# Copyright (C) 2020 The Psycopg Team
import logging
-from typing import List
+from typing import List, Optional
from . import pq
from . import errors as e
break
return ns
+
+
+def copy_to(pgconn: pq.proto.PGconn, buffer: bytes) -> PQGen[None]:
+ # Retry enqueuing data until successful
+ while pgconn.put_copy_data(buffer) == 0:
+ yield pgconn.socket, Wait.W
+
+
+def copy_end(
+ pgconn: pq.proto.PGconn, error: Optional[bytes]
+) -> PQGen[pq.proto.PGresult]:
+ # Retry enqueuing end copy message until successful
+ while pgconn.put_copy_end(error) == 0:
+ yield pgconn.socket, Wait.W
+
+ # Repeat until it the message is flushed to the server
+ while 1:
+ yield pgconn.socket, Wait.W
+ f = pgconn.flush()
+ if f == 0:
+ break
+
+ # Retrieve the final result of copy
+ results = yield from fetch(pgconn)
+ if len(results) == 1:
+ return results[0]
+ else:
+ raise e.InternalError(
+ f"1 result expected from copy end, got {len(results)}"
+ )
PQnotifies.restype = PGnotify_ptr
+# 33.9. Functions Associated with the COPY Command
+
+PQputCopyData = pq.PQputCopyData
+PQputCopyData.argtypes = [PGconn_ptr, c_char_p, c_int]
+PQputCopyData.restype = c_int
+
+PQputCopyEnd = pq.PQputCopyEnd
+PQputCopyEnd.argtypes = [PGconn_ptr, c_char_p]
+PQputCopyEnd.restype = c_int
+
+
# 33.11. Miscellaneous Functions
PQfreemem = pq.PQfreemem
def PQnotifies(
arg1: Optional[PGconn_struct],
) -> Optional[pointer[PGnotify_struct]]: ... # type: ignore
+def PQputCopyEnd(
+ arg1: Optional[PGconn_struct], arg2: Optional[bytes]
+) -> int: ...
def PQsetResultAttrs(
arg1: Optional[PGresult_struct], arg2: int, arg3: Array[PGresAttDesc_struct] # type: ignore
) -> int: ...
def PQflush(arg1: Optional[PGconn_struct]) -> int: ...
def PQgetCancel(arg1: Optional[PGconn_struct]) -> PGcancel_struct: ...
def PQfreeCancel(arg1: Optional[PGcancel_struct]) -> None: ...
+def PQputCopyData(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> int: ...
def PQfreemem(arg1: Any) -> None: ...
def PQmakeEmptyPGresult(arg1: Optional[PGconn_struct], arg2: int) -> PGresult_struct: ...
# autogenerated: end
# 33.8. Asynchronous Notification
PGnotify *PQnotifies(PGconn *conn)
+ # 33.9. Functions Associated with the COPY Command
+ int PQputCopyData(PGconn *conn, const char *buffer, int nbytes)
+ int PQputCopyEnd(PGconn *conn, const char *errormsg)
+
# 33.11. Miscellaneous Functions
void PQfreemem(void *ptr)
void PQconninfoFree(PQconninfoOption *connOptions)
else:
return None
+ def put_copy_data(self, buffer: bytes) -> int:
+ rv = impl.PQputCopyData(self.pgconn_ptr, buffer, len(buffer))
+ if rv < 0:
+ raise PQerror(f"sending copy data failed: {error_message(self)}")
+ return rv
+
+ def put_copy_end(self, error: Optional[bytes] = None) -> int:
+ rv = impl.PQputCopyEnd(self.pgconn_ptr, error)
+ if rv < 0:
+ raise PQerror(f"sending copy end failed: {error_message(self)}")
+ return rv
+
def make_empty_result(self, exec_status: ExecStatus) -> "PGresult":
rv = impl.PQmakeEmptyPGresult(self.pgconn_ptr, exec_status)
if not rv:
from posix.unistd cimport getpid
from cpython.mem cimport PyMem_Malloc, PyMem_Free
+from cpython.bytes cimport PyBytes_AsString
import logging
from typing import List, Optional, Sequence
def send_query(self, command: bytes) -> None:
self._ensure_pgconn()
if not impl.PQsendQuery(self.pgconn_ptr, command):
- raise PQerror(
- "sending query failed:"
- f" {error_message(self)}"
- )
+ raise PQerror(f"sending query failed: {error_message(self)}")
def exec_params(
self,
_clear_query_params(ctypes, cvalues, clengths, cformats)
if not rv:
raise PQerror(
- "sending query and params failed:"
- f" {error_message(self)}"
+ f"sending query and params failed: {error_message(self)}"
)
def send_prepare(
PyMem_Free(atypes)
if not rv:
raise PQerror(
- "sending query and params failed:"
- f" {error_message(self)}"
+ f"sending query and params failed: {error_message(self)}"
)
def send_query_prepared(
_clear_query_params(ctypes, cvalues, clengths, cformats)
if not rv:
raise PQerror(
- "sending prepared query failed:"
- f" {error_message(self)}"
+ f"sending prepared query failed: {error_message(self)}"
)
def prepare(
def consume_input(self) -> None:
if 1 != impl.PQconsumeInput(self.pgconn_ptr):
- raise PQerror(
- "consuming input failed:"
- f" {error_message(self)}"
- )
+ raise PQerror(f"consuming input failed: {error_message(self)}")
def is_busy(self) -> int:
return impl.PQisBusy(self.pgconn_ptr)
@nonblocking.setter
def nonblocking(self, arg: int) -> None:
if 0 > impl.PQsetnonblocking(self.pgconn_ptr, arg):
- raise PQerror(
- f"setting nonblocking failed:"
- f" {error_message(self)}"
- )
+ raise PQerror(f"setting nonblocking failed: {error_message(self)}")
def flush(self) -> int:
cdef int rv = impl.PQflush(self.pgconn_ptr)
if rv < 0:
- raise PQerror(
- f"flushing failed:{error_message(self)}"
- )
+ raise PQerror(f"flushing failed:{error_message(self)}")
return rv
def get_cancel(self) -> PGcancel:
else:
return None
+ def put_copy_data(self, buffer: bytes) -> int:
+ cdef int rv
+ cdef const char *cbuffer = PyBytes_AsString(buffer)
+ cdef int length = len(buffer)
+ rv = impl.PQputCopyData(self.pgconn_ptr, cbuffer, length)
+ if rv < 0:
+ raise PQerror(f"sending copy data failed: {error_message(self)}")
+ return rv
+
+ def put_copy_end(self, error: Optional[bytes] = None) -> int:
+ cdef int rv
+ cdef const char *cerr = NULL
+ if error is not None:
+ cerr = PyBytes_AsString(error)
+ rv = impl.PQputCopyEnd(self.pgconn_ptr, cerr)
+ if rv < 0:
+ raise PQerror(f"sending copy end failed: {error_message(self)}")
+ return rv
+
def make_empty_result(self, exec_status: ExecStatus) -> PGresult:
cdef impl.PGresult *rv = impl.PQmakeEmptyPGresult(
self.pgconn_ptr, exec_status)
def notifies(self) -> Optional["PGnotify"]:
...
+ def put_copy_data(self, buffer: bytes) -> int:
+ ...
+
+ def put_copy_end(self, error: Optional[bytes] = None) -> int:
+ ...
+
def make_empty_result(self, exec_status: ExecStatus) -> "PGresult":
...
W = EVENT_WRITE
-def wait(gen: "PQGen[RV]", timeout: Optional[float] = None) -> "RV":
+def wait(gen: PQGen[RV], timeout: Optional[float] = None) -> RV:
"""
Wait for a generator using the best option available on the platform.
fd, s = gen.send(ready[0][1])
except StopIteration as ex:
- rv: "RV" = ex.args[0]
+ rv: RV = ex.args[0] if ex.args else None
return rv
-async def wait_async(gen: "PQGen[RV]") -> "RV":
+async def wait_async(gen: PQGen[RV]) -> RV:
"""
Coroutine waiting for a generator to complete.
fd, s = gen.send(ready)
except StopIteration as ex:
- rv: "RV" = ex.args[0]
+ rv: RV = ex.args[0] if ex.args else None
return rv
--- /dev/null
+import pytest
+
+from psycopg3 import pq
+
+sample_tabledef = "col1 int primary key, col2 int, data text"
+
+
+def test_put_data_no_copy(pgconn):
+ with pytest.raises(pq.PQerror):
+ pgconn.put_copy_data(b"wat")
+
+ pgconn.finish()
+ with pytest.raises(pq.PQerror):
+ pgconn.put_copy_data(b"wat")
+
+
+def test_put_end_no_copy(pgconn):
+ with pytest.raises(pq.PQerror):
+ pgconn.put_copy_end()
+
+ pgconn.finish()
+ with pytest.raises(pq.PQerror):
+ pgconn.put_copy_end()
+
+
+def test_copy_out(pgconn):
+ ensure_table(pgconn, sample_tabledef)
+ res = pgconn.exec_(b"copy copy_in from stdin")
+ assert res.status == pq.ExecStatus.COPY_IN
+
+ for i in range(10):
+ data = []
+ for j in range(20):
+ data.append(
+ f"""\
+{i * 20 + j}\t{j}\t{'X' * (i * 20 + j)}
+"""
+ )
+ rv = pgconn.put_copy_data("".join(data).encode("ascii"))
+ assert rv > 0
+
+ rv = pgconn.put_copy_end()
+ assert rv > 0
+
+ res = pgconn.get_result()
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.exec_(
+ b"select min(col1), max(col1), count(*), max(length(data)) from copy_in"
+ )
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.get_value(0, 0) == b"0"
+ assert res.get_value(0, 1) == b"199"
+ assert res.get_value(0, 2) == b"200"
+ assert res.get_value(0, 3) == b"199"
+
+
+def test_copy_out_err(pgconn):
+ ensure_table(pgconn, sample_tabledef)
+ res = pgconn.exec_(b"copy copy_in from stdin")
+ assert res.status == pq.ExecStatus.COPY_IN
+
+ for i in range(10):
+ data = []
+ for j in range(20):
+ data.append(
+ f"""\
+{i * 20 + j}\thardly a number\tnope
+"""
+ )
+ rv = pgconn.put_copy_data("".join(data).encode("ascii"))
+ assert rv > 0
+
+ rv = pgconn.put_copy_end()
+ assert rv > 0
+
+ res = pgconn.get_result()
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ assert b"hardly a number" in res.error_message
+
+ res = pgconn.exec_(b"select count(*) from copy_in")
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.get_value(0, 0) == b"0"
+
+
+def test_copy_out_error_end(pgconn):
+ ensure_table(pgconn, sample_tabledef)
+ res = pgconn.exec_(b"copy copy_in from stdin")
+ assert res.status == pq.ExecStatus.COPY_IN
+
+ for i in range(10):
+ data = []
+ for j in range(20):
+ data.append(
+ f"""\
+{i * 20 + j}\t{j}\t{'X' * (i * 20 + j)}
+"""
+ )
+ rv = pgconn.put_copy_data("".join(data).encode("ascii"))
+ assert rv > 0
+
+ rv = pgconn.put_copy_end(b"nuttengoggenio")
+ assert rv > 0
+
+ res = pgconn.get_result()
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ assert b"nuttengoggenio" in res.error_message
+
+ res = pgconn.exec_(b"select count(*) from copy_in")
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.get_value(0, 0) == b"0"
+
+
+def ensure_table(pgconn, tabledef, name="copy_in"):
+ pgconn.exec_(f"drop table if exists {name}".encode("ascii"))
+ pgconn.exec_(f"create table {name} ({tabledef})".encode("ascii"))
0000 0004 0000 0028 ffff ffff 0000 0005
776f 726c 64ff ff
"""
+sample_binary = bytes.fromhex("".join(sample_binary.split()))
def set_sample_attributes(res, format):
res.set_attributes(attrs)
+@pytest.mark.xfail
@pytest.mark.parametrize(
"format, buffer",
[(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
assert records == as_bytes(sample_records)
+@pytest.mark.xfail
@pytest.mark.parametrize(
"format, buffer",
[(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
assert records == sample_records
+@pytest.mark.xfail
@pytest.mark.parametrize(
"format, buffer",
[(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
assert copy.get_buffer() is None
+@pytest.mark.xfail
@pytest.mark.parametrize(
"format, buffer",
[(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
assert list(copy.buffers(sample_records)) == [globals()[buffer]]
+@pytest.mark.xfail
@pytest.mark.parametrize(
"format, buffer",
[(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
assert copy.read() is None
+@pytest.mark.xfail
@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
def test_iter(conn, format):
cur = conn.cursor()
ensure_table(cur, sample_tabledef)
copy = cur.copy(f"copy copy_in from stdin (format {format.name})")
copy.write(globals()[buffer])
- copy.end()
+ copy.finish()
data = cur.execute("select * from copy_in order by 1").fetchall()
assert data == sample_records
+@pytest.mark.xfail
@pytest.mark.parametrize(
"format, buffer",
[(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
assert data == sample_records
+@pytest.mark.xfail
@pytest.mark.parametrize(
"format", [(Format.TEXT,), (Format.BINARY,)],
)
--- /dev/null
+import pytest
+
+from psycopg3.adapt import Format
+
+from .test_copy import sample_text, sample_binary # noqa
+from .test_copy import sample_values, sample_records, sample_tabledef
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.mark.xfail
+@pytest.mark.parametrize(
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+async def test_copy_out_read(aconn, format, buffer):
+ cur = aconn.cursor()
+ copy = await cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ )
+ assert await copy.read() == globals()[buffer]
+ assert await copy.read() is None
+ assert await copy.read() is None
+
+
+@pytest.mark.parametrize(
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+async def test_copy_in_buffers(aconn, format, buffer):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ copy = await cur.copy(f"copy copy_in from stdin (format {format.name})")
+ await copy.write(globals()[buffer])
+ await copy.finish()
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+async def ensure_table(cur, tabledef, name="copy_in"):
+ await cur.execute(f"drop table if exists {name}")
+ await cur.execute(f"create table {name} ({tabledef})")