]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: Raise DataError dumping lists of mixed types
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 13 May 2022 22:37:48 +0000 (00:37 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 14 May 2022 00:01:27 +0000 (02:01 +0200)
This requires more checking dumping lists, but the alternative is to
guard every single dumper, which is painful as well, and more brittle.

docs/news.rst
psycopg/psycopg/types/array.py
tests/types/test_array.py

index b8ca1e21a3e134667b77185c71c1052add6f2335..4f8f8cc2195839a829c573d80241c41e312fc36c 100644 (file)
@@ -31,7 +31,7 @@ Psycopg 3.1 (unreleased)
 Psycopg 3.0.14 (unreleased)
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
-- Fail dumping arrays of mixed numbers with DataError (:ticket:`#301`).
+- Raise `DataError` dumping arrays of mixed types (:ticket:`#301`).
 
 
 Current release
index f2b28cdec068aa5b2086c1a51395140a29035093..fe1e311bd0dc64bf129712e2b9a413456c64707f 100644 (file)
@@ -47,23 +47,30 @@ class BaseListDumper(RecursiveDumper):
         """
         Find the first non-null element of an eventually nested list
         """
-        it = self._flatiter(L, set())
-        try:
-            item = next(it)
-        except StopIteration:
+        items = list(self._flatiter(L, set()))
+        types = set(map(type, items))
+        if not types:
             return None
+        if len(types) > 1:
+            raise e.DataError(
+                "cannot dump lists of mixed types;"
+                f" got: {', '.join(sorted(t.__name__ for t in types))}"
+            )
+        t = types.pop()
 
         # Checking for precise type. If the type is a subclass (e.g. Int4)
         # we assume the user knows what type they are passing.
-        if type(item) is not int:
-            return item
+        if t is not int:
+            return items[0]
 
         # If we got an int, let's see what is the biggest one in order to
         # choose the smallest OID and allow Postgres to do the right cast.
-        it = self._flatiter(L, set())
-        imax = max((i if i >= 0 else -i - 1 for i in it), default=0)
-        imax = max(item if item >= 0 else -item - 1, imax)
-        return imax
+        imax: int = max(items)
+        imin: int = min(items)
+        if imin >= 0:
+            return imax
+        else:
+            return max(imax, -imin - 1)
 
     def _flatiter(self, L: List[Any], seen: Set[int]) -> Any:
         if id(L) in seen:
@@ -118,7 +125,7 @@ class ListDumper(BaseListDumper):
             # Empty lists can only be dumped as text if the type is unknown.
             return self
 
-        sd = self._tx.get_dumper(item, format.from_pq(self.format))
+        sd = self._tx.get_dumper(item, PyFormat.from_pq(self.format))
         dumper = type(self)(self.cls, self._tx)
         dumper.sub_dumper = sd
 
index 5abf207e90132a08931688ae44736163bb68f2f6..6ea1971012ffc6de54906c56cfe415b5a19a0270 100644 (file)
@@ -1,4 +1,3 @@
-from enum import Enum
 from decimal import Decimal
 
 import pytest
@@ -185,17 +184,6 @@ def test_list_number_wrapper(conn, wrapper, fmt_in, fmt_out):
 
 
 def test_mix_types(conn):
-    class MyEnum(int, Enum):
-        ONE = 2**30
-
-    cur = conn.execute("select %s", ([1, MyEnum.ONE],))
-    assert cur.fetchone() == ([1, 2**30],)
-    assert cur.description[0].type_code == cur.adapters.types["int4"].array_oid
-
-    cur = conn.execute("select %s", ([1, psycopg.types.numeric.Int8(2**60)],))
-    assert cur.fetchone() == ([1, 2**60],)
-    assert cur.description[0].type_code == cur.adapters.types["int8"].array_oid
-
     with pytest.raises(psycopg.DataError):
         conn.execute("select %s", ([1, 0.5],))