]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Set sub-dumper on array dumper on creation in Transformer
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 11 Jan 2021 22:09:19 +0000 (23:09 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jan 2021 15:16:34 +0000 (16:16 +0100)
Fixes binary dump of empty array and other unforeseen problems such as
the same array dumper dumping lists with two different elements.

psycopg3/psycopg3/_transform.py
psycopg3/psycopg3/types/array.py
psycopg3_c/psycopg3_c/_psycopg3/transform.pyx
tests/test_adapt.py
tests/types/test_array.py

index cbc6ad00e340350c1c48e3da1f2341d8912957c5..7bd36caddf243393f2364b90ae9a570d182655c2 100644 (file)
@@ -4,8 +4,8 @@ Helper object to transform values between Python and PostgreSQL
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Any, Dict, List, Optional, Sequence, Tuple
-from typing import TYPE_CHECKING
+from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
+from typing import cast, TYPE_CHECKING
 
 from . import errors as e
 from .pq import Format
@@ -16,6 +16,13 @@ if TYPE_CHECKING:
     from .pq.proto import PGresult
     from .adapt import Dumper, Loader, AdaptersMap
     from .connection import BaseConnection
+    from .types.array import BaseListDumper
+
+DumperKey = Union[type, Tuple[type, type]]
+DumperCache = Dict[DumperKey, "Dumper"]
+
+LoaderKey = int
+LoaderCache = Dict[LoaderKey, "Loader"]
 
 
 class Transformer(AdaptContext):
@@ -49,12 +56,10 @@ class Transformer(AdaptContext):
             self._connection = None
 
         # mapping class, fmt -> Dumper instance
-        self._dumpers_cache: Tuple[Dict[type, "Dumper"], Dict[type, "Dumper"]]
-        self._dumpers_cache = ({}, {})
+        self._dumpers_cache: Tuple[DumperCache, DumperCache] = ({}, {})
 
         # mapping oid, fmt -> Loader instance
-        self._loaders_cache: Tuple[Dict[int, "Loader"], Dict[int, "Loader"]]
-        self._loaders_cache = ({}, {})
+        self._loaders_cache: Tuple[LoaderCache, LoaderCache] = ({}, {})
 
         # sequence of load functions from value to python
         # the length of the result columns
@@ -117,19 +122,33 @@ class Transformer(AdaptContext):
     def get_dumper(self, obj: Any, format: Format) -> "Dumper":
         # Fast path: return a Dumper class already instantiated from the same type
         cls = type(obj)
+        if cls is not list:
+            key: DumperKey = cls
+        else:
+            # TODO: Can be probably generalised to handle other recursive types
+            subobj = self._find_list_element(obj)
+            if subobj is None:
+                subobj = ""
+            key = (cls, type(subobj))
+
         try:
-            return self._dumpers_cache[format][cls]
+            return self._dumpers_cache[format][key]
         except KeyError:
             pass
 
-        dumper_class = self._adapters.get_dumper(cls, format)
-        if dumper_class:
-            d = self._dumpers_cache[format][cls] = dumper_class(cls, self)
-            return d
+        dcls = self._adapters.get_dumper(cls, format)
+        if not dcls:
+            raise e.ProgrammingError(
+                f"cannot adapt type {cls.__name__}"
+                f" to format {Format(format).name}"
+            )
 
-        raise e.ProgrammingError(
-            f"cannot adapt type {cls.__name__} to format {Format(format).name}"
-        )
+        d = self._dumpers_cache[format][key] = dcls(cls, self)
+        if cls is list:
+            sub_dumper = self.get_dumper(subobj, format)
+            cast("BaseListDumper", d).set_sub_dumper(sub_dumper)
+
+        return d
 
     def load_rows(self, row0: int, row1: int) -> List[Tuple[Any, ...]]:
         res = self._pgresult
@@ -196,3 +215,26 @@ class Transformer(AdaptContext):
                 raise e.InterfaceError("unknown oid loader not found")
         loader = self._loaders_cache[format][oid] = loader_cls(oid, self)
         return loader
+
+    def _find_list_element(
+        self, L: List[Any], seen: Optional[Set[int]] = None
+    ) -> Any:
+        """
+        Find the first non-null element of an eventually nested list
+        """
+        if not seen:
+            seen = set()
+        if id(L) in seen:
+            raise e.DataError("cannot dump a recursive list")
+
+        seen.add(id(L))
+
+        for it in L:
+            if type(it) is list:
+                subit = self._find_list_element(it, seen)
+                if subit is not None:
+                    return subit
+            elif it is not None:
+                return it
+
+        return None
index e7bda5d797d77a41c5b3da436281a3ca3b6c7358..9c22bad5408536bd08cbcbccbf439d934df53a6a 100644 (file)
@@ -20,7 +20,13 @@ class BaseListDumper(Dumper):
 
     def __init__(self, cls: type, context: Optional[AdaptContext] = None):
         super().__init__(cls, context)
-        self._tx = Transformer(context)
+        tx = Transformer(context)
+        self.set_sub_dumper(tx.get_dumper("", self.format))
+
+    def set_sub_dumper(self, dumper: Dumper) -> None:
+        self.sub_dumper = dumper
+        self.oid = self._get_array_oid(dumper.oid)
+        self.sub_oid = dumper.oid or TEXT_OID
 
     def _get_array_oid(self, base_oid: int) -> int:
         """
@@ -40,15 +46,15 @@ class BaseListDumper(Dumper):
 
 
 class ListDumper(BaseListDumper):
+
+    format = Format.TEXT
+
     # from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
     #
     # The array output routine will put double quotes around element values if
     # they are empty strings, contain curly braces, delimiter characters,
     # double quotes, backslashes, or white space, or match the word NULL.
     # TODO: recognise only , as delimiter. Should be configured
-
-    format = Format.TEXT
-
     _re_needs_quotes = re.compile(
         br"""(?xi)
           ^$              # the empty string
@@ -63,11 +69,8 @@ class ListDumper(BaseListDumper):
 
     def dump(self, obj: List[Any]) -> bytes:
         tokens: List[bytes] = []
-        oid = 0
 
         def dump_list(obj: List[Any]) -> None:
-            nonlocal oid
-
             if not obj:
                 tokens.append(b"{}")
                 return
@@ -77,15 +80,12 @@ class ListDumper(BaseListDumper):
                 if isinstance(item, list):
                     dump_list(item)
                 elif item is not None:
-                    dumper = self._tx.get_dumper(item, Format.TEXT)
-                    ad = dumper.dump(item)
+                    ad = self.sub_dumper.dump(item)
                     if self._re_needs_quotes.search(ad):
                         ad = (
                             b'"' + self._re_esc.sub(br"\\\1", bytes(ad)) + b'"'
                         )
                     tokens.append(ad)
-                    if not oid:
-                        oid = dumper.oid
                 else:
                     tokens.append(b"NULL")
 
@@ -95,9 +95,6 @@ class ListDumper(BaseListDumper):
 
         dump_list(obj)
 
-        if oid:
-            self.oid = self._get_array_oid(oid)
-
         return b"".join(tokens)
 
 
@@ -112,7 +109,6 @@ class ListBinaryDumper(BaseListDumper):
         data: List[bytes] = [b"", b""]  # placeholders to avoid a resize
         dims: List[int] = []
         hasnull = 0
-        oid = 0
 
         def calc_dims(L: List[Any]) -> None:
             if isinstance(L, self.cls):
@@ -124,19 +120,16 @@ class ListBinaryDumper(BaseListDumper):
         calc_dims(obj)
 
         def dump_list(L: List[Any], dim: int) -> None:
-            nonlocal oid, hasnull
+            nonlocal hasnull
             if len(L) != dims[dim]:
                 raise e.DataError("nested lists have inconsistent lengths")
 
             if dim == len(dims) - 1:
                 for item in L:
                     if item is not None:
-                        dumper = self._tx.get_dumper(item, Format.BINARY)
-                        ad = dumper.dump(item)
+                        ad = self.sub_dumper.dump(item)
                         data.append(_struct_len.pack(len(ad)))
                         data.append(ad)
-                        if not oid:
-                            oid = dumper.oid
                     else:
                         hasnull = 1
                         data.append(b"\xff\xff\xff\xff")
@@ -150,12 +143,7 @@ class ListBinaryDumper(BaseListDumper):
 
         dump_list(obj, 0)
 
-        if not oid:
-            oid = TEXT_OID
-
-        self.oid = self._get_array_oid(oid)
-
-        data[0] = _struct_head.pack(len(dims), hasnull, oid)
+        data[0] = _struct_head.pack(len(dims), hasnull, self.sub_oid)
         data[1] = b"".join(_struct_dim.pack(dim, 1) for dim in dims)
         return b"".join(data)
 
index c40c0e6f256123dc3d836db1eaf455b2355441a7..62b9b44d948a961df44d342e3ea7e73f98948e2e 100644 (file)
@@ -9,9 +9,11 @@ too many temporary Python objects and performing less memory copying.
 # Copyright (C) 2020 The Psycopg Team
 
 from cpython.ref cimport Py_INCREF
+from cpython.set cimport PySet_Add, PySet_Contains
 from cpython.dict cimport PyDict_GetItem, PyDict_SetItem
 from cpython.list cimport (
-    PyList_New, PyList_GET_ITEM, PyList_SET_ITEM, PyList_GET_SIZE)
+    PyList_New, PyList_CheckExact,
+    PyList_GET_ITEM, PyList_SET_ITEM, PyList_GET_SIZE)
 from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM
 from cpython.object cimport PyObject, PyObject_CallFunctionObjArgs
 
