From: Daniele Varrazzo Date: Sat, 2 Oct 2021 03:31:23 +0000 (+0200) Subject: Add multirange adaptation X-Git-Tag: 3.0~32^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2061ce1d2912baa0c91de44842d7c2d8bbb5ba2e;p=thirdparty%2Fpsycopg.git Add multirange adaptation To be tested yet Close #75 --- diff --git a/psycopg/psycopg/postgres.py b/psycopg/psycopg/postgres.py index fe878a953..474cadce4 100644 --- a/psycopg/psycopg/postgres.py +++ b/psycopg/psycopg/postgres.py @@ -119,7 +119,7 @@ TEXT_ARRAY_OID = types["text"].array_oid def register_default_adapters(context: AdaptContext) -> None: - from .types import array, bool, composite, datetime, json + from .types import array, bool, composite, datetime, json, multirange from .types import net, none, numeric, range, string, uuid array.register_default_adapters(context) @@ -127,6 +127,7 @@ def register_default_adapters(context: AdaptContext) -> None: composite.register_default_adapters(context) datetime.register_default_adapters(context) json.register_default_adapters(context) + multirange.register_default_adapters(context) net.register_default_adapters(context) none.register_default_adapters(context) numeric.register_default_adapters(context) diff --git a/psycopg/psycopg/types/multirange.py b/psycopg/psycopg/types/multirange.py new file mode 100644 index 000000000..dd80ce1d1 --- /dev/null +++ b/psycopg/psycopg/types/multirange.py @@ -0,0 +1,440 @@ +""" +Support for multirange types adaptation. +""" + +# Copyright (C) 2021 The Psycopg Team + +from decimal import Decimal +from typing import Any, Generic, List, Iterable +from typing import MutableSequence, Optional, Union, overload +from datetime import date, datetime + +from .. import errors as e +from .. import postgres +from ..pq import Format +from ..abc import AdaptContext, Buffer, Dumper, DumperKey +from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat +from .._struct import pack_len, unpack_len +from ..postgres import INVALID_OID, TEXT_OID +from .._typeinfo import MultirangeInfo as MultirangeInfo # exported here + +from .range import Range, T, load_range_text, load_range_binary +from .range import dump_range_text, dump_range_binary, fail_dump + + +class Multirange(MutableSequence[Range[T]]): + def __init__(self, items: Iterable[Range[T]] = ()): + self._ranges: List[Range[T]] = list(items) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._ranges!r})" + + def __str__(self) -> str: + return f"{{{', '.join(map(str, self._ranges))}}}" + + @overload + def __getitem__(self, index: int) -> Range[T]: + ... + + @overload + def __getitem__(self, index: slice) -> "Multirange[T]": + ... + + def __getitem__( + self, index: Union[int, slice] + ) -> "Union[Range[T],Multirange[T]]": + if isinstance(index, int): + return self._ranges[index] + else: + return Multirange(self._ranges[index]) + + def __len__(self) -> int: + return len(self._ranges) + + @overload + def __setitem__(self, index: int, value: Range[T]) -> None: + ... + + @overload + def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None: + ... + + def __setitem__( + self, + index: Union[int, slice], + value: Union[Range[T], Iterable[Range[T]]], + ) -> None: + self._ranges[index] = value # type: ignore + + def __delitem__(self, index: Union[int, slice]) -> None: + del self._ranges[index] + + def insert(self, index: int, value: Range[T]) -> None: + self._ranges.insert(index, value) + + +# Subclasses to specify a specific subtype. Usually not needed + + +class Int4Multirange(Multirange[int]): + pass + + +class Int8Multirange(Multirange[int]): + pass + + +class NumericMultirange(Multirange[Decimal]): + pass + + +class DateMultirange(Multirange[date]): + pass + + +class TimestampMultirange(Multirange[datetime]): + pass + + +class TimestamptzMultirange(Multirange[datetime]): + pass + + +class BaseMultirangeDumper(RecursiveDumper): + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self.sub_dumper: Optional[Dumper] = None + self._adapt_format = PyFormat.from_pq(self.format) + + def get_key(self, obj: Multirange[Any], format: PyFormat) -> DumperKey: + # If we are a subclass whose oid is specified we don't need upgrade + if self.cls is not Multirange: + return self.cls + + item = self._get_item(obj) + if item is not None: + sd = self._tx.get_dumper(item, self._adapt_format) + return (self.cls, sd.get_key(item, format)) # type: ignore + else: + return (self.cls,) + + def upgrade( + self, obj: Multirange[Any], format: PyFormat + ) -> "BaseMultirangeDumper": + # If we are a subclass whose oid is specified we don't need upgrade + if self.cls is not Multirange: + return self + + item = self._get_item(obj) + if item is None: + return MultirangeDumper(self.cls) + + dumper: BaseMultirangeDumper + if type(item) is int: + # postgres won't cast int4range -> int8range so we must use + # text format and unknown oid here + sd = self._tx.get_dumper(item, PyFormat.TEXT) + dumper = MultirangeDumper(self.cls, self._tx) + dumper.sub_dumper = sd + dumper.oid = INVALID_OID + return dumper + + sd = self._tx.get_dumper(item, format) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + if sd.oid == INVALID_OID and isinstance(item, str): + # Work around the normal mapping where text is dumped as unknown + dumper.oid = self._get_multirange_oid(TEXT_OID) + else: + dumper.oid = self._get_multirange_oid(sd.oid) + + return dumper + + def _get_item(self, obj: Multirange[Any]) -> Any: + """ + Return a member representative of the multirange + """ + for r in obj: + if r.lower is not None: + return r.lower + if r.upper is not None: + return r.upper + return None + + def _get_multirange_oid(self, sub_oid: int) -> int: + """ + Return the oid of the range from the oid of its elements. + """ + info = self._tx.adapters.types.get_by_subtype(MultirangeInfo, sub_oid) + return info.oid if info else INVALID_OID + + +class MultirangeDumper(BaseMultirangeDumper): + """ + Dumper for multirange types. + + The dumper can upgrade to one specific for a different range type. + """ + + def dump(self, obj: Multirange[Any]) -> Buffer: + if not obj: + return b"{}" + + item = self._get_item(obj) + if item is not None: + dump = self._tx.get_dumper(item, self._adapt_format).dump + else: + dump = fail_dump + + out = [b"{"] + for r in obj: + out.append(dump_range_text(r, dump)) + out.append(b",") + out[-1] = b"}" + return b"".join(out) + + +class MultirangeBinaryDumper(BaseMultirangeDumper): + + format = Format.BINARY + + def dump(self, obj: Multirange[Any]) -> Buffer: + item = self._get_item(obj) + if item is not None: + dump = self._tx.get_dumper(item, self._adapt_format).dump + else: + dump = fail_dump + + out = [pack_len(len(obj))] + for r in obj: + data = dump_range_binary(r, dump) + out.append(pack_len(len(data))) + out.append(data) + return b"".join(out) + + +class BaseMultirangeLoader(RecursiveLoader, Generic[T]): + + subtype_oid: int + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._load = self._tx.get_loader( + self.subtype_oid, format=self.format + ).load + + +class MultirangeLoader(BaseMultirangeLoader[T]): + def load(self, data: Buffer) -> Multirange[T]: + if not data or data[0] != _START_INT: + raise e.DataError( + f"malformed multirange starting with" + f" {bytes(data[:1]).decode('utf8', 'replace')}" + ) + + out = Multirange[T]() + if data == b"{}": + return out + + pos = 1 + data = data[pos:] + try: + while True: + r, pos = load_range_text(data, self._load) + out.append(r) + + sep = data[pos] # can raise IndexError + if sep == _SEP_INT: + data = data[pos + 1 :] + continue + elif sep == _END_INT: + if len(data) == pos + 1: + return out + else: + raise e.DataError( + "malformed multirange: data after closing brace" + ) + else: + raise e.DataError( + f"malformed multirange: found unexpected {chr(sep)}" + ) + + except IndexError: + raise e.DataError("malformed multirange: separator missing") + + return out + + +_SEP_INT = ord(",") +_START_INT = ord("{") +_END_INT = ord("}") + + +class MultirangeBinaryLoader(BaseMultirangeLoader[T]): + + format = Format.BINARY + + def load(self, data: Buffer) -> Multirange[T]: + nelems = unpack_len(data, 0)[0] + pos = 4 + out = Multirange[T]() + for i in range(nelems): + length = unpack_len(data, pos)[0] + pos += 4 + out.append(load_range_binary(data[pos : pos + length], self._load)) + pos += length + + if pos != len(data): + raise e.DataError("unexpected trailing data in multirange") + + return out + + +# Text dumpers for builtin multirange types wrappers +# These are registered on specific subtypes so that the upgrade mechanism +# doesn't kick in. + + +class Int4MultirangeDumper(MultirangeDumper): + oid = postgres.types["int4multirange"].oid + + +class Int8MultirangeDumper(MultirangeDumper): + oid = postgres.types["int8multirange"].oid + + +class NumericMultirangeDumper(MultirangeDumper): + oid = postgres.types["nummultirange"].oid + + +class DateMultirangeDumper(MultirangeDumper): + oid = postgres.types["datemultirange"].oid + + +class TimestampMultirangeDumper(MultirangeDumper): + oid = postgres.types["tsmultirange"].oid + + +class TimestamptzMultirangeDumper(MultirangeDumper): + oid = postgres.types["tstzmultirange"].oid + + +# Binary dumpers for builtin multirange types wrappers +# These are registered on specific subtypes so that the upgrade mechanism +# doesn't kick in. + + +class Int4MultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["int4multirange"].oid + + +class Int8MultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["int8multirange"].oid + + +class NumericMultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["nummultirange"].oid + + +class DateMultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["datemultirange"].oid + + +class TimestampMultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["tsmultirange"].oid + + +class TimestamptzMultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["tstzmultirange"].oid + + +# Text loaders for builtin multirange types + + +class Int4MultirangeLoader(MultirangeLoader[int]): + subtype_oid = postgres.types["int4"].oid + + +class Int8MultirangeLoader(MultirangeLoader[int]): + subtype_oid = postgres.types["int8"].oid + + +class NumericMultirangeLoader(MultirangeLoader[Decimal]): + subtype_oid = postgres.types["numeric"].oid + + +class DateMultirangeLoader(MultirangeLoader[date]): + subtype_oid = postgres.types["date"].oid + + +class TimestampMultirangeLoader(MultirangeLoader[datetime]): + subtype_oid = postgres.types["timestamp"].oid + + +class TimestampTZMultirangeLoader(MultirangeLoader[datetime]): + subtype_oid = postgres.types["timestamptz"].oid + + +# Binary loaders for builtin multirange types + + +class Int4MultirangeBinaryLoader(MultirangeBinaryLoader[int]): + subtype_oid = postgres.types["int4"].oid + + +class Int8MultirangeBinaryLoader(MultirangeBinaryLoader[int]): + subtype_oid = postgres.types["int8"].oid + + +class NumericMultirangeBinaryLoader(MultirangeBinaryLoader[Decimal]): + subtype_oid = postgres.types["numeric"].oid + + +class DateMultirangeBinaryLoader(MultirangeBinaryLoader[date]): + subtype_oid = postgres.types["date"].oid + + +class TimestampMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]): + subtype_oid = postgres.types["timestamp"].oid + + +class TimestampTZMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]): + subtype_oid = postgres.types["timestamptz"].oid + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(Multirange, MultirangeBinaryDumper) + adapters.register_dumper(Multirange, MultirangeDumper) + adapters.register_dumper(Int4Multirange, Int4MultirangeDumper) + adapters.register_dumper(Int8Multirange, Int8MultirangeDumper) + adapters.register_dumper(NumericMultirange, NumericMultirangeDumper) + adapters.register_dumper(DateMultirange, DateMultirangeDumper) + adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper) + adapters.register_dumper( + TimestamptzMultirange, TimestamptzMultirangeDumper + ) + adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper) + adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper) + adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper) + adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper) + adapters.register_dumper( + TimestampMultirange, TimestampMultirangeBinaryDumper + ) + adapters.register_dumper( + TimestamptzMultirange, TimestamptzMultirangeBinaryDumper + ) + adapters.register_loader("int4multirange", Int4MultirangeLoader) + adapters.register_loader("int8multirange", Int8MultirangeLoader) + adapters.register_loader("nummultirange", NumericMultirangeLoader) + adapters.register_loader("datemultirange", DateMultirangeLoader) + adapters.register_loader("tsmultirange", TimestampMultirangeLoader) + adapters.register_loader("tstzmultirange", TimestampTZMultirangeLoader) + adapters.register_loader("int4multirange", Int4MultirangeBinaryLoader) + adapters.register_loader("int8multirange", Int8MultirangeBinaryLoader) + adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader) + adapters.register_loader("datemultirange", DateMultirangeBinaryLoader) + adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader) + adapters.register_loader( + "tstzmultirange", TimestampTZMultirangeBinaryLoader + ) diff --git a/tests/fix_faker.py b/tests/fix_faker.py index 149ef68c2..beb8a93d3 100644 --- a/tests/fix_faker.py +++ b/tests/fix_faker.py @@ -238,6 +238,12 @@ class Faker: for cls in dumpers.keys(): if isinstance(cls, str): cls = deep_import(cls) + if ( + issubclass(cls, psycopg.types.multirange.Multirange) + and self.conn.info.server_version < 140000 + ): + continue + rv.add(cls) # check all the types are handled diff --git a/tests/types/test_multirange.py b/tests/types/test_multirange.py new file mode 100644 index 000000000..8416c1e57 --- /dev/null +++ b/tests/types/test_multirange.py @@ -0,0 +1,20 @@ +import pytest + +from psycopg.adapt import PyFormat +from psycopg.types.multirange import Multirange + +pytestmark = pytest.mark.pg(">= 14") + +mr_names = """int4multirange int8multirange nummultirange + datemultirange tsmultirange tstzmultirange""".split() + +mr_classes = """Int4Multirange Int8Multirange NumericMultirange + DateMultirange TimestampMultirange TimestamptzMultirange""".split() + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_empty(conn, pgtype, fmt_in): + mr = Multirange() + cur = conn.execute(f"select '{{}}'::{pgtype} = %{fmt_in}", (mr,)) + assert cur.fetchone()[0] is True