.. automethod:: write
.. automethod:: read
- Instead of using `!read()` you can even iterate on the object to read
- its data row by row, using ``for row in copy: ...``.
+ Instead of using `!read()` you can iterate on the `!Copy` object to
+ read its data row by row, using ``for row in copy: ...``.
.. automethod:: rows
+
+ Equivalent of iterating on `read_row()` until it returns `!None`
+
.. automethod:: read_row
.. automethod:: set_types
.. automethod:: write
.. automethod:: read
- Instead of using `!read()` you can even iterate on the object to read
- its data row by row, using ``async for row in copy: ...``.
+ Instead of using `!read()` you can iterate on the `!AsyncCopy` object
+ to read its data row by row, using ``async for row in copy: ...``.
.. automethod:: rows
from types import TracebackType
from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union
from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple
-from typing_extensions import Protocol
from . import pq
from . import errors as e
from .connection import Connection, AsyncConnection # noqa: F401
-class CopyFormatFunc(Protocol):
- """The type of a function to format copy data to a bytearray."""
-
- def __call__(
- self,
- row: Sequence[Any],
- tx: Transformer,
- out: Optional[bytearray] = None,
- ) -> bytearray:
- ...
-
-
-class CopyParseFunc(Protocol):
- def __call__(self, data: bytes, tx: Transformer) -> Tuple[Any, ...]:
- ...
-
-
class BaseCopy(Generic[ConnectionType]):
def __init__(self, cursor: "BaseCursor[ConnectionType]"):
self.cursor = cursor
self._write_buffer_size = 32 * 1024
self._finished = False
- self._format_row: CopyFormatFunc
- self._parse_row: CopyParseFunc
if self.format == Format.TEXT:
self._format_row = format_row_text
self._parse_row = parse_row_text
# High level copy protocol generators (state change of the Copy object)
- def _read_gen(self) -> PQGen[bytes]:
+ def _read_gen(self) -> PQGen[memoryview]:
if self._finished:
- return b""
+ return memoryview(b"")
res = yield from copy_from(self._pgconn)
if isinstance(res, memoryview):
self._finished = True
nrows = res.command_tuples
self.cursor._rowcount = nrows if nrows is not None else -1
- return b""
+ return memoryview(b"")
def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]:
data = yield from self._read_gen()
__module__ = "psycopg3"
- def read(self) -> bytes:
+ def read(self) -> memoryview:
"""
- Read an unparsed row from a table after a :sql:`COPY TO` operation.
+ Read an unparsed row after a :sql:`COPY TO` operation.
- Return an empty bytes string when the data is finished.
+ Return an empty string when the data is finished.
"""
return self.connection.wait(self._read_gen())
) -> None:
self.connection.wait(self._exit_gen(exc_type, exc_val))
- def __iter__(self) -> Iterator[bytes]:
+ def __iter__(self) -> Iterator[memoryview]:
while True:
data = self.read()
if not data:
__module__ = "psycopg3"
- async def read(self) -> bytes:
+ async def read(self) -> memoryview:
return await self.connection.wait(self._read_gen())
async def rows(self) -> AsyncIterator[Tuple[Any, ...]]:
) -> None:
await self.connection.wait(self._exit_gen(exc_type, exc_val))
- async def __aiter__(self) -> AsyncIterator[bytes]:
+ async def __aiter__(self) -> AsyncIterator[memoryview]:
while True:
data = await self.read()
if not data:
return tx.load_sequence(row)
-def parse_row_binary(data: bytes, tx: Transformer) -> Tuple[Any, ...]:
+def _parse_row_binary(data: bytes, tx: Transformer) -> Tuple[Any, ...]:
row: List[Optional[bytes]] = []
nfields = _unpack_int2(data, 0)[0]
pos = 2
return __map[m.group(0)]
-# Override it with fast object if available
-
-format_row_binary: CopyFormatFunc
-parse_row_text: CopyParseFunc
-
+# Override functions with fast versions if available
if pq.__impl__ == "c":
from psycopg3_c import _psycopg3
format_row_text = _psycopg3.format_row_text
format_row_binary = _psycopg3.format_row_binary
parse_row_text = _psycopg3.parse_row_text
+ parse_row_binary = _psycopg3.parse_row_binary
else:
format_row_text = _format_row_text
format_row_binary = _format_row_binary
parse_row_text = _parse_row_text
+ parse_row_binary = _parse_row_binary
row: Sequence[Any], tx: proto.Transformer, out: Optional[bytearray] = None
) -> bytearray: ...
def parse_row_text(data: bytes, tx: proto.Transformer) -> Tuple[Any, ...]: ...
+def parse_row_binary(
+ data: bytes, tx: proto.Transformer
+) -> Tuple[Any, ...]: ...
# vim: set syntax=python:
from cpython.bytearray cimport PyByteArray_AS_STRING, PyByteArray_GET_SIZE
from cpython.memoryview cimport PyMemoryView_FromObject
-from psycopg3_c._psycopg3.endian cimport htobe16, htobe32
+from psycopg3_c._psycopg3 cimport endian
from psycopg3_c.pq cimport ViewBuffer
+from psycopg3 import errors as e
+
cdef int32_t _binary_null = -1
) -> bytearray:
"""Convert a row of adapted data to the data to send for binary copy"""
cdef Py_ssize_t rowlen = len(row)
- cdef uint16_t berowlen = htobe16(rowlen)
+ cdef uint16_t berowlen = endian.htobe16(rowlen)
cdef Py_ssize_t pos # offset in 'out' where to write
if out is None:
# A cdumper can resize if necessary and copy in place
size = (<CDumper>dumper).cdump(item, out, pos + sizeof(besize))
# Also add the size of the item, before the item
- besize = htobe32(size)
+ besize = endian.htobe32(size)
target = PyByteArray_AS_STRING(out) # might have been moved by cdump
memcpy(target + pos, <void *>&besize, sizeof(besize))
else:
b = PyObject_CallFunctionObjArgs(dumper.dump, <PyObject *>item, NULL)
_buffer_as_string_and_size(b, &buf, &size)
target = CDumper.ensure_size(out, pos, size + sizeof(besize))
- besize = htobe32(size)
+ besize = endian.htobe32(size)
memcpy(target, <void *>&besize, sizeof(besize))
memcpy(target + sizeof(besize), buf, size)
return out
+def parse_row_binary(data, tx: Transformer) -> Tuple[Any, ...]:
+ cdef unsigned char *ptr
+ cdef Py_ssize_t bufsize
+ _buffer_as_string_and_size(data, <char **>&ptr, &bufsize)
+ cdef unsigned char *bufend = ptr + bufsize
+
+ cdef uint16_t benfields = (<uint16_t *>ptr)[0]
+ cdef int nfields = endian.be16toh(benfields)
+ ptr += sizeof(benfields)
+ cdef list row = PyList_New(nfields)
+
+ cdef int col
+ cdef int32_t belength
+ cdef Py_ssize_t length
+
+ for col in range(nfields):
+ memcpy(&belength, ptr, sizeof(belength))
+ ptr += sizeof(belength)
+ if belength == _binary_null:
+ field = None
+ else:
+ length = endian.be32toh(belength)
+ if ptr + length > bufend:
+ raise e.DataError("bad copy data: length exceeding data")
+ field = PyMemoryView_FromObject(
+ ViewBuffer._from_buffer(ptr, length))
+ ptr += length
+
+ Py_INCREF(field)
+ PyList_SET_ITEM(row, col, field)
+
+ return tx.load_sequence(row)
+
+
def parse_row_text(data, tx: Transformer) -> Tuple[Any, ...]:
cdef unsigned char *fstart
cdef Py_ssize_t size
cdef unsigned char *rowend = fstart + size
cdef unsigned char *src
cdef unsigned char *tgt
- cdef int col = 0
+ cdef int col
cdef int num_bs
+
for col in range(nfields):
fend = fstart
num_bs = 0
# Check if we stopped for the right reason
if fend >= rowend:
- raise ValueError("bad copy format, field delimiter not found")
+ raise e.DataError("bad copy data: field delimiter not found")
elif fend[0] == b'\t' and col == nfields - 1:
- raise ValueError("bad copy format, got a tab at the end of the row")
+ raise e.DataError("bad copy data: got a tab at the end of the row")
elif fend[0] == b'\n' and col != nfields - 1:
- raise ValueError(
- "bad copy format, got a newline before the end of the row")
+ raise e.DataError(
+ "bad copy format: got a newline before the end of the row")
# Is this a NULL?
if fend - fstart == 2 and fstart[0] == b'\\' and fstart[1] == b'N':
return record
- def load_sequence(self, record: Sequence[Optional[bytes]]) -> Tuple[Any, ...]:
+ cpdef object load_sequence(self, record: Sequence[Optional[bytes]]):
cdef int nfields = len(record)
out = PyTuple_New(nfields)
cdef PyObject *loader # borrowed RowLoader
gen.assert_data()
- f = StringIO()
+ f = BytesIO()
with cur.copy("copy copy_in to stdout") as copy:
for block in copy:
- f.write(block.decode("utf8"))
+ f.write(block)
f.seek(0)
assert gen.sha(f) == gen.sha(gen.file())
await gen.assert_data()
- f = StringIO()
+ f = BytesIO()
async with cur.copy("copy copy_in to stdout") as copy:
async for block in copy:
- f.write(block.decode("utf8"))
+ f.write(block)
f.seek(0)
assert gen.sha(f) == gen.sha(gen.file())