# Copyright (C) 2020 The Psycopg Team
-from typing import Any, Dict, List, Optional, Sequence, Tuple
-from typing import TYPE_CHECKING
+from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
+from typing import cast, TYPE_CHECKING
from . import errors as e
from .pq import Format
from .pq.proto import PGresult
from .adapt import Dumper, Loader, AdaptersMap
from .connection import BaseConnection
+ from .types.array import BaseListDumper
+
+DumperKey = Union[type, Tuple[type, type]]
+DumperCache = Dict[DumperKey, "Dumper"]
+
+LoaderKey = int
+LoaderCache = Dict[LoaderKey, "Loader"]
class Transformer(AdaptContext):
self._connection = None
# mapping class, fmt -> Dumper instance
- self._dumpers_cache: Tuple[Dict[type, "Dumper"], Dict[type, "Dumper"]]
- self._dumpers_cache = ({}, {})
+ self._dumpers_cache: Tuple[DumperCache, DumperCache] = ({}, {})
# mapping oid, fmt -> Loader instance
- self._loaders_cache: Tuple[Dict[int, "Loader"], Dict[int, "Loader"]]
- self._loaders_cache = ({}, {})
+ self._loaders_cache: Tuple[LoaderCache, LoaderCache] = ({}, {})
# sequence of load functions from value to python
# the length of the result columns
def get_dumper(self, obj: Any, format: Format) -> "Dumper":
# Fast path: return a Dumper class already instantiated from the same type
cls = type(obj)
+ if cls is not list:
+ key: DumperKey = cls
+ else:
+ # TODO: Can be probably generalised to handle other recursive types
+ subobj = self._find_list_element(obj)
+ if subobj is None:
+ subobj = ""
+ key = (cls, type(subobj))
+
try:
- return self._dumpers_cache[format][cls]
+ return self._dumpers_cache[format][key]
except KeyError:
pass
- dumper_class = self._adapters.get_dumper(cls, format)
- if dumper_class:
- d = self._dumpers_cache[format][cls] = dumper_class(cls, self)
- return d
+ dcls = self._adapters.get_dumper(cls, format)
+ if not dcls:
+ raise e.ProgrammingError(
+ f"cannot adapt type {cls.__name__}"
+ f" to format {Format(format).name}"
+ )
- raise e.ProgrammingError(
- f"cannot adapt type {cls.__name__} to format {Format(format).name}"
- )
+ d = self._dumpers_cache[format][key] = dcls(cls, self)
+ if cls is list:
+ sub_dumper = self.get_dumper(subobj, format)
+ cast("BaseListDumper", d).set_sub_dumper(sub_dumper)
+
+ return d
def load_rows(self, row0: int, row1: int) -> List[Tuple[Any, ...]]:
res = self._pgresult
raise e.InterfaceError("unknown oid loader not found")
loader = self._loaders_cache[format][oid] = loader_cls(oid, self)
return loader
+
+ def _find_list_element(
+ self, L: List[Any], seen: Optional[Set[int]] = None
+ ) -> Any:
+ """
+ Find the first non-null element of an eventually nested list
+ """
+ if not seen:
+ seen = set()
+ if id(L) in seen:
+ raise e.DataError("cannot dump a recursive list")
+
+ seen.add(id(L))
+
+ for it in L:
+ if type(it) is list:
+ subit = self._find_list_element(it, seen)
+ if subit is not None:
+ return subit
+ elif it is not None:
+ return it
+
+ return None
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
super().__init__(cls, context)
- self._tx = Transformer(context)
+ tx = Transformer(context)
+ self.set_sub_dumper(tx.get_dumper("", self.format))
+
+ def set_sub_dumper(self, dumper: Dumper) -> None:
+ self.sub_dumper = dumper
+ self.oid = self._get_array_oid(dumper.oid)
+ self.sub_oid = dumper.oid or TEXT_OID
def _get_array_oid(self, base_oid: int) -> int:
"""
class ListDumper(BaseListDumper):
+
+ format = Format.TEXT
+
# from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
#
# The array output routine will put double quotes around element values if
# they are empty strings, contain curly braces, delimiter characters,
# double quotes, backslashes, or white space, or match the word NULL.
# TODO: recognise only , as delimiter. Should be configured
-
- format = Format.TEXT
-
_re_needs_quotes = re.compile(
br"""(?xi)
^$ # the empty string
def dump(self, obj: List[Any]) -> bytes:
tokens: List[bytes] = []
- oid = 0
def dump_list(obj: List[Any]) -> None:
- nonlocal oid
-
if not obj:
tokens.append(b"{}")
return
if isinstance(item, list):
dump_list(item)
elif item is not None:
- dumper = self._tx.get_dumper(item, Format.TEXT)
- ad = dumper.dump(item)
+ ad = self.sub_dumper.dump(item)
if self._re_needs_quotes.search(ad):
ad = (
b'"' + self._re_esc.sub(br"\\\1", bytes(ad)) + b'"'
)
tokens.append(ad)
- if not oid:
- oid = dumper.oid
else:
tokens.append(b"NULL")
dump_list(obj)
- if oid:
- self.oid = self._get_array_oid(oid)
-
return b"".join(tokens)
data: List[bytes] = [b"", b""] # placeholders to avoid a resize
dims: List[int] = []
hasnull = 0
- oid = 0
def calc_dims(L: List[Any]) -> None:
if isinstance(L, self.cls):
calc_dims(obj)
def dump_list(L: List[Any], dim: int) -> None:
- nonlocal oid, hasnull
+ nonlocal hasnull
if len(L) != dims[dim]:
raise e.DataError("nested lists have inconsistent lengths")
if dim == len(dims) - 1:
for item in L:
if item is not None:
- dumper = self._tx.get_dumper(item, Format.BINARY)
- ad = dumper.dump(item)
+ ad = self.sub_dumper.dump(item)
data.append(_struct_len.pack(len(ad)))
data.append(ad)
- if not oid:
- oid = dumper.oid
else:
hasnull = 1
data.append(b"\xff\xff\xff\xff")
dump_list(obj, 0)
- if not oid:
- oid = TEXT_OID
-
- self.oid = self._get_array_oid(oid)
-
- data[0] = _struct_head.pack(len(dims), hasnull, oid)
+ data[0] = _struct_head.pack(len(dims), hasnull, self.sub_oid)
data[1] = b"".join(_struct_dim.pack(dim, 1) for dim in dims)
return b"".join(data)
# Copyright (C) 2020 The Psycopg Team
from cpython.ref cimport Py_INCREF
+from cpython.set cimport PySet_Add, PySet_Contains
from cpython.dict cimport PyDict_GetItem, PyDict_SetItem
from cpython.list cimport (
- PyList_New, PyList_GET_ITEM, PyList_SET_ITEM, PyList_GET_SIZE)
+ PyList_New, PyList_CheckExact,
+ PyList_GET_ITEM, PyList_SET_ITEM, PyList_GET_SIZE)
from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM
from cpython.object cimport PyObject, PyObject_CallFunctionObjArgs
cdef PyObject *ptr
cls = type(obj)
+ if cls is not list:
+ key = cls
+ else:
+ subobj = self._find_list_element(obj, set())
+ if subobj is None:
+ subobj = ""
+ key = (cls, type(subobj))
+
cache = self._binary_dumpers if format else self._text_dumpers
- ptr = PyDict_GetItem(cache, cls)
+ ptr = PyDict_GetItem(cache, key)
if ptr != NULL:
return <object>ptr
- dumper_class = PyObject_CallFunctionObjArgs(
+ dcls = PyObject_CallFunctionObjArgs(
self.adapters.get_dumper, <PyObject *>cls, <PyObject *>format, NULL)
- if dumper_class is not None:
- d = PyObject_CallFunctionObjArgs(
- dumper_class, <PyObject *>cls, <PyObject *>self, NULL)
- PyDict_SetItem(cache, cls, d)
- return d
+ if dcls is None:
+ raise e.ProgrammingError(
+ f"cannot adapt type {cls.__name__}"
+ f" to format {Format(format).name}"
+ )
+
+ d = PyObject_CallFunctionObjArgs(
+ dcls, <PyObject *>cls, <PyObject *>self, NULL)
+ if cls is list:
+ sub_dumper = self.get_dumper(subobj, format)
+ d.set_sub_dumper(sub_dumper)
- raise e.ProgrammingError(
- f"cannot adapt type {cls.__name__} to format {Format(format).name}"
- )
+ PyDict_SetItem(cache, key, d)
+ return d
cpdef dump_sequence(self, object params, object formats):
# Verify that they are not none and that PyList_GET_ITEM won't blow up
loader_cls, oid, <PyObject *>self, NULL)
PyDict_SetItem(<object>cache, <object>oid, loader)
return loader
+
+ cdef object _find_list_element(self, object L, set seen):
+ """
+ Find the first non-null element of an eventually nested list
+ """
+ cdef object list_id = <long><PyObject *>L
+ if PySet_Contains(seen, list_id):
+ raise e.DataError("cannot dump a recursive list")
+
+ PySet_Add(seen, list_id)
+
+ cdef int i
+ cdef PyObject *it
+ for i in range(PyList_GET_SIZE(L)):
+ it = PyList_GET_ITEM(L, i)
+ if PyList_CheckExact(<object>it):
+ subit = self._find_list_element(<object>it, seen)
+ if subit is not None:
+ return subit
+ elif <object>it is not None:
+ return <object>it
+
+ return None
assert res == obj
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_array_dumper(fmt_out):
+ t = Transformer()
+ dint = t.get_dumper([0], fmt_out)
+ assert dint.oid == builtins["int8"].array_oid
+ assert dint.sub_oid == builtins["int8"].oid
+
+ dstr = t.get_dumper([""], fmt_out)
+ assert dstr.oid == builtins["text"].array_oid
+ assert dstr.sub_oid == builtins["text"].oid
+ assert dstr is not dint
+
+ assert t.get_dumper([1], fmt_out) is dint
+ assert t.get_dumper([], fmt_out) is dstr
+ assert t.get_dumper([None, [1]], fmt_out) is dint
+ assert t.get_dumper([None, [None]], fmt_out) is dstr
+
+ L = []
+ L.append(L)
+ with pytest.raises(psycopg3.DataError):
+ assert t.get_dumper(L, fmt_out)
+
+
@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
def test_none_type_argument(conn, fmt_in):
cur = conn.cursor()
assert dumper.oid == builtins[type].array_oid
-@pytest.mark.xfail
-@pytest.mark.parametrize("fmt_in", [Format.BINARY]) # TODO: add Format.TEXT
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
def test_empty_list_mix(conn, fmt_in):
ph = "%s" if fmt_in == Format.TEXT else "%b"
objs = list(range(3))