From: Daniele Varrazzo Date: Sun, 8 May 2022 14:29:32 +0000 (+0200) Subject: Fix array oid dumping lists of integer X-Git-Tag: 3.1~117^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6f08ffb4c2348d7ee8424c75b667dc6ac4ac1cc1;p=thirdparty%2Fpsycopg.git Fix array oid dumping lists of integer The problem was only in the list text dumper, which was trying to do more than what needed. The list binary dumper worked alright. Drop mixed array test, which don't really represent anything Postgres can work with. Close #293 --- diff --git a/docs/news.rst b/docs/news.rst index ac33b9a34..f636d9627 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -30,6 +30,8 @@ Psycopg 3.0.13 (unreleased) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ - Fix `Cursor.stream()` slowness (:ticket:`#286`). +- Fix oid for lists of integers, which might cause the server choosing + bad plans (:ticket:`#293`). - Make `Connection.cancel()` on a closed connection a no-op instead of an error. diff --git a/psycopg/psycopg/types/array.py b/psycopg/psycopg/types/array.py index a4cf19b3d..f2b28cdec 100644 --- a/psycopg/psycopg/types/array.py +++ b/psycopg/psycopg/types/array.py @@ -6,7 +6,6 @@ Adapters for arrays import re import struct -from decimal import Decimal from typing import Any, cast, Callable, Iterator, List from typing import Optional, Pattern, Set, Tuple, Type from functools import lru_cache @@ -50,10 +49,22 @@ class BaseListDumper(RecursiveDumper): """ it = self._flatiter(L, set()) try: - return next(it) + item = next(it) except StopIteration: return None + # Checking for precise type. If the type is a subclass (e.g. Int4) + # we assume the user knows what type they are passing. + if type(item) is not int: + return item + + # If we got an int, let's see what is the biggest one in order to + # choose the smallest OID and allow Postgres to do the right cast. + it = self._flatiter(L, set()) + imax = max((i if i >= 0 else -i - 1 for i in it), default=0) + imax = max(item if item >= 0 else -item - 1, imax) + return imax + def _flatiter(self, L: List[Any], seen: Set[int]) -> Any: if id(L) in seen: raise e.DataError("cannot dump a recursive list") @@ -94,12 +105,6 @@ class ListDumper(BaseListDumper): if item is None: return self.cls - # If we got a number, let's dump them as numeric text array. - # Don't check for subclasses because if someone has used Int2 etc - # they probably know better what they want. - if type(item) in MixedNumbersListDumper.NUMBERS_TYPES: - return MixedNumbersListDumper - sd = self._tx.get_dumper(item, format) return (self.cls, sd.get_key(item, format)) # type: ignore @@ -113,9 +118,6 @@ class ListDumper(BaseListDumper): # Empty lists can only be dumped as text if the type is unknown. return self - if type(item) in MixedNumbersListDumper.NUMBERS_TYPES: - return MixedNumbersListDumper(self.cls, self._tx) - sd = self._tx.get_dumper(item, format.from_pq(self.format)) dumper = type(self)(self.cls, self._tx) dumper.sub_dumper = sd @@ -193,32 +195,6 @@ def _get_needs_quotes_regexp(delimiter: bytes) -> Pattern[bytes]: ) -class MixedItemsListDumper(ListDumper): - """ - An array dumper that doesn't assume that all the items are the same type. - - Such dumper can be only textual and return either unknown oid or something - that work for every type contained. - """ - - def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey: - return self.cls - - def _dump_item(self, item: Any) -> Buffer: - # If we get here, the sub_dumper must have been set - return self._tx.get_dumper(item, PyFormat.TEXT).dump(item) - - -class MixedNumbersListDumper(MixedItemsListDumper): - """ - A text dumper to dump lists containing any number as numeric array. - """ - - NUMBERS_TYPES = (int, float, Decimal) - - oid = postgres.types["numeric"].array_oid - - class ListBinaryDumper(BaseListDumper): format = pq.Format.BINARY @@ -298,17 +274,6 @@ class ListBinaryDumper(BaseListDumper): data[1] = b"".join(_pack_dim(dim, 1) for dim in dims) return b"".join(data) - def _find_list_element(self, L: List[Any]) -> Any: - item = super()._find_list_element(L) - if not isinstance(item, int): - return item - - # If we got an int, let's see what is the biggest onw - it = self._flatiter(L, set()) - imax = max((i if i >= 0 else -i - 1 for i in it), default=0) - imax = max(item if item >= 0 else -item, imax) - return imax - class BaseArrayLoader(RecursiveLoader): base_oid: int diff --git a/tests/test_adapt.py b/tests/test_adapt.py index c4ade8ac4..a64ac0095 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -293,14 +293,9 @@ def test_array_dumper(conn, fmt_out): t = Transformer(conn) fmt_in = PyFormat.from_pq(fmt_out) dint = t.get_dumper([0], fmt_in) - if fmt_out == pq.Format.BINARY: - assert isinstance(dint, ListBinaryDumper) - assert dint.oid == builtins["int2"].array_oid - assert dint.sub_dumper and dint.sub_dumper.oid == builtins["int2"].oid - else: - assert isinstance(dint, ListDumper) - assert dint.oid == builtins["numeric"].array_oid - assert dint.sub_dumper is None + assert isinstance(dint, (ListDumper, ListBinaryDumper)) + assert dint.oid == builtins["int2"].array_oid + assert dint.sub_dumper and dint.sub_dumper.oid == builtins["int2"].oid dstr = t.get_dumper([""], fmt_in) if fmt_in == PyFormat.BINARY: diff --git a/tests/types/test_array.py b/tests/types/test_array.py index a0b27c3c3..1a747ef50 100644 --- a/tests/types/test_array.py +++ b/tests/types/test_array.py @@ -140,26 +140,27 @@ def test_array_of_unknown_builtin(conn): @pytest.mark.parametrize( - "array, type", + "num, type", [ - ([0], "int2"), - ([1, 2**15 - 1], "int2"), - ([1, -(2**15)], "int2"), - ([1, 2**15], "int4"), - ([1, 2**31 - 1], "int4"), - ([1, -(2**31)], "int4"), - ([1, 2**31], "int8"), - ([1, 2**63 - 1], "int8"), - ([1, -(2**63)], "int8"), - ([1, 2**63], "numeric"), + (0, "int2"), + (2**15 - 1, "int2"), + (-(2**15), "int2"), + (2**15, "int4"), + (2**31 - 1, "int4"), + (-(2**31), "int4"), + (2**31, "int8"), + (2**63 - 1, "int8"), + (-(2**63), "int8"), + (2**63, "numeric"), ], ) @pytest.mark.parametrize("fmt_in", PyFormat) -def test_numbers_array(array, type, fmt_in): - tx = Transformer() - dumper = tx.get_dumper(array, fmt_in) - dumper.dump(array) - assert dumper.oid == builtins[type].array_oid +def test_numbers_array(num, type, fmt_in): + for array in ([num], [1, num]): + tx = Transformer() + dumper = tx.get_dumper(array, fmt_in) + dumper.dump(array) + assert dumper.oid == builtins[type].array_oid @pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Float4 Float8 Decimal".split()) @@ -182,15 +183,6 @@ def test_list_number_wrapper(conn, wrapper, fmt_in, fmt_out): assert type(i) is want_cls -def test_mix_types(conn): - cur = conn.cursor() - cur.execute("create table test (id serial primary key, data numeric[])") - cur.execute("insert into test (data) values (%s)", ([1, 2, 0.5],)) - cur.execute("select data from test") - assert cur.fetchone()[0] == [1, 2, Decimal("0.5")] - assert cur.description[0].type_code == builtins["numeric"].array_oid - - @pytest.mark.parametrize("fmt_in", PyFormat) def test_empty_list_mix(conn, fmt_in): objs = list(range(3))