From: Daniele Varrazzo Date: Mon, 11 Jan 2021 22:09:19 +0000 (+0100) Subject: Set sub-dumper on array dumper on creation in Transformer X-Git-Tag: 3.0.dev0~175 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=29d4c2777ae0cd552b2ddcc716d2e0b14961e50d;p=thirdparty%2Fpsycopg.git Set sub-dumper on array dumper on creation in Transformer Fixes binary dump of empty array and other unforeseen problems such as the same array dumper dumping lists with two different elements. --- diff --git a/psycopg3/psycopg3/_transform.py b/psycopg3/psycopg3/_transform.py index cbc6ad00e..7bd36cadd 100644 --- a/psycopg3/psycopg3/_transform.py +++ b/psycopg3/psycopg3/_transform.py @@ -4,8 +4,8 @@ Helper object to transform values between Python and PostgreSQL # Copyright (C) 2020 The Psycopg Team -from typing import Any, Dict, List, Optional, Sequence, Tuple -from typing import TYPE_CHECKING +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import cast, TYPE_CHECKING from . import errors as e from .pq import Format @@ -16,6 +16,13 @@ if TYPE_CHECKING: from .pq.proto import PGresult from .adapt import Dumper, Loader, AdaptersMap from .connection import BaseConnection + from .types.array import BaseListDumper + +DumperKey = Union[type, Tuple[type, type]] +DumperCache = Dict[DumperKey, "Dumper"] + +LoaderKey = int +LoaderCache = Dict[LoaderKey, "Loader"] class Transformer(AdaptContext): @@ -49,12 +56,10 @@ class Transformer(AdaptContext): self._connection = None # mapping class, fmt -> Dumper instance - self._dumpers_cache: Tuple[Dict[type, "Dumper"], Dict[type, "Dumper"]] - self._dumpers_cache = ({}, {}) + self._dumpers_cache: Tuple[DumperCache, DumperCache] = ({}, {}) # mapping oid, fmt -> Loader instance - self._loaders_cache: Tuple[Dict[int, "Loader"], Dict[int, "Loader"]] - self._loaders_cache = ({}, {}) + self._loaders_cache: Tuple[LoaderCache, LoaderCache] = ({}, {}) # sequence of load functions from value to python # the length of the result columns @@ -117,19 +122,33 @@ class Transformer(AdaptContext): def get_dumper(self, obj: Any, format: Format) -> "Dumper": # Fast path: return a Dumper class already instantiated from the same type cls = type(obj) + if cls is not list: + key: DumperKey = cls + else: + # TODO: Can be probably generalised to handle other recursive types + subobj = self._find_list_element(obj) + if subobj is None: + subobj = "" + key = (cls, type(subobj)) + try: - return self._dumpers_cache[format][cls] + return self._dumpers_cache[format][key] except KeyError: pass - dumper_class = self._adapters.get_dumper(cls, format) - if dumper_class: - d = self._dumpers_cache[format][cls] = dumper_class(cls, self) - return d + dcls = self._adapters.get_dumper(cls, format) + if not dcls: + raise e.ProgrammingError( + f"cannot adapt type {cls.__name__}" + f" to format {Format(format).name}" + ) - raise e.ProgrammingError( - f"cannot adapt type {cls.__name__} to format {Format(format).name}" - ) + d = self._dumpers_cache[format][key] = dcls(cls, self) + if cls is list: + sub_dumper = self.get_dumper(subobj, format) + cast("BaseListDumper", d).set_sub_dumper(sub_dumper) + + return d def load_rows(self, row0: int, row1: int) -> List[Tuple[Any, ...]]: res = self._pgresult @@ -196,3 +215,26 @@ class Transformer(AdaptContext): raise e.InterfaceError("unknown oid loader not found") loader = self._loaders_cache[format][oid] = loader_cls(oid, self) return loader + + def _find_list_element( + self, L: List[Any], seen: Optional[Set[int]] = None + ) -> Any: + """ + Find the first non-null element of an eventually nested list + """ + if not seen: + seen = set() + if id(L) in seen: + raise e.DataError("cannot dump a recursive list") + + seen.add(id(L)) + + for it in L: + if type(it) is list: + subit = self._find_list_element(it, seen) + if subit is not None: + return subit + elif it is not None: + return it + + return None diff --git a/psycopg3/psycopg3/types/array.py b/psycopg3/psycopg3/types/array.py index e7bda5d79..9c22bad54 100644 --- a/psycopg3/psycopg3/types/array.py +++ b/psycopg3/psycopg3/types/array.py @@ -20,7 +20,13 @@ class BaseListDumper(Dumper): def __init__(self, cls: type, context: Optional[AdaptContext] = None): super().__init__(cls, context) - self._tx = Transformer(context) + tx = Transformer(context) + self.set_sub_dumper(tx.get_dumper("", self.format)) + + def set_sub_dumper(self, dumper: Dumper) -> None: + self.sub_dumper = dumper + self.oid = self._get_array_oid(dumper.oid) + self.sub_oid = dumper.oid or TEXT_OID def _get_array_oid(self, base_oid: int) -> int: """ @@ -40,15 +46,15 @@ class BaseListDumper(Dumper): class ListDumper(BaseListDumper): + + format = Format.TEXT + # from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO # # The array output routine will put double quotes around element values if # they are empty strings, contain curly braces, delimiter characters, # double quotes, backslashes, or white space, or match the word NULL. # TODO: recognise only , as delimiter. Should be configured - - format = Format.TEXT - _re_needs_quotes = re.compile( br"""(?xi) ^$ # the empty string @@ -63,11 +69,8 @@ class ListDumper(BaseListDumper): def dump(self, obj: List[Any]) -> bytes: tokens: List[bytes] = [] - oid = 0 def dump_list(obj: List[Any]) -> None: - nonlocal oid - if not obj: tokens.append(b"{}") return @@ -77,15 +80,12 @@ class ListDumper(BaseListDumper): if isinstance(item, list): dump_list(item) elif item is not None: - dumper = self._tx.get_dumper(item, Format.TEXT) - ad = dumper.dump(item) + ad = self.sub_dumper.dump(item) if self._re_needs_quotes.search(ad): ad = ( b'"' + self._re_esc.sub(br"\\\1", bytes(ad)) + b'"' ) tokens.append(ad) - if not oid: - oid = dumper.oid else: tokens.append(b"NULL") @@ -95,9 +95,6 @@ class ListDumper(BaseListDumper): dump_list(obj) - if oid: - self.oid = self._get_array_oid(oid) - return b"".join(tokens) @@ -112,7 +109,6 @@ class ListBinaryDumper(BaseListDumper): data: List[bytes] = [b"", b""] # placeholders to avoid a resize dims: List[int] = [] hasnull = 0 - oid = 0 def calc_dims(L: List[Any]) -> None: if isinstance(L, self.cls): @@ -124,19 +120,16 @@ class ListBinaryDumper(BaseListDumper): calc_dims(obj) def dump_list(L: List[Any], dim: int) -> None: - nonlocal oid, hasnull + nonlocal hasnull if len(L) != dims[dim]: raise e.DataError("nested lists have inconsistent lengths") if dim == len(dims) - 1: for item in L: if item is not None: - dumper = self._tx.get_dumper(item, Format.BINARY) - ad = dumper.dump(item) + ad = self.sub_dumper.dump(item) data.append(_struct_len.pack(len(ad))) data.append(ad) - if not oid: - oid = dumper.oid else: hasnull = 1 data.append(b"\xff\xff\xff\xff") @@ -150,12 +143,7 @@ class ListBinaryDumper(BaseListDumper): dump_list(obj, 0) - if not oid: - oid = TEXT_OID - - self.oid = self._get_array_oid(oid) - - data[0] = _struct_head.pack(len(dims), hasnull, oid) + data[0] = _struct_head.pack(len(dims), hasnull, self.sub_oid) data[1] = b"".join(_struct_dim.pack(dim, 1) for dim in dims) return b"".join(data) diff --git a/psycopg3_c/psycopg3_c/_psycopg3/transform.pyx b/psycopg3_c/psycopg3_c/_psycopg3/transform.pyx index c40c0e6f2..62b9b44d9 100644 --- a/psycopg3_c/psycopg3_c/_psycopg3/transform.pyx +++ b/psycopg3_c/psycopg3_c/_psycopg3/transform.pyx @@ -9,9 +9,11 @@ too many temporary Python objects and performing less memory copying. # Copyright (C) 2020 The Psycopg Team from cpython.ref cimport Py_INCREF +from cpython.set cimport PySet_Add, PySet_Contains from cpython.dict cimport PyDict_GetItem, PyDict_SetItem from cpython.list cimport ( - PyList_New, PyList_GET_ITEM, PyList_SET_ITEM, PyList_GET_SIZE) + PyList_New, PyList_CheckExact, + PyList_GET_ITEM, PyList_SET_ITEM, PyList_GET_SIZE) from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM from cpython.object cimport PyObject, PyObject_CallFunctionObjArgs @@ -169,22 +171,35 @@ cdef class Transformer: cdef PyObject *ptr cls = type(obj) + if cls is not list: + key = cls + else: + subobj = self._find_list_element(obj, set()) + if subobj is None: + subobj = "" + key = (cls, type(subobj)) + cache = self._binary_dumpers if format else self._text_dumpers - ptr = PyDict_GetItem(cache, cls) + ptr = PyDict_GetItem(cache, key) if ptr != NULL: return ptr - dumper_class = PyObject_CallFunctionObjArgs( + dcls = PyObject_CallFunctionObjArgs( self.adapters.get_dumper, cls, format, NULL) - if dumper_class is not None: - d = PyObject_CallFunctionObjArgs( - dumper_class, cls, self, NULL) - PyDict_SetItem(cache, cls, d) - return d + if dcls is None: + raise e.ProgrammingError( + f"cannot adapt type {cls.__name__}" + f" to format {Format(format).name}" + ) + + d = PyObject_CallFunctionObjArgs( + dcls, cls, self, NULL) + if cls is list: + sub_dumper = self.get_dumper(subobj, format) + d.set_sub_dumper(sub_dumper) - raise e.ProgrammingError( - f"cannot adapt type {cls.__name__} to format {Format(format).name}" - ) + PyDict_SetItem(cache, key, d) + return d cpdef dump_sequence(self, object params, object formats): # Verify that they are not none and that PyList_GET_ITEM won't blow up @@ -375,3 +390,26 @@ cdef class Transformer: loader_cls, oid, self, NULL) PyDict_SetItem(cache, oid, loader) return loader + + cdef object _find_list_element(self, object L, set seen): + """ + Find the first non-null element of an eventually nested list + """ + cdef object list_id = L + if PySet_Contains(seen, list_id): + raise e.DataError("cannot dump a recursive list") + + PySet_Add(seen, list_id) + + cdef int i + cdef PyObject *it + for i in range(PyList_GET_SIZE(L)): + it = PyList_GET_ITEM(L, i) + if PyList_CheckExact(it): + subit = self._find_list_element(it, seen) + if subit is not None: + return subit + elif it is not None: + return it + + return None diff --git a/tests/test_adapt.py b/tests/test_adapt.py index f549c3c5e..048da0a49 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -160,6 +160,29 @@ def test_load_cursor_ctx_nested(conn, sql, obj, fmt_out): assert res == obj +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) +def test_array_dumper(fmt_out): + t = Transformer() + dint = t.get_dumper([0], fmt_out) + assert dint.oid == builtins["int8"].array_oid + assert dint.sub_oid == builtins["int8"].oid + + dstr = t.get_dumper([""], fmt_out) + assert dstr.oid == builtins["text"].array_oid + assert dstr.sub_oid == builtins["text"].oid + assert dstr is not dint + + assert t.get_dumper([1], fmt_out) is dint + assert t.get_dumper([], fmt_out) is dstr + assert t.get_dumper([None, [1]], fmt_out) is dint + assert t.get_dumper([None, [None]], fmt_out) is dstr + + L = [] + L.append(L) + with pytest.raises(psycopg3.DataError): + assert t.get_dumper(L, fmt_out) + + @pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) def test_none_type_argument(conn, fmt_in): cur = conn.cursor() diff --git a/tests/types/test_array.py b/tests/types/test_array.py index 9de51c14a..91aa8149f 100644 --- a/tests/types/test_array.py +++ b/tests/types/test_array.py @@ -141,8 +141,7 @@ def test_array_mixed_numbers(array, type): assert dumper.oid == builtins[type].array_oid -@pytest.mark.xfail -@pytest.mark.parametrize("fmt_in", [Format.BINARY]) # TODO: add Format.TEXT +@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) def test_empty_list_mix(conn, fmt_in): ph = "%s" if fmt_in == Format.TEXT else "%b" objs = list(range(3))