@@ -169,22 +171,35 @@ cdef class Transformer:
         cdef PyObject *ptr
 
         cls = type(obj)
+        if cls is not list:
+            key = cls
+        else:
+            subobj = self._find_list_element(obj, set())
+            if subobj is None:
+                subobj = ""
+            key = (cls, type(subobj))
+
         cache = self._binary_dumpers if format else self._text_dumpers
-        ptr = PyDict_GetItem(cache, cls)
+        ptr = PyDict_GetItem(cache, key)
         if ptr != NULL:
             return <object>ptr
 
-        dumper_class = PyObject_CallFunctionObjArgs(
+        dcls = PyObject_CallFunctionObjArgs(
             self.adapters.get_dumper, <PyObject *>cls, <PyObject *>format, NULL)
-        if dumper_class is not None:
-            d = PyObject_CallFunctionObjArgs(
-                dumper_class, <PyObject *>cls, <PyObject *>self, NULL)
-            PyDict_SetItem(cache, cls, d)
-            return d
+        if dcls is None:
+            raise e.ProgrammingError(
+                f"cannot adapt type {cls.__name__}"
+                f" to format {Format(format).name}"
+            )
+
+        d = PyObject_CallFunctionObjArgs(
+            dcls, <PyObject *>cls, <PyObject *>self, NULL)
+        if cls is list:
+            sub_dumper = self.get_dumper(subobj, format)
+            d.set_sub_dumper(sub_dumper)
 
-        raise e.ProgrammingError(
-            f"cannot adapt type {cls.__name__} to format {Format(format).name}"
-        )
+        PyDict_SetItem(cache, key, d)
+        return d
 
     cpdef dump_sequence(self, object params, object formats):
         # Verify that they are not none and that PyList_GET_ITEM won't blow up
@@ -375,3 +390,26 @@ cdef class Transformer:
             loader_cls, oid, <PyObject *>self, NULL)
         PyDict_SetItem(<object>cache, <object>oid, loader)
         return loader
