From: Daniele Varrazzo Date: Fri, 13 May 2022 22:37:48 +0000 (+0200) Subject: fix: Raise DataError dumping lists of mixed types X-Git-Tag: 3.1~101^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ff38106d4ea604939c39b7afc2751103585e2ed7;p=thirdparty%2Fpsycopg.git fix: Raise DataError dumping lists of mixed types This requires more checking dumping lists, but the alternative is to guard every single dumper, which is painful as well, and more brittle. --- diff --git a/docs/news.rst b/docs/news.rst index b8ca1e21a..4f8f8cc21 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -31,7 +31,7 @@ Psycopg 3.1 (unreleased) Psycopg 3.0.14 (unreleased) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ -- Fail dumping arrays of mixed numbers with DataError (:ticket:`#301`). +- Raise `DataError` dumping arrays of mixed types (:ticket:`#301`). Current release diff --git a/psycopg/psycopg/types/array.py b/psycopg/psycopg/types/array.py index f2b28cdec..fe1e311bd 100644 --- a/psycopg/psycopg/types/array.py +++ b/psycopg/psycopg/types/array.py @@ -47,23 +47,30 @@ class BaseListDumper(RecursiveDumper): """ Find the first non-null element of an eventually nested list """ - it = self._flatiter(L, set()) - try: - item = next(it) - except StopIteration: + items = list(self._flatiter(L, set())) + types = set(map(type, items)) + if not types: return None + if len(types) > 1: + raise e.DataError( + "cannot dump lists of mixed types;" + f" got: {', '.join(sorted(t.__name__ for t in types))}" + ) + t = types.pop() # Checking for precise type. If the type is a subclass (e.g. Int4) # we assume the user knows what type they are passing. - if type(item) is not int: - return item + if t is not int: + return items[0] # If we got an int, let's see what is the biggest one in order to # choose the smallest OID and allow Postgres to do the right cast. - it = self._flatiter(L, set()) - imax = max((i if i >= 0 else -i - 1 for i in it), default=0) - imax = max(item if item >= 0 else -item - 1, imax) - return imax + imax: int = max(items) + imin: int = min(items) + if imin >= 0: + return imax + else: + return max(imax, -imin - 1) def _flatiter(self, L: List[Any], seen: Set[int]) -> Any: if id(L) in seen: @@ -118,7 +125,7 @@ class ListDumper(BaseListDumper): # Empty lists can only be dumped as text if the type is unknown. return self - sd = self._tx.get_dumper(item, format.from_pq(self.format)) + sd = self._tx.get_dumper(item, PyFormat.from_pq(self.format)) dumper = type(self)(self.cls, self._tx) dumper.sub_dumper = sd diff --git a/tests/types/test_array.py b/tests/types/test_array.py index 5abf207e9..6ea197101 100644 --- a/tests/types/test_array.py +++ b/tests/types/test_array.py @@ -1,4 +1,3 @@ -from enum import Enum from decimal import Decimal import pytest @@ -185,17 +184,6 @@ def test_list_number_wrapper(conn, wrapper, fmt_in, fmt_out): def test_mix_types(conn): - class MyEnum(int, Enum): - ONE = 2**30 - - cur = conn.execute("select %s", ([1, MyEnum.ONE],)) - assert cur.fetchone() == ([1, 2**30],) - assert cur.description[0].type_code == cur.adapters.types["int4"].array_oid - - cur = conn.execute("select %s", ([1, psycopg.types.numeric.Int8(2**60)],)) - assert cur.fetchone() == ([1, 2**60],) - assert cur.description[0].type_code == cur.adapters.types["int8"].array_oid - with pytest.raises(psycopg.DataError): conn.execute("select %s", ([1, 0.5],))