import re
import struct
+from decimal import Decimal
from typing import Any, Callable, Iterator, List, Optional, Set, Tuple, Type
from typing import cast
super().__init__(cls, context)
self.sub_dumper: Optional[Dumper] = None
- def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
- item = self._find_list_element(obj)
- if item is not None:
- sd = self._tx.get_dumper(item, format)
- return (self.cls, sd.get_key(item, format)) # type: ignore
- else:
- return (self.cls,)
-
- def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
- item = self._find_list_element(obj)
- if item is None:
- # Empty lists can only be dumped as text if the type is unknown.
- return ListDumper(self.cls, self._tx)
-
- sd = self._tx.get_dumper(item, format)
- dcls = ListDumper if sd.format == pq.Format.TEXT else ListBinaryDumper
- dumper = dcls(self.cls, self._tx)
- dumper.sub_dumper = sd
-
- # We consider an array of unknowns as unknown, so we can dump empty
- # lists or lists containing only None elements.
- if sd.oid != INVALID_OID:
- dumper.oid = self._get_array_oid(sd.oid)
- else:
- dumper.oid = INVALID_OID
-
- return dumper
-
def _find_list_element(self, L: List[Any]) -> Any:
"""
Find the first non-null element of an eventually nested list
"""
it = self._flatiter(L, set())
try:
- item = next(it)
+ return next(it)
except StopIteration:
return None
- if not isinstance(item, int):
- return item
-
- imax = max((i if i >= 0 else -i - 1 for i in it), default=0)
- imax = max(item if item >= 0 else -item, imax)
- return imax
-
def _flatiter(self, L: List[Any], seen: Set[int]) -> Any:
if id(L) in seen:
raise e.DataError("cannot dump a recursive list")
format = pq.Format.TEXT
+ def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
+ if self.oid:
+ return self.cls
+
+ item = self._find_list_element(obj)
+ if item is None:
+ return self.cls
+
+ # If we got a number, let's dump them as numeric text array.
+ # Don't check for subclasses because if someone has used Int2 etc
+ # they probably know better what they want.
+ if type(item) in MixedNumbersListDumper.NUMBERS_TYPES:
+ return MixedNumbersListDumper
+
+ sd = self._tx.get_dumper(item, format)
+ return (self.cls, sd.get_key(item, format)) # type: ignore
+
+ def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
+ # If we have an oid we don't need to upgrade
+ if self.oid:
+ return self
+
+ item = self._find_list_element(obj)
+ if item is None:
+ # Empty lists can only be dumped as text if the type is unknown.
+ return self
+
+ if type(item) in MixedNumbersListDumper.NUMBERS_TYPES:
+ return MixedNumbersListDumper(self.cls, self._tx)
+
+ sd = self._tx.get_dumper(item, format.from_pq(self.format))
+ dumper = type(self)(self.cls, self._tx)
+ dumper.sub_dumper = sd
+
+ # We consider an array of unknowns as unknown, so we can dump empty
+ # lists or lists containing only None elements.
+ if sd.oid != INVALID_OID:
+ dumper.oid = self._get_array_oid(sd.oid)
+ else:
+ dumper.oid = INVALID_OID
+
+ return dumper
+
# from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
#
# The array output routine will put double quotes around element values if
if isinstance(item, list):
dump_list(item)
elif item is not None:
- # If we get here, the sub_dumper must have been set
- ad = self.sub_dumper.dump(item) # type: ignore[union-attr]
+ ad = self._dump_item(item)
if self._re_needs_quotes.search(ad):
- ad = (
- b'"' + self._re_esc.sub(br"\\\1", bytes(ad)) + b'"'
- )
+ if not isinstance(ad, bytes):
+ ad = bytes(ad)
+ ad = b'"' + self._re_esc.sub(br"\\\1", ad) + b'"'
tokens.append(ad)
else:
tokens.append(b"NULL")
return b"".join(tokens)
+ def _dump_item(self, item: Any) -> Buffer:
+ if self.sub_dumper:
+ return self.sub_dumper.dump(item)
+ else:
+ return self._tx.get_dumper(item, PyFormat.TEXT).dump(item)
+
+
+class MixedItemsListDumper(ListDumper):
+ """
+ An array dumper that doesn't assume that all the items are the same type.
+
+ Such dumper can be only textual and return either unknown oid or something
+ that work for every type contained.
+ """
+
+ def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
+ return self.cls
+
+ def _dump_item(self, item: Any) -> Buffer:
+ # If we get here, the sub_dumper must have been set
+ return self._tx.get_dumper(item, PyFormat.TEXT).dump(item)
+
+
+class MixedNumbersListDumper(MixedItemsListDumper):
+ """
+ A text dumper to dump lists containing any number as numeric array.
+ """
+
+ NUMBERS_TYPES = (int, float, Decimal)
+
+ _oid = postgres.types["numeric"].array_oid
+
class ListBinaryDumper(BaseListDumper):
format = pq.Format.BINARY
+ def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
+ if self.oid:
+ return self.cls
+
+ item = self._find_list_element(obj)
+ if item is None:
+ return (self.cls,)
+
+ sd = self._tx.get_dumper(item, format)
+ return (self.cls, sd.get_key(item, format)) # type: ignore
+
+ def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
+ # If we have an oid we don't need to upgrade
+ if self.oid:
+ return self
+
+ item = self._find_list_element(obj)
+ if item is None:
+ return ListDumper(self.cls, self._tx)
+
+ sd = self._tx.get_dumper(item, format.from_pq(self.format))
+ dumper = type(self)(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ dumper.oid = self._get_array_oid(sd.oid)
+
+ return dumper
+
def dump(self, obj: List[Any]) -> bytes:
# Postgres won't take unknown for element oid: fall back on text
sub_oid = self.sub_dumper and self.sub_dumper.oid or TEXT_OID
data[1] = b"".join(_pack_dim(dim, 1) for dim in dims)
return b"".join(data)
+ def _find_list_element(self, L: List[Any]) -> Any:
+ item = super()._find_list_element(L)
+ if not isinstance(item, int):
+ return item
+
+ # If we got an int, let's see what is the biggest onw
+ it = self._flatiter(L, set())
+ imax = max((i if i >= 0 else -i - 1 for i in it), default=0)
+ imax = max(item if item >= 0 else -item, imax)
+ return imax
+
class BaseArrayLoader(RecursiveLoader):
base_oid: int
def register_default_adapters(context: AdaptContext) -> None:
- context.adapters.register_dumper(list, ListDumper)
+ # The text dumper is more flexible as it can handle lists of mixed type,
+ # so register it later.
context.adapters.register_dumper(list, ListBinaryDumper)
+ context.adapters.register_dumper(list, ListDumper)
def register_all_arrays(context: AdaptContext) -> None:
from psycopg import pq
from psycopg.sql import Identifier
from psycopg.adapt import PyFormat as Format
+from psycopg.types import range as range_module
from psycopg.types.range import Range, RangeInfo
assert cur.fetchone()[0] is True
+@pytest.mark.parametrize(
+ "wrapper",
+ """
+ Int4Range Int8Range NumericRange DateRange TimestampRange TimestamptzRange
+ """.split(),
+)
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_builtin_empty_wrapper(conn, wrapper, fmt_in):
+ wrapper = getattr(range_module, wrapper)
+ r = wrapper(empty=True)
+ cur = conn.execute(f"select 'empty' = %{fmt_in}", (r,))
+ assert cur.fetchone()[0] is True
+
+
@pytest.mark.parametrize(
"pgtype",
"int4range int8range numrange daterange tsrange tstzrange".split(),
)
-@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize(
+ "fmt_in",
+ [
+ Format.AUTO,
+ Format.TEXT,
+ # There are many ways to work around this (use text, use a cast on the
+ # placeholder, use specific Range subclasses).
+ pytest.param(
+ Format.BINARY,
+ marks=pytest.mark.xfail(
+ reason="can't dump an array of untypes binary range without cast"
+ ),
+ ),
+ ],
+)
def test_dump_builtin_array(conn, pgtype, fmt_in):
r1 = Range(empty=True)
r2 = Range(bounds="()")
assert cur.fetchone()[0] is True
+@pytest.mark.parametrize(
+ "pgtype",
+ "int4range int8range numrange daterange tsrange tstzrange".split(),
+)
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_builtin_array_with_cast(conn, pgtype, fmt_in):
+ r1 = Range(empty=True)
+ r2 = Range(bounds="()")
+ cur = conn.execute(
+ f"select array['empty'::{pgtype}, '(,)'::{pgtype}] = %{fmt_in}::{pgtype}[]",
+ ([r1, r2],),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "wrapper",
+ """
+ Int4Range Int8Range NumericRange DateRange TimestampRange TimestamptzRange
+ """.split(),
+)
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_builtin_array_wrapper(conn, wrapper, fmt_in):
+ wrapper = getattr(range_module, wrapper)
+ r1 = wrapper(empty=True)
+ r2 = wrapper(bounds="()")
+ cur = conn.execute(
+ f"""select '{{empty,"(,)"}}' = %{fmt_in}""", ([r1, r2],)
+ )
+ assert cur.fetchone()[0] is True
+
+
@pytest.mark.parametrize("pgtype, min, max, bounds", samples)
@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
def test_dump_builtin_range(conn, pgtype, min, max, bounds, fmt_in):
@pytest.mark.parametrize("bounds", "() empty".split())
@pytest.mark.parametrize(
"wrapper",
- """Int4Range Int8Range NumericRange
- DateRange TimestampRange TimestamptzRange""".split(),
+ """
+ Int4Range Int8Range NumericRange DateRange TimestampRange TimestamptzRange
+ """.split(),
)
@pytest.mark.parametrize("format", [pq.Format.TEXT, pq.Format.BINARY])
def test_copy_in_empty_wrappers(conn, bounds, wrapper, format):
- from psycopg.types import range as range_module
-
cur = conn.cursor()
cur.execute("create table copyrange (id serial primary key, r daterange)")