--- /dev/null
+"""
+Support for multirange types adaptation.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from decimal import Decimal
+from typing import Any, Generic, List, Iterable
+from typing import MutableSequence, Optional, Union, overload
+from datetime import date, datetime
+
+from .. import errors as e
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext, Buffer, Dumper, DumperKey
+from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
+from .._struct import pack_len, unpack_len
+from ..postgres import INVALID_OID, TEXT_OID
+from .._typeinfo import MultirangeInfo as MultirangeInfo # exported here
+
+from .range import Range, T, load_range_text, load_range_binary
+from .range import dump_range_text, dump_range_binary, fail_dump
+
+
+class Multirange(MutableSequence[Range[T]]):
+ def __init__(self, items: Iterable[Range[T]] = ()):
+ self._ranges: List[Range[T]] = list(items)
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({self._ranges!r})"
+
+ def __str__(self) -> str:
+ return f"{{{', '.join(map(str, self._ranges))}}}"
+
+ @overload
+ def __getitem__(self, index: int) -> Range[T]:
+ ...
+
+ @overload
+ def __getitem__(self, index: slice) -> "Multirange[T]":
+ ...
+
+ def __getitem__(
+ self, index: Union[int, slice]
+ ) -> "Union[Range[T],Multirange[T]]":
+ if isinstance(index, int):
+ return self._ranges[index]
+ else:
+ return Multirange(self._ranges[index])
+
+ def __len__(self) -> int:
+ return len(self._ranges)
+
+ @overload
+ def __setitem__(self, index: int, value: Range[T]) -> None:
+ ...
+
+ @overload
+ def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None:
+ ...
+
+ def __setitem__(
+ self,
+ index: Union[int, slice],
+ value: Union[Range[T], Iterable[Range[T]]],
+ ) -> None:
+ self._ranges[index] = value # type: ignore
+
+ def __delitem__(self, index: Union[int, slice]) -> None:
+ del self._ranges[index]
+
+ def insert(self, index: int, value: Range[T]) -> None:
+ self._ranges.insert(index, value)
+
+
+# Subclasses to specify a specific subtype. Usually not needed
+
+
+class Int4Multirange(Multirange[int]):
+ pass
+
+
+class Int8Multirange(Multirange[int]):
+ pass
+
+
+class NumericMultirange(Multirange[Decimal]):
+ pass
+
+
+class DateMultirange(Multirange[date]):
+ pass
+
+
+class TimestampMultirange(Multirange[datetime]):
+ pass
+
+
+class TimestamptzMultirange(Multirange[datetime]):
+ pass
+
+
+class BaseMultirangeDumper(RecursiveDumper):
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ self.sub_dumper: Optional[Dumper] = None
+ self._adapt_format = PyFormat.from_pq(self.format)
+
+ def get_key(self, obj: Multirange[Any], format: PyFormat) -> DumperKey:
+ # If we are a subclass whose oid is specified we don't need upgrade
+ if self.cls is not Multirange:
+ return self.cls
+
+ item = self._get_item(obj)
+ if item is not None:
+ sd = self._tx.get_dumper(item, self._adapt_format)
+ return (self.cls, sd.get_key(item, format)) # type: ignore
+ else:
+ return (self.cls,)
+
+ def upgrade(
+ self, obj: Multirange[Any], format: PyFormat
+ ) -> "BaseMultirangeDumper":
+ # If we are a subclass whose oid is specified we don't need upgrade
+ if self.cls is not Multirange:
+ return self
+
+ item = self._get_item(obj)
+ if item is None:
+ return MultirangeDumper(self.cls)
+
+ dumper: BaseMultirangeDumper
+ if type(item) is int:
+ # postgres won't cast int4range -> int8range so we must use
+ # text format and unknown oid here
+ sd = self._tx.get_dumper(item, PyFormat.TEXT)
+ dumper = MultirangeDumper(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ dumper.oid = INVALID_OID
+ return dumper
+
+ sd = self._tx.get_dumper(item, format)
+ dumper = type(self)(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ if sd.oid == INVALID_OID and isinstance(item, str):
+ # Work around the normal mapping where text is dumped as unknown
+ dumper.oid = self._get_multirange_oid(TEXT_OID)
+ else:
+ dumper.oid = self._get_multirange_oid(sd.oid)
+
+ return dumper
+
+ def _get_item(self, obj: Multirange[Any]) -> Any:
+ """
+ Return a member representative of the multirange
+ """
+ for r in obj:
+ if r.lower is not None:
+ return r.lower
+ if r.upper is not None:
+ return r.upper
+ return None
+
+ def _get_multirange_oid(self, sub_oid: int) -> int:
+ """
+ Return the oid of the range from the oid of its elements.
+ """
+ info = self._tx.adapters.types.get_by_subtype(MultirangeInfo, sub_oid)
+ return info.oid if info else INVALID_OID
+
+
+class MultirangeDumper(BaseMultirangeDumper):
+ """
+ Dumper for multirange types.
+
+ The dumper can upgrade to one specific for a different range type.
+ """
+
+ def dump(self, obj: Multirange[Any]) -> Buffer:
+ if not obj:
+ return b"{}"
+
+ item = self._get_item(obj)
+ if item is not None:
+ dump = self._tx.get_dumper(item, self._adapt_format).dump
+ else:
+ dump = fail_dump
+
+ out = [b"{"]
+ for r in obj:
+ out.append(dump_range_text(r, dump))
+ out.append(b",")
+ out[-1] = b"}"
+ return b"".join(out)
+
+
+class MultirangeBinaryDumper(BaseMultirangeDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: Multirange[Any]) -> Buffer:
+ item = self._get_item(obj)
+ if item is not None:
+ dump = self._tx.get_dumper(item, self._adapt_format).dump
+ else:
+ dump = fail_dump
+
+ out = [pack_len(len(obj))]
+ for r in obj:
+ data = dump_range_binary(r, dump)
+ out.append(pack_len(len(data)))
+ out.append(data)
+ return b"".join(out)
+
+
+class BaseMultirangeLoader(RecursiveLoader, Generic[T]):
+
+ subtype_oid: int
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._load = self._tx.get_loader(
+ self.subtype_oid, format=self.format
+ ).load
+
+
+class MultirangeLoader(BaseMultirangeLoader[T]):
+ def load(self, data: Buffer) -> Multirange[T]:
+ if not data or data[0] != _START_INT:
+ raise e.DataError(
+ f"malformed multirange starting with"
+ f" {bytes(data[:1]).decode('utf8', 'replace')}"
+ )
+
+ out = Multirange[T]()
+ if data == b"{}":
+ return out
+
+ pos = 1
+ data = data[pos:]
+ try:
+ while True:
+ r, pos = load_range_text(data, self._load)
+ out.append(r)
+
+ sep = data[pos] # can raise IndexError
+ if sep == _SEP_INT:
+ data = data[pos + 1 :]
+ continue
+ elif sep == _END_INT:
+ if len(data) == pos + 1:
+ return out
+ else:
+ raise e.DataError(
+ "malformed multirange: data after closing brace"
+ )
+ else:
+ raise e.DataError(
+ f"malformed multirange: found unexpected {chr(sep)}"
+ )
+
+ except IndexError:
+ raise e.DataError("malformed multirange: separator missing")
+
+ return out
+
+
+_SEP_INT = ord(",")
+_START_INT = ord("{")
+_END_INT = ord("}")
+
+
+class MultirangeBinaryLoader(BaseMultirangeLoader[T]):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Multirange[T]:
+ nelems = unpack_len(data, 0)[0]
+ pos = 4
+ out = Multirange[T]()
+ for i in range(nelems):
+ length = unpack_len(data, pos)[0]
+ pos += 4
+ out.append(load_range_binary(data[pos : pos + length], self._load))
+ pos += length
+
+ if pos != len(data):
+ raise e.DataError("unexpected trailing data in multirange")
+
+ return out
+
+
+# Text dumpers for builtin multirange types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4MultirangeDumper(MultirangeDumper):
+ oid = postgres.types["int4multirange"].oid
+
+
+class Int8MultirangeDumper(MultirangeDumper):
+ oid = postgres.types["int8multirange"].oid
+
+
+class NumericMultirangeDumper(MultirangeDumper):
+ oid = postgres.types["nummultirange"].oid
+
+
+class DateMultirangeDumper(MultirangeDumper):
+ oid = postgres.types["datemultirange"].oid
+
+
+class TimestampMultirangeDumper(MultirangeDumper):
+ oid = postgres.types["tsmultirange"].oid
+
+
+class TimestamptzMultirangeDumper(MultirangeDumper):
+ oid = postgres.types["tstzmultirange"].oid
+
+
+# Binary dumpers for builtin multirange types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4MultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["int4multirange"].oid
+
+
+class Int8MultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["int8multirange"].oid
+
+
+class NumericMultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["nummultirange"].oid
+
+
+class DateMultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["datemultirange"].oid
+
+
+class TimestampMultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["tsmultirange"].oid
+
+
+class TimestamptzMultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["tstzmultirange"].oid
+
+
+# Text loaders for builtin multirange types
+
+
+class Int4MultirangeLoader(MultirangeLoader[int]):
+ subtype_oid = postgres.types["int4"].oid
+
+
+class Int8MultirangeLoader(MultirangeLoader[int]):
+ subtype_oid = postgres.types["int8"].oid
+
+
+class NumericMultirangeLoader(MultirangeLoader[Decimal]):
+ subtype_oid = postgres.types["numeric"].oid
+
+
+class DateMultirangeLoader(MultirangeLoader[date]):
+ subtype_oid = postgres.types["date"].oid
+
+
+class TimestampMultirangeLoader(MultirangeLoader[datetime]):
+ subtype_oid = postgres.types["timestamp"].oid
+
+
+class TimestampTZMultirangeLoader(MultirangeLoader[datetime]):
+ subtype_oid = postgres.types["timestamptz"].oid
+
+
+# Binary loaders for builtin multirange types
+
+
+class Int4MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
+ subtype_oid = postgres.types["int4"].oid
+
+
+class Int8MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
+ subtype_oid = postgres.types["int8"].oid
+
+
+class NumericMultirangeBinaryLoader(MultirangeBinaryLoader[Decimal]):
+ subtype_oid = postgres.types["numeric"].oid
+
+
+class DateMultirangeBinaryLoader(MultirangeBinaryLoader[date]):
+ subtype_oid = postgres.types["date"].oid
+
+
+class TimestampMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
+ subtype_oid = postgres.types["timestamp"].oid
+
+
+class TimestampTZMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
+ subtype_oid = postgres.types["timestamptz"].oid
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(Multirange, MultirangeBinaryDumper)
+ adapters.register_dumper(Multirange, MultirangeDumper)
+ adapters.register_dumper(Int4Multirange, Int4MultirangeDumper)
+ adapters.register_dumper(Int8Multirange, Int8MultirangeDumper)
+ adapters.register_dumper(NumericMultirange, NumericMultirangeDumper)
+ adapters.register_dumper(DateMultirange, DateMultirangeDumper)
+ adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper)
+ adapters.register_dumper(
+ TimestamptzMultirange, TimestamptzMultirangeDumper
+ )
+ adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper)
+ adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper)
+ adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper)
+ adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper)
+ adapters.register_dumper(
+ TimestampMultirange, TimestampMultirangeBinaryDumper
+ )
+ adapters.register_dumper(
+ TimestamptzMultirange, TimestamptzMultirangeBinaryDumper
+ )
+ adapters.register_loader("int4multirange", Int4MultirangeLoader)
+ adapters.register_loader("int8multirange", Int8MultirangeLoader)
+ adapters.register_loader("nummultirange", NumericMultirangeLoader)
+ adapters.register_loader("datemultirange", DateMultirangeLoader)
+ adapters.register_loader("tsmultirange", TimestampMultirangeLoader)
+ adapters.register_loader("tstzmultirange", TimestampTZMultirangeLoader)
+ adapters.register_loader("int4multirange", Int4MultirangeBinaryLoader)
+ adapters.register_loader("int8multirange", Int8MultirangeBinaryLoader)
+ adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader)
+ adapters.register_loader("datemultirange", DateMultirangeBinaryLoader)
+ adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader)
+ adapters.register_loader(
+ "tstzmultirange", TimestampTZMultirangeBinaryLoader
+ )