import re
import struct
-from decimal import Decimal
from typing import Any, cast, Callable, Iterator, List
from typing import Optional, Pattern, Set, Tuple, Type
from functools import lru_cache
"""
it = self._flatiter(L, set())
try:
- return next(it)
+ item = next(it)
except StopIteration:
return None
+ # Checking for precise type. If the type is a subclass (e.g. Int4)
+ # we assume the user knows what type they are passing.
+ if type(item) is not int:
+ return item
+
+ # If we got an int, let's see what is the biggest one in order to
+ # choose the smallest OID and allow Postgres to do the right cast.
+ it = self._flatiter(L, set())
+ imax = max((i if i >= 0 else -i - 1 for i in it), default=0)
+ imax = max(item if item >= 0 else -item - 1, imax)
+ return imax
+
def _flatiter(self, L: List[Any], seen: Set[int]) -> Any:
if id(L) in seen:
raise e.DataError("cannot dump a recursive list")
if item is None:
return self.cls
- # If we got a number, let's dump them as numeric text array.
- # Don't check for subclasses because if someone has used Int2 etc
- # they probably know better what they want.
- if type(item) in MixedNumbersListDumper.NUMBERS_TYPES:
- return MixedNumbersListDumper
-
sd = self._tx.get_dumper(item, format)
return (self.cls, sd.get_key(item, format)) # type: ignore
# Empty lists can only be dumped as text if the type is unknown.
return self
- if type(item) in MixedNumbersListDumper.NUMBERS_TYPES:
- return MixedNumbersListDumper(self.cls, self._tx)
-
sd = self._tx.get_dumper(item, format.from_pq(self.format))
dumper = type(self)(self.cls, self._tx)
dumper.sub_dumper = sd
)
-class MixedItemsListDumper(ListDumper):
- """
- An array dumper that doesn't assume that all the items are the same type.
-
- Such dumper can be only textual and return either unknown oid or something
- that work for every type contained.
- """
-
- def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
- return self.cls
-
- def _dump_item(self, item: Any) -> Buffer:
- # If we get here, the sub_dumper must have been set
- return self._tx.get_dumper(item, PyFormat.TEXT).dump(item)
-
-
-class MixedNumbersListDumper(MixedItemsListDumper):
- """
- A text dumper to dump lists containing any number as numeric array.
- """
-
- NUMBERS_TYPES = (int, float, Decimal)
-
- oid = postgres.types["numeric"].array_oid
-
-
class ListBinaryDumper(BaseListDumper):
format = pq.Format.BINARY
data[1] = b"".join(_pack_dim(dim, 1) for dim in dims)
return b"".join(data)
- def _find_list_element(self, L: List[Any]) -> Any:
- item = super()._find_list_element(L)
- if not isinstance(item, int):
- return item
-
- # If we got an int, let's see what is the biggest onw
- it = self._flatiter(L, set())
- imax = max((i if i >= 0 else -i - 1 for i in it), default=0)
- imax = max(item if item >= 0 else -item, imax)
- return imax
-
class BaseArrayLoader(RecursiveLoader):
base_oid: int
t = Transformer(conn)
fmt_in = PyFormat.from_pq(fmt_out)
dint = t.get_dumper([0], fmt_in)
- if fmt_out == pq.Format.BINARY:
- assert isinstance(dint, ListBinaryDumper)
- assert dint.oid == builtins["int2"].array_oid
- assert dint.sub_dumper and dint.sub_dumper.oid == builtins["int2"].oid
- else:
- assert isinstance(dint, ListDumper)
- assert dint.oid == builtins["numeric"].array_oid
- assert dint.sub_dumper is None
+ assert isinstance(dint, (ListDumper, ListBinaryDumper))
+ assert dint.oid == builtins["int2"].array_oid
+ assert dint.sub_dumper and dint.sub_dumper.oid == builtins["int2"].oid
dstr = t.get_dumper([""], fmt_in)
if fmt_in == PyFormat.BINARY:
@pytest.mark.parametrize(
- "array, type",
+ "num, type",
[
- ([0], "int2"),
- ([1, 2**15 - 1], "int2"),
- ([1, -(2**15)], "int2"),
- ([1, 2**15], "int4"),
- ([1, 2**31 - 1], "int4"),
- ([1, -(2**31)], "int4"),
- ([1, 2**31], "int8"),
- ([1, 2**63 - 1], "int8"),
- ([1, -(2**63)], "int8"),
- ([1, 2**63], "numeric"),
+ (0, "int2"),
+ (2**15 - 1, "int2"),
+ (-(2**15), "int2"),
+ (2**15, "int4"),
+ (2**31 - 1, "int4"),
+ (-(2**31), "int4"),
+ (2**31, "int8"),
+ (2**63 - 1, "int8"),
+ (-(2**63), "int8"),
+ (2**63, "numeric"),
],
)
@pytest.mark.parametrize("fmt_in", PyFormat)
-def test_numbers_array(array, type, fmt_in):
- tx = Transformer()
- dumper = tx.get_dumper(array, fmt_in)
- dumper.dump(array)
- assert dumper.oid == builtins[type].array_oid
+def test_numbers_array(num, type, fmt_in):
+ for array in ([num], [1, num]):
+ tx = Transformer()
+ dumper = tx.get_dumper(array, fmt_in)
+ dumper.dump(array)
+ assert dumper.oid == builtins[type].array_oid
@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Float4 Float8 Decimal".split())
assert type(i) is want_cls
-def test_mix_types(conn):
- cur = conn.cursor()
- cur.execute("create table test (id serial primary key, data numeric[])")
- cur.execute("insert into test (data) values (%s)", ([1, 2, 0.5],))
- cur.execute("select data from test")
- assert cur.fetchone()[0] == [1, 2, Decimal("0.5")]
- assert cur.description[0].type_code == builtins["numeric"].array_oid
-
-
@pytest.mark.parametrize("fmt_in", PyFormat)
def test_empty_list_mix(conn, fmt_in):
objs = list(range(3))