]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Allow some form of dumping lists of mixed types
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 26 Jul 2021 18:56:59 +0000 (20:56 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 26 Jul 2021 18:59:37 +0000 (20:59 +0200)
Lists of numbers are now dumped as numeric[].

Default to dump text for arrays.

psycopg/psycopg/types/array.py
tests/test_adapt.py
tests/types/test_array.py
tests/types/test_range.py

index 1b82e76bac2ae1b5c46d83aee3b4d4ab458cf10f..31900355961b0ad9dc890de6bea490caca8f0e9b 100644 (file)
@@ -6,6 +6,7 @@ Adapters for arrays
 
 import re
 import struct
+from decimal import Decimal
 from typing import Any, Callable, Iterator, List, Optional, Set, Tuple, Type
 from typing import cast
 
@@ -37,51 +38,16 @@ class BaseListDumper(RecursiveDumper):
         super().__init__(cls, context)
         self.sub_dumper: Optional[Dumper] = None
 
-    def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
-        item = self._find_list_element(obj)
-        if item is not None:
-            sd = self._tx.get_dumper(item, format)
-            return (self.cls, sd.get_key(item, format))  # type: ignore
-        else:
-            return (self.cls,)
-
-    def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
-        item = self._find_list_element(obj)
-        if item is None:
-            # Empty lists can only be dumped as text if the type is unknown.
-            return ListDumper(self.cls, self._tx)
-
-        sd = self._tx.get_dumper(item, format)
-        dcls = ListDumper if sd.format == pq.Format.TEXT else ListBinaryDumper
-        dumper = dcls(self.cls, self._tx)
-        dumper.sub_dumper = sd
-
-        # We consider an array of unknowns as unknown, so we can dump empty
-        # lists or lists containing only None elements.
-        if sd.oid != INVALID_OID:
-            dumper.oid = self._get_array_oid(sd.oid)
-        else:
-            dumper.oid = INVALID_OID
-
-        return dumper
-
     def _find_list_element(self, L: List[Any]) -> Any:
         """
         Find the first non-null element of an eventually nested list
         """
         it = self._flatiter(L, set())
         try:
-            item = next(it)
+            return next(it)
         except StopIteration:
             return None
 
-        if not isinstance(item, int):
-            return item
-
-        imax = max((i if i >= 0 else -i - 1 for i in it), default=0)
-        imax = max(item if item >= 0 else -item, imax)
-        return imax
-
     def _flatiter(self, L: List[Any], seen: Set[int]) -> Any:
         if id(L) in seen:
             raise e.DataError("cannot dump a recursive list")
@@ -116,6 +82,49 @@ class ListDumper(BaseListDumper):
 
     format = pq.Format.TEXT
 
+    def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
+        if self.oid:
+            return self.cls
+
+        item = self._find_list_element(obj)
+        if item is None:
+            return self.cls
+
+        # If we got a number, let's dump them as numeric text array.
+        # Don't check for subclasses because if someone has used Int2 etc
+        # they probably know better what they want.
+        if type(item) in MixedNumbersListDumper.NUMBERS_TYPES:
+            return MixedNumbersListDumper
+
+        sd = self._tx.get_dumper(item, format)
+        return (self.cls, sd.get_key(item, format))  # type: ignore
+
+    def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
+        # If we have an oid we don't need to upgrade
+        if self.oid:
+            return self
+
+        item = self._find_list_element(obj)
+        if item is None:
+            # Empty lists can only be dumped as text if the type is unknown.
+            return self
+
+        if type(item) in MixedNumbersListDumper.NUMBERS_TYPES:
+            return MixedNumbersListDumper(self.cls, self._tx)
+
+        sd = self._tx.get_dumper(item, format.from_pq(self.format))
+        dumper = type(self)(self.cls, self._tx)
+        dumper.sub_dumper = sd
+
+        # We consider an array of unknowns as unknown, so we can dump empty
+        # lists or lists containing only None elements.
+        if sd.oid != INVALID_OID:
+            dumper.oid = self._get_array_oid(sd.oid)
+        else:
+            dumper.oid = INVALID_OID
+
+        return dumper
+
     # from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
     #
     # The array output routine will put double quotes around element values if
@@ -147,12 +156,11 @@ class ListDumper(BaseListDumper):
                 if isinstance(item, list):
                     dump_list(item)
                 elif item is not None:
-                    # If we get here, the sub_dumper must have been set
-                    ad = self.sub_dumper.dump(item)  # type: ignore[union-attr]
+                    ad = self._dump_item(item)
                     if self._re_needs_quotes.search(ad):
-                        ad = (
-                            b'"' + self._re_esc.sub(br"\\\1", bytes(ad)) + b'"'
-                        )
+                        if not isinstance(ad, bytes):
+                            ad = bytes(ad)
+                        ad = b'"' + self._re_esc.sub(br"\\\1", ad) + b'"'
                     tokens.append(ad)
                 else:
                     tokens.append(b"NULL")
@@ -165,11 +173,70 @@ class ListDumper(BaseListDumper):
 
         return b"".join(tokens)
 
+    def _dump_item(self, item: Any) -> Buffer:
+        if self.sub_dumper:
+            return self.sub_dumper.dump(item)
+        else:
+            return self._tx.get_dumper(item, PyFormat.TEXT).dump(item)
+
+
+class MixedItemsListDumper(ListDumper):
+    """
+    An array dumper that doesn't assume that all the items are the same type.
+
+    Such dumper can be only textual and return either unknown oid or something
+    that work for every type contained.
+    """
+
+    def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
+        return self.cls
+
+    def _dump_item(self, item: Any) -> Buffer:
+        # If we get here, the sub_dumper must have been set
+        return self._tx.get_dumper(item, PyFormat.TEXT).dump(item)
+
+
+class MixedNumbersListDumper(MixedItemsListDumper):
+    """
+    A text dumper to dump lists containing any number as numeric array.
+    """
+
+    NUMBERS_TYPES = (int, float, Decimal)
+
+    _oid = postgres.types["numeric"].array_oid
+
 
 class ListBinaryDumper(BaseListDumper):
 
     format = pq.Format.BINARY
 
+    def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
+        if self.oid:
+            return self.cls
+
+        item = self._find_list_element(obj)
+        if item is None:
+            return (self.cls,)
+
+        sd = self._tx.get_dumper(item, format)
+        return (self.cls, sd.get_key(item, format))  # type: ignore
+
+    def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
+        # If we have an oid we don't need to upgrade
+        if self.oid:
+            return self
+
+        item = self._find_list_element(obj)
+        if item is None:
+            return ListDumper(self.cls, self._tx)
+
+        sd = self._tx.get_dumper(item, format.from_pq(self.format))
+        dumper = type(self)(self.cls, self._tx)
+        dumper.sub_dumper = sd
+        dumper.oid = self._get_array_oid(sd.oid)
+
+        return dumper
+
     def dump(self, obj: List[Any]) -> bytes:
         # Postgres won't take unknown for element oid: fall back on text
         sub_oid = self.sub_dumper and self.sub_dumper.oid or TEXT_OID
@@ -219,6 +286,17 @@ class ListBinaryDumper(BaseListDumper):
         data[1] = b"".join(_pack_dim(dim, 1) for dim in dims)
         return b"".join(data)
 
+    def _find_list_element(self, L: List[Any]) -> Any:
+        item = super()._find_list_element(L)
+        if not isinstance(item, int):
+            return item
+
+        # If we got an int, let's see what is the biggest onw
+        it = self._flatiter(L, set())
+        imax = max((i if i >= 0 else -i - 1 for i in it), default=0)
+        imax = max(item if item >= 0 else -item, imax)
+        return imax
+
 
 class BaseArrayLoader(RecursiveLoader):
     base_oid: int
@@ -331,8 +409,10 @@ def register_adapters(
 
 
 def register_default_adapters(context: AdaptContext) -> None:
-    context.adapters.register_dumper(list, ListDumper)
+    # The text dumper is more flexible as it can handle lists of mixed type,
+    # so register it later.
     context.adapters.register_dumper(list, ListBinaryDumper)
+    context.adapters.register_dumper(list, ListDumper)
 
 
 def register_all_arrays(context: AdaptContext) -> None:
index ee836bffce39bef1a5a8baf9bfbc75391be75758..88033fc8a8b593ef419ec50e4d50d1d3cf11461a 100644 (file)
@@ -293,8 +293,12 @@ def test_array_dumper(conn, fmt_out):
     t = Transformer(conn)
     fmt_in = Format.from_pq(fmt_out)
     dint = t.get_dumper([0], fmt_in)
-    assert dint.oid == builtins["int2"].array_oid
-    assert dint.sub_dumper.oid == builtins["int2"].oid
+    if fmt_out == pq.Format.BINARY:
+        assert dint.oid == builtins["int2"].array_oid
+        assert dint.sub_dumper.oid == builtins["int2"].oid
+    else:
+        assert dint.oid == builtins["numeric"].array_oid
+        assert dint.sub_dumper is None
 
     dstr = t.get_dumper([""], fmt_in)
     if fmt_in == Format.BINARY:
index 4e7296654b077fde761d9e5f42cb31a3af0d4a1d..69e83d6666371a501c09be20f44761b3cb43eed5 100644 (file)
@@ -1,4 +1,7 @@
+from decimal import Decimal
+
 import pytest
+
 import psycopg
 from psycopg import pq
 from psycopg import sql
@@ -148,6 +151,15 @@ def test_array_mixed_numbers(array, type):
     assert dumper.oid == builtins[type].array_oid
 
 
+def test_mix_types(conn):
+    cur = conn.cursor()
+    cur.execute("create table test (id serial primary key, data numeric[])")
+    cur.execute("insert into test (data) values (%s)", ([1, 2, 0.5],))
+    cur.execute("select data from test")
+    assert cur.fetchone()[0] == [1, 2, Decimal("0.5")]
+    assert cur.description[0].type_code == builtins["numeric"].array_oid
+
+
 @pytest.mark.parametrize("fmt_in", fmts_in)
 def test_empty_list_mix(conn, fmt_in):
     objs = list(range(3))
index 10beb092ae9f49cd774a48650e482ec9426eef58..3ac0c11fe1c930ad545f176868818407cb5d3635 100644 (file)
@@ -8,6 +8,7 @@ import psycopg.errors
 from psycopg import pq
 from psycopg.sql import Identifier
 from psycopg.adapt import PyFormat as Format
+from psycopg.types import range as range_module
 from psycopg.types.range import Range, RangeInfo
 
 
@@ -59,11 +60,39 @@ def test_dump_builtin_empty(conn, pgtype, fmt_in):
     assert cur.fetchone()[0] is True
 
 
+@pytest.mark.parametrize(
+    "wrapper",
+    """
+    Int4Range Int8Range NumericRange DateRange TimestampRange TimestamptzRange
+    """.split(),
+)
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_builtin_empty_wrapper(conn, wrapper, fmt_in):
+    wrapper = getattr(range_module, wrapper)
+    r = wrapper(empty=True)
+    cur = conn.execute(f"select 'empty' = %{fmt_in}", (r,))
+    assert cur.fetchone()[0] is True
+
+
 @pytest.mark.parametrize(
     "pgtype",
     "int4range int8range numrange daterange tsrange tstzrange".split(),
 )
-@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize(
+    "fmt_in",
+    [
+        Format.AUTO,
+        Format.TEXT,
+        # There are many ways to work around this (use text, use a cast on the
+        # placeholder, use specific Range subclasses).
+        pytest.param(
+            Format.BINARY,
+            marks=pytest.mark.xfail(
+                reason="can't dump an array of untypes binary range without cast"
+            ),
+        ),
+    ],
+)
 def test_dump_builtin_array(conn, pgtype, fmt_in):
     r1 = Range(empty=True)
     r2 = Range(bounds="()")
@@ -74,6 +103,38 @@ def test_dump_builtin_array(conn, pgtype, fmt_in):
     assert cur.fetchone()[0] is True
 
 
+@pytest.mark.parametrize(
+    "pgtype",
+    "int4range int8range numrange daterange tsrange tstzrange".split(),
+)
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_builtin_array_with_cast(conn, pgtype, fmt_in):
+    r1 = Range(empty=True)
+    r2 = Range(bounds="()")
+    cur = conn.execute(
+        f"select array['empty'::{pgtype}, '(,)'::{pgtype}] = %{fmt_in}::{pgtype}[]",
+        ([r1, r2],),
+    )
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "wrapper",
+    """
+    Int4Range Int8Range NumericRange DateRange TimestampRange TimestamptzRange
+    """.split(),
+)
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_builtin_array_wrapper(conn, wrapper, fmt_in):
+    wrapper = getattr(range_module, wrapper)
+    r1 = wrapper(empty=True)
+    r2 = wrapper(bounds="()")
+    cur = conn.execute(
+        f"""select '{{empty,"(,)"}}' = %{fmt_in}""", ([r1, r2],)
+    )
+    assert cur.fetchone()[0] is True
+
+
 @pytest.mark.parametrize("pgtype, min, max, bounds", samples)
 @pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
 def test_dump_builtin_range(conn, pgtype, min, max, bounds, fmt_in):
@@ -191,13 +252,12 @@ def test_copy_in_empty(conn, min, max, bounds, format):
 @pytest.mark.parametrize("bounds", "() empty".split())
 @pytest.mark.parametrize(
     "wrapper",
-    """Int4Range Int8Range NumericRange
-       DateRange TimestampRange TimestamptzRange""".split(),
+    """
+    Int4Range Int8Range NumericRange DateRange TimestampRange TimestamptzRange
+    """.split(),
 )
 @pytest.mark.parametrize("format", [pq.Format.TEXT, pq.Format.BINARY])
 def test_copy_in_empty_wrappers(conn, bounds, wrapper, format):
-    from psycopg.types import range as range_module
-
     cur = conn.cursor()
     cur.execute("create table copyrange (id serial primary key, r daterange)")