]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix array oid dumping lists of integer
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 8 May 2022 14:29:32 +0000 (16:29 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 8 May 2022 19:11:44 +0000 (21:11 +0200)
The problem was only in the list text dumper, which was trying to do more
than what needed. The list binary dumper worked alright.

Drop mixed array test, which don't really represent anything Postgres
can work with.

Close #293

docs/news.rst
psycopg/psycopg/types/array.py
tests/test_adapt.py
tests/types/test_array.py

index ac33b9a3400d50dbbc4c0d2d9389e39f84795387..f636d9627c7371f3cac5522da78783e5a9c5ec55 100644 (file)
@@ -30,6 +30,8 @@ Psycopg 3.0.13 (unreleased)
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
 - Fix `Cursor.stream()` slowness (:ticket:`#286`).
+- Fix oid for lists of integers, which might cause the server choosing
+  bad plans (:ticket:`#293`).
 - Make `Connection.cancel()` on a closed connection a no-op instead of an
   error.
 
index a4cf19b3d5e668cfd479e691f27cb29ae1eece12..f2b28cdec068aa5b2086c1a51395140a29035093 100644 (file)
@@ -6,7 +6,6 @@ Adapters for arrays
 
 import re
 import struct
-from decimal import Decimal
 from typing import Any, cast, Callable, Iterator, List
 from typing import Optional, Pattern, Set, Tuple, Type
 from functools import lru_cache
@@ -50,10 +49,22 @@ class BaseListDumper(RecursiveDumper):
         """
         it = self._flatiter(L, set())
         try:
-            return next(it)
+            item = next(it)
         except StopIteration:
             return None
 
+        # Checking for precise type. If the type is a subclass (e.g. Int4)
+        # we assume the user knows what type they are passing.
+        if type(item) is not int:
+            return item
+
+        # If we got an int, let's see what is the biggest one in order to
+        # choose the smallest OID and allow Postgres to do the right cast.
+        it = self._flatiter(L, set())
+        imax = max((i if i >= 0 else -i - 1 for i in it), default=0)
+        imax = max(item if item >= 0 else -item - 1, imax)
+        return imax
+
     def _flatiter(self, L: List[Any], seen: Set[int]) -> Any:
         if id(L) in seen:
             raise e.DataError("cannot dump a recursive list")
@@ -94,12 +105,6 @@ class ListDumper(BaseListDumper):
         if item is None:
             return self.cls
 
-        # If we got a number, let's dump them as numeric text array.
-        # Don't check for subclasses because if someone has used Int2 etc
-        # they probably know better what they want.
-        if type(item) in MixedNumbersListDumper.NUMBERS_TYPES:
-            return MixedNumbersListDumper
-
         sd = self._tx.get_dumper(item, format)
         return (self.cls, sd.get_key(item, format))  # type: ignore
 
@@ -113,9 +118,6 @@ class ListDumper(BaseListDumper):
             # Empty lists can only be dumped as text if the type is unknown.
             return self
 
-        if type(item) in MixedNumbersListDumper.NUMBERS_TYPES:
-            return MixedNumbersListDumper(self.cls, self._tx)
-
         sd = self._tx.get_dumper(item, format.from_pq(self.format))
         dumper = type(self)(self.cls, self._tx)
         dumper.sub_dumper = sd
@@ -193,32 +195,6 @@ def _get_needs_quotes_regexp(delimiter: bytes) -> Pattern[bytes]:
     )
 
 
-class MixedItemsListDumper(ListDumper):
-    """
-    An array dumper that doesn't assume that all the items are the same type.
-
-    Such dumper can be only textual and return either unknown oid or something
-    that work for every type contained.
-    """
-
-    def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
-        return self.cls
-
-    def _dump_item(self, item: Any) -> Buffer:
-        # If we get here, the sub_dumper must have been set
-        return self._tx.get_dumper(item, PyFormat.TEXT).dump(item)
-
-
-class MixedNumbersListDumper(MixedItemsListDumper):
-    """
-    A text dumper to dump lists containing any number as numeric array.
-    """
-
-    NUMBERS_TYPES = (int, float, Decimal)
-
-    oid = postgres.types["numeric"].array_oid
-
-
 class ListBinaryDumper(BaseListDumper):
 
     format = pq.Format.BINARY
@@ -298,17 +274,6 @@ class ListBinaryDumper(BaseListDumper):
         data[1] = b"".join(_pack_dim(dim, 1) for dim in dims)
         return b"".join(data)
 
-    def _find_list_element(self, L: List[Any]) -> Any:
-        item = super()._find_list_element(L)
-        if not isinstance(item, int):
-            return item
-
-        # If we got an int, let's see what is the biggest onw
-        it = self._flatiter(L, set())
-        imax = max((i if i >= 0 else -i - 1 for i in it), default=0)
-        imax = max(item if item >= 0 else -item, imax)
-        return imax
-
 
 class BaseArrayLoader(RecursiveLoader):
     base_oid: int
index c4ade8ac47e6ba6a2cadc5353f15d27585b20491..a64ac00959976d930215eda86b5467dd9923e471 100644 (file)
@@ -293,14 +293,9 @@ def test_array_dumper(conn, fmt_out):
     t = Transformer(conn)
     fmt_in = PyFormat.from_pq(fmt_out)
     dint = t.get_dumper([0], fmt_in)
-    if fmt_out == pq.Format.BINARY:
-        assert isinstance(dint, ListBinaryDumper)
-        assert dint.oid == builtins["int2"].array_oid
-        assert dint.sub_dumper and dint.sub_dumper.oid == builtins["int2"].oid
-    else:
-        assert isinstance(dint, ListDumper)
-        assert dint.oid == builtins["numeric"].array_oid
-        assert dint.sub_dumper is None
+    assert isinstance(dint, (ListDumper, ListBinaryDumper))
+    assert dint.oid == builtins["int2"].array_oid
+    assert dint.sub_dumper and dint.sub_dumper.oid == builtins["int2"].oid
 
     dstr = t.get_dumper([""], fmt_in)
     if fmt_in == PyFormat.BINARY:
index a0b27c3c3b94d259683c8893e19249de65d830c8..1a747ef504c7aacba64f4a0d1429b291b7aaa732 100644 (file)
@@ -140,26 +140,27 @@ def test_array_of_unknown_builtin(conn):
 
 
 @pytest.mark.parametrize(
-    "array, type",
+    "num, type",
     [
-        ([0], "int2"),
-        ([1, 2**15 - 1], "int2"),
-        ([1, -(2**15)], "int2"),
-        ([1, 2**15], "int4"),
-        ([1, 2**31 - 1], "int4"),
-        ([1, -(2**31)], "int4"),
-        ([1, 2**31], "int8"),
-        ([1, 2**63 - 1], "int8"),
-        ([1, -(2**63)], "int8"),
-        ([1, 2**63], "numeric"),
+        (0, "int2"),
+        (2**15 - 1, "int2"),
+        (-(2**15), "int2"),
+        (2**15, "int4"),
+        (2**31 - 1, "int4"),
+        (-(2**31), "int4"),
+        (2**31, "int8"),
+        (2**63 - 1, "int8"),
+        (-(2**63), "int8"),
+        (2**63, "numeric"),
     ],
 )
 @pytest.mark.parametrize("fmt_in", PyFormat)
-def test_numbers_array(array, type, fmt_in):
-    tx = Transformer()
-    dumper = tx.get_dumper(array, fmt_in)
-    dumper.dump(array)
-    assert dumper.oid == builtins[type].array_oid
+def test_numbers_array(num, type, fmt_in):
+    for array in ([num], [1, num]):
+        tx = Transformer()
+        dumper = tx.get_dumper(array, fmt_in)
+        dumper.dump(array)
+        assert dumper.oid == builtins[type].array_oid
 
 
 @pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Float4 Float8 Decimal".split())
@@ -182,15 +183,6 @@ def test_list_number_wrapper(conn, wrapper, fmt_in, fmt_out):
             assert type(i) is want_cls
 
 
-def test_mix_types(conn):
-    cur = conn.cursor()
-    cur.execute("create table test (id serial primary key, data numeric[])")
-    cur.execute("insert into test (data) values (%s)", ([1, 2, 0.5],))
-    cur.execute("select data from test")
-    assert cur.fetchone()[0] == [1, 2, Decimal("0.5")]
-    assert cur.description[0].type_code == builtins["numeric"].array_oid
-
-
 @pytest.mark.parametrize("fmt_in", PyFormat)
 def test_empty_list_mix(conn, fmt_in):
     objs = list(range(3))