From: Daniele Varrazzo Date: Mon, 26 Jul 2021 18:56:59 +0000 (+0200) Subject: Allow some form of dumping lists of mixed types X-Git-Tag: 3.0.dev2~29 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=cd594df2bdb0572ab07e7ab08eff5d88cb6b4ea3;p=thirdparty%2Fpsycopg.git Allow some form of dumping lists of mixed types Lists of numbers are now dumped as numeric[]. Default to dump text for arrays. --- diff --git a/psycopg/psycopg/types/array.py b/psycopg/psycopg/types/array.py index 1b82e76ba..319003559 100644 --- a/psycopg/psycopg/types/array.py +++ b/psycopg/psycopg/types/array.py @@ -6,6 +6,7 @@ Adapters for arrays import re import struct +from decimal import Decimal from typing import Any, Callable, Iterator, List, Optional, Set, Tuple, Type from typing import cast @@ -37,51 +38,16 @@ class BaseListDumper(RecursiveDumper): super().__init__(cls, context) self.sub_dumper: Optional[Dumper] = None - def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey: - item = self._find_list_element(obj) - if item is not None: - sd = self._tx.get_dumper(item, format) - return (self.cls, sd.get_key(item, format)) # type: ignore - else: - return (self.cls,) - - def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper": - item = self._find_list_element(obj) - if item is None: - # Empty lists can only be dumped as text if the type is unknown. - return ListDumper(self.cls, self._tx) - - sd = self._tx.get_dumper(item, format) - dcls = ListDumper if sd.format == pq.Format.TEXT else ListBinaryDumper - dumper = dcls(self.cls, self._tx) - dumper.sub_dumper = sd - - # We consider an array of unknowns as unknown, so we can dump empty - # lists or lists containing only None elements. - if sd.oid != INVALID_OID: - dumper.oid = self._get_array_oid(sd.oid) - else: - dumper.oid = INVALID_OID - - return dumper - def _find_list_element(self, L: List[Any]) -> Any: """ Find the first non-null element of an eventually nested list """ it = self._flatiter(L, set()) try: - item = next(it) + return next(it) except StopIteration: return None - if not isinstance(item, int): - return item - - imax = max((i if i >= 0 else -i - 1 for i in it), default=0) - imax = max(item if item >= 0 else -item, imax) - return imax - def _flatiter(self, L: List[Any], seen: Set[int]) -> Any: if id(L) in seen: raise e.DataError("cannot dump a recursive list") @@ -116,6 +82,49 @@ class ListDumper(BaseListDumper): format = pq.Format.TEXT + def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey: + if self.oid: + return self.cls + + item = self._find_list_element(obj) + if item is None: + return self.cls + + # If we got a number, let's dump them as numeric text array. + # Don't check for subclasses because if someone has used Int2 etc + # they probably know better what they want. + if type(item) in MixedNumbersListDumper.NUMBERS_TYPES: + return MixedNumbersListDumper + + sd = self._tx.get_dumper(item, format) + return (self.cls, sd.get_key(item, format)) # type: ignore + + def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper": + # If we have an oid we don't need to upgrade + if self.oid: + return self + + item = self._find_list_element(obj) + if item is None: + # Empty lists can only be dumped as text if the type is unknown. + return self + + if type(item) in MixedNumbersListDumper.NUMBERS_TYPES: + return MixedNumbersListDumper(self.cls, self._tx) + + sd = self._tx.get_dumper(item, format.from_pq(self.format)) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + + # We consider an array of unknowns as unknown, so we can dump empty + # lists or lists containing only None elements. + if sd.oid != INVALID_OID: + dumper.oid = self._get_array_oid(sd.oid) + else: + dumper.oid = INVALID_OID + + return dumper + # from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO # # The array output routine will put double quotes around element values if @@ -147,12 +156,11 @@ class ListDumper(BaseListDumper): if isinstance(item, list): dump_list(item) elif item is not None: - # If we get here, the sub_dumper must have been set - ad = self.sub_dumper.dump(item) # type: ignore[union-attr] + ad = self._dump_item(item) if self._re_needs_quotes.search(ad): - ad = ( - b'"' + self._re_esc.sub(br"\\\1", bytes(ad)) + b'"' - ) + if not isinstance(ad, bytes): + ad = bytes(ad) + ad = b'"' + self._re_esc.sub(br"\\\1", ad) + b'"' tokens.append(ad) else: tokens.append(b"NULL") @@ -165,11 +173,70 @@ class ListDumper(BaseListDumper): return b"".join(tokens) + def _dump_item(self, item: Any) -> Buffer: + if self.sub_dumper: + return self.sub_dumper.dump(item) + else: + return self._tx.get_dumper(item, PyFormat.TEXT).dump(item) + + +class MixedItemsListDumper(ListDumper): + """ + An array dumper that doesn't assume that all the items are the same type. + + Such dumper can be only textual and return either unknown oid or something + that work for every type contained. + """ + + def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey: + return self.cls + + def _dump_item(self, item: Any) -> Buffer: + # If we get here, the sub_dumper must have been set + return self._tx.get_dumper(item, PyFormat.TEXT).dump(item) + + +class MixedNumbersListDumper(MixedItemsListDumper): + """ + A text dumper to dump lists containing any number as numeric array. + """ + + NUMBERS_TYPES = (int, float, Decimal) + + _oid = postgres.types["numeric"].array_oid + class ListBinaryDumper(BaseListDumper): format = pq.Format.BINARY + def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey: + if self.oid: + return self.cls + + item = self._find_list_element(obj) + if item is None: + return (self.cls,) + + sd = self._tx.get_dumper(item, format) + return (self.cls, sd.get_key(item, format)) # type: ignore + + def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper": + # If we have an oid we don't need to upgrade + if self.oid: + return self + + item = self._find_list_element(obj) + if item is None: + return ListDumper(self.cls, self._tx) + + sd = self._tx.get_dumper(item, format.from_pq(self.format)) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + dumper.oid = self._get_array_oid(sd.oid) + + return dumper + def dump(self, obj: List[Any]) -> bytes: # Postgres won't take unknown for element oid: fall back on text sub_oid = self.sub_dumper and self.sub_dumper.oid or TEXT_OID @@ -219,6 +286,17 @@ class ListBinaryDumper(BaseListDumper): data[1] = b"".join(_pack_dim(dim, 1) for dim in dims) return b"".join(data) + def _find_list_element(self, L: List[Any]) -> Any: + item = super()._find_list_element(L) + if not isinstance(item, int): + return item + + # If we got an int, let's see what is the biggest onw + it = self._flatiter(L, set()) + imax = max((i if i >= 0 else -i - 1 for i in it), default=0) + imax = max(item if item >= 0 else -item, imax) + return imax + class BaseArrayLoader(RecursiveLoader): base_oid: int @@ -331,8 +409,10 @@ def register_adapters( def register_default_adapters(context: AdaptContext) -> None: - context.adapters.register_dumper(list, ListDumper) + # The text dumper is more flexible as it can handle lists of mixed type, + # so register it later. context.adapters.register_dumper(list, ListBinaryDumper) + context.adapters.register_dumper(list, ListDumper) def register_all_arrays(context: AdaptContext) -> None: diff --git a/tests/test_adapt.py b/tests/test_adapt.py index ee836bffc..88033fc8a 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -293,8 +293,12 @@ def test_array_dumper(conn, fmt_out): t = Transformer(conn) fmt_in = Format.from_pq(fmt_out) dint = t.get_dumper([0], fmt_in) - assert dint.oid == builtins["int2"].array_oid - assert dint.sub_dumper.oid == builtins["int2"].oid + if fmt_out == pq.Format.BINARY: + assert dint.oid == builtins["int2"].array_oid + assert dint.sub_dumper.oid == builtins["int2"].oid + else: + assert dint.oid == builtins["numeric"].array_oid + assert dint.sub_dumper is None dstr = t.get_dumper([""], fmt_in) if fmt_in == Format.BINARY: diff --git a/tests/types/test_array.py b/tests/types/test_array.py index 4e7296654..69e83d666 100644 --- a/tests/types/test_array.py +++ b/tests/types/test_array.py @@ -1,4 +1,7 @@ +from decimal import Decimal + import pytest + import psycopg from psycopg import pq from psycopg import sql @@ -148,6 +151,15 @@ def test_array_mixed_numbers(array, type): assert dumper.oid == builtins[type].array_oid +def test_mix_types(conn): + cur = conn.cursor() + cur.execute("create table test (id serial primary key, data numeric[])") + cur.execute("insert into test (data) values (%s)", ([1, 2, 0.5],)) + cur.execute("select data from test") + assert cur.fetchone()[0] == [1, 2, Decimal("0.5")] + assert cur.description[0].type_code == builtins["numeric"].array_oid + + @pytest.mark.parametrize("fmt_in", fmts_in) def test_empty_list_mix(conn, fmt_in): objs = list(range(3)) diff --git a/tests/types/test_range.py b/tests/types/test_range.py index 10beb092a..3ac0c11fe 100644 --- a/tests/types/test_range.py +++ b/tests/types/test_range.py @@ -8,6 +8,7 @@ import psycopg.errors from psycopg import pq from psycopg.sql import Identifier from psycopg.adapt import PyFormat as Format +from psycopg.types import range as range_module from psycopg.types.range import Range, RangeInfo @@ -59,11 +60,39 @@ def test_dump_builtin_empty(conn, pgtype, fmt_in): assert cur.fetchone()[0] is True +@pytest.mark.parametrize( + "wrapper", + """ + Int4Range Int8Range NumericRange DateRange TimestampRange TimestamptzRange + """.split(), +) +@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) +def test_dump_builtin_empty_wrapper(conn, wrapper, fmt_in): + wrapper = getattr(range_module, wrapper) + r = wrapper(empty=True) + cur = conn.execute(f"select 'empty' = %{fmt_in}", (r,)) + assert cur.fetchone()[0] is True + + @pytest.mark.parametrize( "pgtype", "int4range int8range numrange daterange tsrange tstzrange".split(), ) -@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize( + "fmt_in", + [ + Format.AUTO, + Format.TEXT, + # There are many ways to work around this (use text, use a cast on the + # placeholder, use specific Range subclasses). + pytest.param( + Format.BINARY, + marks=pytest.mark.xfail( + reason="can't dump an array of untypes binary range without cast" + ), + ), + ], +) def test_dump_builtin_array(conn, pgtype, fmt_in): r1 = Range(empty=True) r2 = Range(bounds="()") @@ -74,6 +103,38 @@ def test_dump_builtin_array(conn, pgtype, fmt_in): assert cur.fetchone()[0] is True +@pytest.mark.parametrize( + "pgtype", + "int4range int8range numrange daterange tsrange tstzrange".split(), +) +@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) +def test_dump_builtin_array_with_cast(conn, pgtype, fmt_in): + r1 = Range(empty=True) + r2 = Range(bounds="()") + cur = conn.execute( + f"select array['empty'::{pgtype}, '(,)'::{pgtype}] = %{fmt_in}::{pgtype}[]", + ([r1, r2],), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "wrapper", + """ + Int4Range Int8Range NumericRange DateRange TimestampRange TimestamptzRange + """.split(), +) +@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) +def test_dump_builtin_array_wrapper(conn, wrapper, fmt_in): + wrapper = getattr(range_module, wrapper) + r1 = wrapper(empty=True) + r2 = wrapper(bounds="()") + cur = conn.execute( + f"""select '{{empty,"(,)"}}' = %{fmt_in}""", ([r1, r2],) + ) + assert cur.fetchone()[0] is True + + @pytest.mark.parametrize("pgtype, min, max, bounds", samples) @pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) def test_dump_builtin_range(conn, pgtype, min, max, bounds, fmt_in): @@ -191,13 +252,12 @@ def test_copy_in_empty(conn, min, max, bounds, format): @pytest.mark.parametrize("bounds", "() empty".split()) @pytest.mark.parametrize( "wrapper", - """Int4Range Int8Range NumericRange - DateRange TimestampRange TimestamptzRange""".split(), + """ + Int4Range Int8Range NumericRange DateRange TimestampRange TimestamptzRange + """.split(), ) @pytest.mark.parametrize("format", [pq.Format.TEXT, pq.Format.BINARY]) def test_copy_in_empty_wrappers(conn, bounds, wrapper, format): - from psycopg.types import range as range_module - cur = conn.cursor() cur.execute("create table copyrange (id serial primary key, r daterange)")