]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: allow dumping mixed types list, as long as they use the same dumper
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 13 May 2022 23:29:53 +0000 (01:29 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 14 May 2022 00:01:27 +0000 (02:01 +0200)
Necessary to dump a mix of ipv4/ipv6 addresses, which result in the same
Postgres type, so it's a legit array.

psycopg/psycopg/types/array.py

index fe1e311bd0dc64bf129712e2b9a413456c64707f..d8d7f8756dc869c2108b313a9f58564069edfa9a 100644 (file)
@@ -43,25 +43,34 @@ class BaseListDumper(RecursiveDumper):
             sdclass = context.adapters.get_dumper_by_oid(self.element_oid, self.format)
             self.sub_dumper = sdclass(NoneType, context)
 
-    def _find_list_element(self, L: List[Any]) -> Any:
+    def _find_list_element(self, L: List[Any], format: PyFormat) -> Any:
         """
         Find the first non-null element of an eventually nested list
         """
         items = list(self._flatiter(L, set()))
-        types = set(map(type, items))
+        types = {type(item): item for item in items}
         if not types:
             return None
-        if len(types) > 1:
-            raise e.DataError(
-                "cannot dump lists of mixed types;"
-                f" got: {', '.join(sorted(t.__name__ for t in types))}"
-            )
-        t = types.pop()
+
+        if len(types) == 1:
+            t, v = types.popitem()
+        else:
+            # More than one type in the list. It might be still good, as long
+            # as they dump with the same oid (e.g. IPv4Network, IPv6Network).
+            dumpers = [self._tx.get_dumper(item, format) for item in types.values()]
+            oids = set(d.oid for d in dumpers)
+            if len(oids) == 1:
+                t, v = types.popitem()
+            else:
+                raise e.DataError(
+                    "cannot dump lists of mixed types;"
+                    f" got: {', '.join(sorted(t.__name__ for t in types))}"
+                )
 
         # Checking for precise type. If the type is a subclass (e.g. Int4)
         # we assume the user knows what type they are passing.
         if t is not int:
-            return items[0]
+            return v
 
         # If we got an int, let's see what is the biggest one in order to
         # choose the smallest OID and allow Postgres to do the right cast.
@@ -108,7 +117,7 @@ class ListDumper(BaseListDumper):
         if self.oid:
             return self.cls
 
-        item = self._find_list_element(obj)
+        item = self._find_list_element(obj, format)
         if item is None:
             return self.cls
 
@@ -120,7 +129,7 @@ class ListDumper(BaseListDumper):
         if self.oid:
             return self
 
-        item = self._find_list_element(obj)
+        item = self._find_list_element(obj, format)
         if item is None:
             # Empty lists can only be dumped as text if the type is unknown.
             return self
@@ -210,7 +219,7 @@ class ListBinaryDumper(BaseListDumper):
         if self.oid:
             return self.cls
 
-        item = self._find_list_element(obj)
+        item = self._find_list_element(obj, format)
         if item is None:
             return (self.cls,)
 
@@ -222,7 +231,7 @@ class ListBinaryDumper(BaseListDumper):
         if self.oid:
             return self
 
-        item = self._find_list_element(obj)
+        item = self._find_list_element(obj, format)
         if item is None:
             return ListDumper(self.cls, self._tx)