+
+    cdef object _find_list_element(self, object L, set seen):
+        """
+        Find the first non-null element of an eventually nested list
+        """
+        cdef object list_id = <long><PyObject *>L
+        if PySet_Contains(seen, list_id):
+            raise e.DataError("cannot dump a recursive list")
+
+        PySet_Add(seen, list_id)
+
+        cdef int i
+        cdef PyObject *it
+        for i in range(PyList_GET_SIZE(L)):
+            it = PyList_GET_ITEM(L, i)
+            if PyList_CheckExact(<object>it):
+                subit = self._find_list_element(<object>it, seen)
+                if subit is not None:
+                    return subit
+            elif <object>it is not None:
+                return <object>it
+
+        return None
index f549c3c5e6c7d90f5c0212f8db2341d239b7af6e..048da0a49f948439c41ed315a75782dfae15809b 100644 (file)
@@ -160,6 +160,29 @@ def test_load_cursor_ctx_nested(conn, sql, obj, fmt_out):
     assert res == obj
 
 
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_array_dumper(fmt_out):
+    t = Transformer()
+    dint = t.get_dumper([0], fmt_out)
+    assert dint.oid == builtins["int8"].array_oid
+    assert dint.sub_oid == builtins["int8"].oid
+
+    dstr = t.get_dumper([""], fmt_out)
+    assert dstr.oid == builtins["text"].array_oid
+    assert dstr.sub_oid == builtins["text"].oid
+    assert dstr is not dint
+
+    assert t.get_dumper([1], fmt_out) is dint
+    assert t.get_dumper([], fmt_out) is dstr
+    assert t.get_dumper([None, [1]], fmt_out) is dint
+    assert t.get_dumper([None, [None]], fmt_out) is dstr
+
+    L = []
+    L.append(L)
+    with pytest.raises(psycopg3.DataError):
+        assert t.get_dumper(L, fmt_out)
+
+
 @pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
 def test_none_type_argument(conn, fmt_in):
     cur = conn.cursor()
index 9de51c14a213a2d6ff7c10414c17b60aeb5ee753..91aa8149f0a54f73e4c405eb1c31c8962a5eb74f 100644 (file)
@@ -141,8 +141,7 @@ def test_array_mixed_numbers(array, type):
     assert dumper.oid == builtins[type].array_oid
 
 
-@pytest.mark.xfail
-@pytest.mark.parametrize("fmt_in", [Format.BINARY])  # TODO: add Format.TEXT
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
 def test_empty_list_mix(conn, fmt_in):
     ph = "%s" if fmt_in == Format.TEXT else "%b"
     objs = list(range(3))