From: Daniele Varrazzo Date: Fri, 27 Aug 2021 12:50:41 +0000 (+0200) Subject: Replace git_row_types() with get_dumper_types/get_loader_types() X-Git-Tag: 3.0.beta1~29^2~2 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=2ffa646c9b4333f120f5c0b627ad2e513a8fb654;p=thirdparty%2Fpsycopg.git Replace git_row_types() with get_dumper_types/get_loader_types() dump_sequence() can make use of the dumpers pre-set by get_dumper_types() and is now used in composite binary dump and in binary COPY FROM. The interface should be better iterated because for the latter use cases the extra info (oids, formats) are just a waste of resources. Only the Python implementation has been changed so far, the C implementation will be changed down the line. --- diff --git a/psycopg/psycopg/_transform.py b/psycopg/psycopg/_transform.py index c4dbbaf52..97c74d6e4 100644 --- a/psycopg/psycopg/_transform.py +++ b/psycopg/psycopg/_transform.py @@ -109,14 +109,23 @@ class Transformer(AdaptContext): fmt = result.fformat(i) if format is None else format rc.append(self.get_loader(oid, fmt).load) # type: ignore - def set_row_types( - self, types: Sequence[int], formats: Sequence[pq.Format] + def set_dumper_types( + self, types: Sequence[int], format: pq.Format ) -> None: - rc: List[LoadFunc] = [] + dumpers: List[Optional["Dumper"]] = [] for i in range(len(types)): - rc.append(self.get_loader(types[i], formats[i]).load) + dumpers.append(self.get_dumper_by_oid(types[i], format)) - self._row_loaders = rc + self._row_dumpers = dumpers + + def set_loader_types( + self, types: Sequence[int], format: pq.Format + ) -> None: + loaders: List[LoadFunc] = [] + for i in range(len(types)): + loaders.append(self.get_loader(types[i], format).load) + + self._row_loaders = loaders def dump_sequence( self, params: Sequence[Any], formats: Sequence[PyFormat] diff --git a/psycopg/psycopg/abc.py b/psycopg/psycopg/abc.py index 584b2ae03..fc86dc441 100644 --- a/psycopg/psycopg/abc.py +++ b/psycopg/psycopg/abc.py @@ -193,8 +193,13 @@ class Transformer(Protocol): ) -> None: ... - def set_row_types( - self, types: Sequence[int], formats: Sequence[pq.Format] + def set_dumper_types( + self, types: Sequence[int], format: pq.Format + ) -> None: + ... + + def set_loader_types( + self, types: Sequence[int], format: pq.Format ) -> None: ... diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index 0e889dacb..61338f6ce 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -96,9 +96,15 @@ class BaseCopy(Generic[ConnectionType]): oids = [ t if isinstance(t, int) else registry.get_oid(t) for t in types ] - self.formatter.transformer.set_row_types( - oids, [self.formatter.format] * len(types) - ) + + if self._pgresult.status == ExecStatus.COPY_IN: + self.formatter.transformer.set_dumper_types( + oids, self.formatter.format + ) + else: + self.formatter.transformer.set_loader_types( + oids, self.formatter.format + ) # High level copy protocol generators (state change of the Copy object) @@ -556,10 +562,9 @@ def _format_row_binary( out = bytearray() out += _pack_int2(len(row)) - for item in row: - if item is not None: - dumper = tx.get_dumper(item, PyFormat.BINARY) - b = dumper.dump(item) + adapted, _, _ = tx.dump_sequence(row, [PyFormat.BINARY] * len(row)) + for b in adapted: + if b is not None: out += _pack_int4(len(b)) out += b else: diff --git a/psycopg/psycopg/types/composite.py b/psycopg/psycopg/types/composite.py index 950142d76..89d793c91 100644 --- a/psycopg/psycopg/types/composite.py +++ b/psycopg/psycopg/types/composite.py @@ -11,7 +11,6 @@ from typing import Any, Callable, cast, Iterator, List, Optional from typing import Sequence, Tuple, Type from .. import pq -from .. import errors as e from .. import postgres from ..abc import AdaptContext, Buffer from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader @@ -77,22 +76,19 @@ class TupleBinaryDumper(RecursiveDumper): # Subclasses must set an info info: CompositeInfo - def dump(self, obj: Tuple[Any, ...]) -> bytearray: - - if len(obj) != len(self.info.field_types): - raise e.DataError( - f"expected a sequence of {len(self.info.field_types)} items," - f" got {len(obj)}" - ) + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + nfields = len(self.info.field_types) + self._tx.set_dumper_types(self.info.field_types, self.format) + self._formats = [PyFormat.from_pq(self.format)] * nfields + def dump(self, obj: Tuple[Any, ...]) -> bytearray: out = bytearray(pack_len(len(obj))) - get_dumper = self._tx.get_dumper_by_oid + adapted, _, _ = self._tx.dump_sequence(obj, self._formats) for i in range(len(obj)): - item = obj[i] + b = adapted[i] oid = self.info.field_types[i] - if item is not None: - dumper = get_dumper(oid, self.format) - b = dumper.dump(item) + if b is not None: out += _pack_oidlen(oid, len(b)) out += b else: @@ -178,7 +174,7 @@ class RecordBinaryLoader(RecursiveLoader): def _config_types(self, data: bytes) -> None: oids = [r[0] for r in self._walk_record(data)] - self._tx.set_row_types(oids, [pq.Format.BINARY] * len(oids)) + self._tx.set_loader_types(oids, self.format) class CompositeLoader(RecordLoader): @@ -201,9 +197,7 @@ class CompositeLoader(RecordLoader): ) def _config_types(self, data: bytes) -> None: - self._tx.set_row_types( - self.fields_types, [pq.Format.TEXT] * len(self.fields_types) - ) + self._tx.set_loader_types(self.fields_types, self.format) class CompositeBinaryLoader(RecordBinaryLoader): diff --git a/psycopg_c/psycopg_c/_psycopg.pyi b/psycopg_c/psycopg_c/_psycopg.pyi index 5c9fcef68..a6e07e88d 100644 --- a/psycopg_c/psycopg_c/_psycopg.pyi +++ b/psycopg_c/psycopg_c/_psycopg.pyi @@ -31,8 +31,11 @@ class Transformer(abc.AdaptContext): set_loaders: bool = True, format: Optional[pq.Format] = None, ) -> None: ... - def set_row_types( - self, types: Sequence[int], formats: Sequence[pq.Format] + def set_dumper_types( + self, types: Sequence[int], format: pq.Format + ) -> None: ... + def set_loader_types( + self, types: Sequence[int], format: pq.Format ) -> None: ... def dump_sequence( self, params: Sequence[Any], formats: Sequence[PyFormat] diff --git a/psycopg_c/psycopg_c/_psycopg/transform.pyx b/psycopg_c/psycopg_c/_psycopg/transform.pyx index a68886fc9..ab1fa453c 100644 --- a/psycopg_c/psycopg_c/_psycopg/transform.pyx +++ b/psycopg_c/psycopg_c/_psycopg/transform.pyx @@ -125,36 +125,31 @@ cdef class Transformer: cdef int i cdef object tmp cdef list types - cdef list formats + if format is None: + format = libpq.PQfformat(res, 0) + if set_loaders: types = PyList_New(self._nfields) - formats = PyList_New(self._nfields) for i in range(self._nfields): tmp = libpq.PQftype(res, i) Py_INCREF(tmp) PyList_SET_ITEM(types, i, tmp) - tmp = libpq.PQfformat(res, i) if format is None else format - Py_INCREF(tmp) - PyList_SET_ITEM(formats, i, tmp) - - self._c_set_row_types(self._nfields, types, formats) + self._c_loader_types(self._nfields, types, format) - def set_row_types(self, - types: Sequence[int], formats: Sequence[Format]) -> None: - self._c_set_row_types(len(types), types, formats) + def set_loader_types(self, + types: Sequence[int], format: Format) -> None: + self._c_loader_types(len(types), types, format) - cdef void _c_set_row_types(self, Py_ssize_t ntypes, list types, list formats): + cdef void _c_loader_types(self, Py_ssize_t ntypes, list types, int format): cdef list loaders = PyList_New(ntypes) # these are used more as Python object than C cdef PyObject *oid - cdef PyObject *fmt cdef PyObject *row_loader for i in range(ntypes): oid = PyList_GET_ITEM(types, i) - fmt = PyList_GET_ITEM(formats, i) - row_loader = self._c_get_loader(oid, fmt) + row_loader = self._c_get_loader(oid, format) Py_INCREF(row_loader) PyList_SET_ITEM(loaders, i, row_loader) diff --git a/tests/test_copy.py b/tests/test_copy.py index 6994ab4b1..91b558756 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -20,7 +20,7 @@ from .utils import gc_collect eur = "\u20ac" -sample_records = [(Int4(10), Int4(20), "hello"), (Int4(40), None, "world")] +sample_records = [(10, 20, "hello"), (40, None, "world")] sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')" @@ -79,6 +79,7 @@ def test_copy_out_iter(conn, format): want = [row + b"\n" for row in sample_text.splitlines()] else: want = sample_binary_rows + cur = conn.cursor() with cur.copy( f"copy ({sample_values}) to stdout (format {format.name})" @@ -341,6 +342,22 @@ def test_copy_in_records(conn, format): ensure_table(cur, sample_tabledef) with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + for row in sample_records: + if format == Format.BINARY: + row = tuple(Int4(i) if isinstance(i, int) else i for i in row) + copy.write_row(row) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_copy_in_records_set_types(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + copy.set_types(["int4", "int4", "text"]) for row in sample_records: copy.write_row(row) diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index b118f6463..3fbe7941d 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -14,6 +14,7 @@ from psycopg.pq import Format from psycopg.types import TypeInfo from psycopg.adapt import PyFormat as PgFormat from psycopg.types.hstore import register_hstore +from psycopg.types.numeric import Int4 from .utils import gc_collect from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa @@ -325,6 +326,25 @@ async def test_copy_in_records(aconn, format): async with cur.copy( f"copy copy_in from stdin (format {format.name})" ) as copy: + for row in sample_records: + if format == Format.BINARY: + row = tuple(Int4(i) if isinstance(i, int) else i for i in row) + await copy.write_row(row) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +async def test_copy_in_records_set_types(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy( + f"copy copy_in from stdin (format {format.name})" + ) as copy: + copy.set_types(["int4", "int4", "text"]) for row in sample_records: await copy.write_row(row)