fmt = result.fformat(i) if format is None else format
rc.append(self.get_loader(oid, fmt).load) # type: ignore
- def set_row_types(
- self, types: Sequence[int], formats: Sequence[pq.Format]
+ def set_dumper_types(
+ self, types: Sequence[int], format: pq.Format
) -> None:
- rc: List[LoadFunc] = []
+ dumpers: List[Optional["Dumper"]] = []
for i in range(len(types)):
- rc.append(self.get_loader(types[i], formats[i]).load)
+ dumpers.append(self.get_dumper_by_oid(types[i], format))
- self._row_loaders = rc
+ self._row_dumpers = dumpers
+
+ def set_loader_types(
+ self, types: Sequence[int], format: pq.Format
+ ) -> None:
+ loaders: List[LoadFunc] = []
+ for i in range(len(types)):
+ loaders.append(self.get_loader(types[i], format).load)
+
+ self._row_loaders = loaders
def dump_sequence(
self, params: Sequence[Any], formats: Sequence[PyFormat]
) -> None:
...
- def set_row_types(
- self, types: Sequence[int], formats: Sequence[pq.Format]
+ def set_dumper_types(
+ self, types: Sequence[int], format: pq.Format
+ ) -> None:
+ ...
+
+ def set_loader_types(
+ self, types: Sequence[int], format: pq.Format
) -> None:
...
oids = [
t if isinstance(t, int) else registry.get_oid(t) for t in types
]
- self.formatter.transformer.set_row_types(
- oids, [self.formatter.format] * len(types)
- )
+
+ if self._pgresult.status == ExecStatus.COPY_IN:
+ self.formatter.transformer.set_dumper_types(
+ oids, self.formatter.format
+ )
+ else:
+ self.formatter.transformer.set_loader_types(
+ oids, self.formatter.format
+ )
# High level copy protocol generators (state change of the Copy object)
out = bytearray()
out += _pack_int2(len(row))
- for item in row:
- if item is not None:
- dumper = tx.get_dumper(item, PyFormat.BINARY)
- b = dumper.dump(item)
+ adapted, _, _ = tx.dump_sequence(row, [PyFormat.BINARY] * len(row))
+ for b in adapted:
+ if b is not None:
out += _pack_int4(len(b))
out += b
else:
from typing import Sequence, Tuple, Type
from .. import pq
-from .. import errors as e
from .. import postgres
from ..abc import AdaptContext, Buffer
from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader
# Subclasses must set an info
info: CompositeInfo
- def dump(self, obj: Tuple[Any, ...]) -> bytearray:
-
- if len(obj) != len(self.info.field_types):
- raise e.DataError(
- f"expected a sequence of {len(self.info.field_types)} items,"
- f" got {len(obj)}"
- )
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ nfields = len(self.info.field_types)
+ self._tx.set_dumper_types(self.info.field_types, self.format)
+ self._formats = [PyFormat.from_pq(self.format)] * nfields
+ def dump(self, obj: Tuple[Any, ...]) -> bytearray:
out = bytearray(pack_len(len(obj)))
- get_dumper = self._tx.get_dumper_by_oid
+ adapted, _, _ = self._tx.dump_sequence(obj, self._formats)
for i in range(len(obj)):
- item = obj[i]
+ b = adapted[i]
oid = self.info.field_types[i]
- if item is not None:
- dumper = get_dumper(oid, self.format)
- b = dumper.dump(item)
+ if b is not None:
out += _pack_oidlen(oid, len(b))
out += b
else:
def _config_types(self, data: bytes) -> None:
oids = [r[0] for r in self._walk_record(data)]
- self._tx.set_row_types(oids, [pq.Format.BINARY] * len(oids))
+ self._tx.set_loader_types(oids, self.format)
class CompositeLoader(RecordLoader):
)
def _config_types(self, data: bytes) -> None:
- self._tx.set_row_types(
- self.fields_types, [pq.Format.TEXT] * len(self.fields_types)
- )
+ self._tx.set_loader_types(self.fields_types, self.format)
class CompositeBinaryLoader(RecordBinaryLoader):
set_loaders: bool = True,
format: Optional[pq.Format] = None,
) -> None: ...
- def set_row_types(
- self, types: Sequence[int], formats: Sequence[pq.Format]
+ def set_dumper_types(
+ self, types: Sequence[int], format: pq.Format
+ ) -> None: ...
+ def set_loader_types(
+ self, types: Sequence[int], format: pq.Format
) -> None: ...
def dump_sequence(
self, params: Sequence[Any], formats: Sequence[PyFormat]
cdef int i
cdef object tmp
cdef list types
- cdef list formats
+ if format is None:
+ format = libpq.PQfformat(res, 0)
+
if set_loaders:
types = PyList_New(self._nfields)
- formats = PyList_New(self._nfields)
for i in range(self._nfields):
tmp = libpq.PQftype(res, i)
Py_INCREF(tmp)
PyList_SET_ITEM(types, i, tmp)
- tmp = libpq.PQfformat(res, i) if format is None else format
- Py_INCREF(tmp)
- PyList_SET_ITEM(formats, i, tmp)
-
- self._c_set_row_types(self._nfields, types, formats)
+ self._c_loader_types(self._nfields, types, format)
- def set_row_types(self,
- types: Sequence[int], formats: Sequence[Format]) -> None:
- self._c_set_row_types(len(types), types, formats)
+ def set_loader_types(self,
+ types: Sequence[int], format: Format) -> None:
+ self._c_loader_types(len(types), types, format)
- cdef void _c_set_row_types(self, Py_ssize_t ntypes, list types, list formats):
+ cdef void _c_loader_types(self, Py_ssize_t ntypes, list types, int format):
cdef list loaders = PyList_New(ntypes)
# these are used more as Python object than C
cdef PyObject *oid
- cdef PyObject *fmt
cdef PyObject *row_loader
for i in range(ntypes):
oid = PyList_GET_ITEM(types, i)
- fmt = PyList_GET_ITEM(formats, i)
- row_loader = self._c_get_loader(oid, fmt)
+ row_loader = self._c_get_loader(oid, <PyObject *>format)
Py_INCREF(<object>row_loader)
PyList_SET_ITEM(loaders, i, <object>row_loader)
eur = "\u20ac"
-sample_records = [(Int4(10), Int4(20), "hello"), (Int4(40), None, "world")]
+sample_records = [(10, 20, "hello"), (40, None, "world")]
sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')"
want = [row + b"\n" for row in sample_text.splitlines()]
else:
want = sample_binary_rows
+
cur = conn.cursor()
with cur.copy(
f"copy ({sample_values}) to stdout (format {format.name})"
ensure_table(cur, sample_tabledef)
with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ for row in sample_records:
+ if format == Format.BINARY:
+ row = tuple(Int4(i) if isinstance(i, int) else i for i in row)
+ copy.write_row(row)
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_copy_in_records_set_types(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ copy.set_types(["int4", "int4", "text"])
for row in sample_records:
copy.write_row(row)
from psycopg.types import TypeInfo
from psycopg.adapt import PyFormat as PgFormat
from psycopg.types.hstore import register_hstore
+from psycopg.types.numeric import Int4
from .utils import gc_collect
from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa
async with cur.copy(
f"copy copy_in from stdin (format {format.name})"
) as copy:
+ for row in sample_records:
+ if format == Format.BINARY:
+ row = tuple(Int4(i) if isinstance(i, int) else i for i in row)
+ await copy.write_row(row)
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+async def test_copy_in_records_set_types(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ async with cur.copy(
+ f"copy copy_in from stdin (format {format.name})"
+ ) as copy:
+ copy.set_types(["int4", "int4", "text"])
for row in sample_records:
await copy.write_row(row)