]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add multirange adaptation
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Oct 2021 03:31:23 +0000 (05:31 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 4 Oct 2021 12:45:56 +0000 (14:45 +0200)
To be tested yet

Close #75

psycopg/psycopg/postgres.py
psycopg/psycopg/types/multirange.py [new file with mode: 0644]
tests/fix_faker.py
tests/types/test_multirange.py [new file with mode: 0644]

index fe878a953efe88fd5dbbcf6689ac453fb06761c4..474cadce4461dca85e75c3aaac307890e58bd98e 100644 (file)
@@ -119,7 +119,7 @@ TEXT_ARRAY_OID = types["text"].array_oid
 
 def register_default_adapters(context: AdaptContext) -> None:
 
-    from .types import array, bool, composite, datetime, json
+    from .types import array, bool, composite, datetime, json, multirange
     from .types import net, none, numeric, range, string, uuid
 
     array.register_default_adapters(context)
@@ -127,6 +127,7 @@ def register_default_adapters(context: AdaptContext) -> None:
     composite.register_default_adapters(context)
     datetime.register_default_adapters(context)
     json.register_default_adapters(context)
+    multirange.register_default_adapters(context)
     net.register_default_adapters(context)
     none.register_default_adapters(context)
     numeric.register_default_adapters(context)
diff --git a/psycopg/psycopg/types/multirange.py b/psycopg/psycopg/types/multirange.py
new file mode 100644 (file)
index 0000000..dd80ce1
--- /dev/null
@@ -0,0 +1,440 @@
+"""
+Support for multirange types adaptation.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from decimal import Decimal
+from typing import Any, Generic, List, Iterable
+from typing import MutableSequence, Optional, Union, overload
+from datetime import date, datetime
+
+from .. import errors as e
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext, Buffer, Dumper, DumperKey
+from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
+from .._struct import pack_len, unpack_len
+from ..postgres import INVALID_OID, TEXT_OID
+from .._typeinfo import MultirangeInfo as MultirangeInfo  # exported here
+
+from .range import Range, T, load_range_text, load_range_binary
+from .range import dump_range_text, dump_range_binary, fail_dump
+
+
+class Multirange(MutableSequence[Range[T]]):
+    def __init__(self, items: Iterable[Range[T]] = ()):
+        self._ranges: List[Range[T]] = list(items)
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}({self._ranges!r})"
+
+    def __str__(self) -> str:
+        return f"{{{', '.join(map(str, self._ranges))}}}"
+
+    @overload
+    def __getitem__(self, index: int) -> Range[T]:
+        ...
+
+    @overload
+    def __getitem__(self, index: slice) -> "Multirange[T]":
+        ...
+
+    def __getitem__(
+        self, index: Union[int, slice]
+    ) -> "Union[Range[T],Multirange[T]]":
+        if isinstance(index, int):
+            return self._ranges[index]
+        else:
+            return Multirange(self._ranges[index])
+
+    def __len__(self) -> int:
+        return len(self._ranges)
+
+    @overload
+    def __setitem__(self, index: int, value: Range[T]) -> None:
+        ...
+
+    @overload
+    def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None:
+        ...
+
+    def __setitem__(
+        self,
+        index: Union[int, slice],
+        value: Union[Range[T], Iterable[Range[T]]],
+    ) -> None:
+        self._ranges[index] = value  # type: ignore
+
+    def __delitem__(self, index: Union[int, slice]) -> None:
+        del self._ranges[index]
+
+    def insert(self, index: int, value: Range[T]) -> None:
+        self._ranges.insert(index, value)
+
+
+# Subclasses to specify a specific subtype. Usually not needed
+
+
+class Int4Multirange(Multirange[int]):
+    pass
+
+
+class Int8Multirange(Multirange[int]):
+    pass
+
+
+class NumericMultirange(Multirange[Decimal]):
+    pass
+
+
+class DateMultirange(Multirange[date]):
+    pass
+
+
+class TimestampMultirange(Multirange[datetime]):
+    pass
+
+
+class TimestamptzMultirange(Multirange[datetime]):
+    pass
+
+
+class BaseMultirangeDumper(RecursiveDumper):
+    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+        super().__init__(cls, context)
+        self.sub_dumper: Optional[Dumper] = None
+        self._adapt_format = PyFormat.from_pq(self.format)
+
+    def get_key(self, obj: Multirange[Any], format: PyFormat) -> DumperKey:
+        # If we are a subclass whose oid is specified we don't need upgrade
+        if self.cls is not Multirange:
+            return self.cls
+
+        item = self._get_item(obj)
+        if item is not None:
+            sd = self._tx.get_dumper(item, self._adapt_format)
+            return (self.cls, sd.get_key(item, format))  # type: ignore
+        else:
+            return (self.cls,)
+
+    def upgrade(
+        self, obj: Multirange[Any], format: PyFormat
+    ) -> "BaseMultirangeDumper":
+        # If we are a subclass whose oid is specified we don't need upgrade
+        if self.cls is not Multirange:
+            return self
+
+        item = self._get_item(obj)
+        if item is None:
+            return MultirangeDumper(self.cls)
+
+        dumper: BaseMultirangeDumper
+        if type(item) is int:
+            # postgres won't cast int4range -> int8range so we must use
+            # text format and unknown oid here
+            sd = self._tx.get_dumper(item, PyFormat.TEXT)
+            dumper = MultirangeDumper(self.cls, self._tx)
+            dumper.sub_dumper = sd
+            dumper.oid = INVALID_OID
+            return dumper
+
+        sd = self._tx.get_dumper(item, format)
+        dumper = type(self)(self.cls, self._tx)
+        dumper.sub_dumper = sd
+        if sd.oid == INVALID_OID and isinstance(item, str):
+            # Work around the normal mapping where text is dumped as unknown
+            dumper.oid = self._get_multirange_oid(TEXT_OID)
+        else:
+            dumper.oid = self._get_multirange_oid(sd.oid)
+
+        return dumper
+
+    def _get_item(self, obj: Multirange[Any]) -> Any:
+        """
+        Return a member representative of the multirange
+        """
+        for r in obj:
+            if r.lower is not None:
+                return r.lower
+            if r.upper is not None:
+                return r.upper
+        return None
+
+    def _get_multirange_oid(self, sub_oid: int) -> int:
+        """
+        Return the oid of the range from the oid of its elements.
+        """
+        info = self._tx.adapters.types.get_by_subtype(MultirangeInfo, sub_oid)
+        return info.oid if info else INVALID_OID
+
+
+class MultirangeDumper(BaseMultirangeDumper):
+    """
+    Dumper for multirange types.
+
+    The dumper can upgrade to one specific for a different range type.
+    """
+
+    def dump(self, obj: Multirange[Any]) -> Buffer:
+        if not obj:
+            return b"{}"
+
+        item = self._get_item(obj)
+        if item is not None:
+            dump = self._tx.get_dumper(item, self._adapt_format).dump
+        else:
+            dump = fail_dump
+
+        out = [b"{"]
+        for r in obj:
+            out.append(dump_range_text(r, dump))
+            out.append(b",")
+        out[-1] = b"}"
+        return b"".join(out)
+
+
+class MultirangeBinaryDumper(BaseMultirangeDumper):
+
+    format = Format.BINARY
+
+    def dump(self, obj: Multirange[Any]) -> Buffer:
+        item = self._get_item(obj)
+        if item is not None:
+            dump = self._tx.get_dumper(item, self._adapt_format).dump
+        else:
+            dump = fail_dump
+
+        out = [pack_len(len(obj))]
+        for r in obj:
+            data = dump_range_binary(r, dump)
+            out.append(pack_len(len(data)))
+            out.append(data)
+        return b"".join(out)
+
+
+class BaseMultirangeLoader(RecursiveLoader, Generic[T]):
+
+    subtype_oid: int
+
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+        super().__init__(oid, context)
+        self._load = self._tx.get_loader(
+            self.subtype_oid, format=self.format
+        ).load
+
+
+class MultirangeLoader(BaseMultirangeLoader[T]):
+    def load(self, data: Buffer) -> Multirange[T]:
+        if not data or data[0] != _START_INT:
+            raise e.DataError(
+                f"malformed multirange starting with"
+                f" {bytes(data[:1]).decode('utf8', 'replace')}"
+            )
+
+        out = Multirange[T]()
+        if data == b"{}":
+            return out
+
+        pos = 1
+        data = data[pos:]
+        try:
+            while True:
+                r, pos = load_range_text(data, self._load)
+                out.append(r)
+
+                sep = data[pos]  # can raise IndexError
+                if sep == _SEP_INT:
+                    data = data[pos + 1 :]
+                    continue
+                elif sep == _END_INT:
+                    if len(data) == pos + 1:
+                        return out
+                    else:
+                        raise e.DataError(
+                            "malformed multirange: data after closing brace"
+                        )
+                else:
+                    raise e.DataError(
+                        f"malformed multirange: found unexpected {chr(sep)}"
+                    )
+
+        except IndexError:
+            raise e.DataError("malformed multirange: separator missing")
+
+        return out
+
+
+_SEP_INT = ord(",")
+_START_INT = ord("{")
+_END_INT = ord("}")
+
+
+class MultirangeBinaryLoader(BaseMultirangeLoader[T]):
+
+    format = Format.BINARY
+
+    def load(self, data: Buffer) -> Multirange[T]:
+        nelems = unpack_len(data, 0)[0]
+        pos = 4
+        out = Multirange[T]()
+        for i in range(nelems):
+            length = unpack_len(data, pos)[0]
+            pos += 4
+            out.append(load_range_binary(data[pos : pos + length], self._load))
+            pos += length
+
+        if pos != len(data):
+            raise e.DataError("unexpected trailing data in multirange")
+
+        return out
+
+
+# Text dumpers for builtin multirange types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4MultirangeDumper(MultirangeDumper):
+    oid = postgres.types["int4multirange"].oid
+
+
+class Int8MultirangeDumper(MultirangeDumper):
+    oid = postgres.types["int8multirange"].oid
+
+
+class NumericMultirangeDumper(MultirangeDumper):
+    oid = postgres.types["nummultirange"].oid
+
+
+class DateMultirangeDumper(MultirangeDumper):
+    oid = postgres.types["datemultirange"].oid
+
+
+class TimestampMultirangeDumper(MultirangeDumper):
+    oid = postgres.types["tsmultirange"].oid
+
+
+class TimestamptzMultirangeDumper(MultirangeDumper):
+    oid = postgres.types["tstzmultirange"].oid
+
+
+# Binary dumpers for builtin multirange types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4MultirangeBinaryDumper(MultirangeBinaryDumper):
+    oid = postgres.types["int4multirange"].oid
+
+
+class Int8MultirangeBinaryDumper(MultirangeBinaryDumper):
+    oid = postgres.types["int8multirange"].oid
+
+
+class NumericMultirangeBinaryDumper(MultirangeBinaryDumper):
+    oid = postgres.types["nummultirange"].oid
+
+
+class DateMultirangeBinaryDumper(MultirangeBinaryDumper):
+    oid = postgres.types["datemultirange"].oid
+
+
+class TimestampMultirangeBinaryDumper(MultirangeBinaryDumper):
+    oid = postgres.types["tsmultirange"].oid
+
+
+class TimestamptzMultirangeBinaryDumper(MultirangeBinaryDumper):
+    oid = postgres.types["tstzmultirange"].oid
+
+
+# Text loaders for builtin multirange types
+
+
+class Int4MultirangeLoader(MultirangeLoader[int]):
+    subtype_oid = postgres.types["int4"].oid
+
+
+class Int8MultirangeLoader(MultirangeLoader[int]):
+    subtype_oid = postgres.types["int8"].oid
+
+
+class NumericMultirangeLoader(MultirangeLoader[Decimal]):
+    subtype_oid = postgres.types["numeric"].oid
+
+
+class DateMultirangeLoader(MultirangeLoader[date]):
+    subtype_oid = postgres.types["date"].oid
+
+
+class TimestampMultirangeLoader(MultirangeLoader[datetime]):
+    subtype_oid = postgres.types["timestamp"].oid
+
+
+class TimestampTZMultirangeLoader(MultirangeLoader[datetime]):
+    subtype_oid = postgres.types["timestamptz"].oid
+
+
+# Binary loaders for builtin multirange types
+
+
+class Int4MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
+    subtype_oid = postgres.types["int4"].oid
+
+
+class Int8MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
+    subtype_oid = postgres.types["int8"].oid
+
+
+class NumericMultirangeBinaryLoader(MultirangeBinaryLoader[Decimal]):
+    subtype_oid = postgres.types["numeric"].oid
+
+
+class DateMultirangeBinaryLoader(MultirangeBinaryLoader[date]):
+    subtype_oid = postgres.types["date"].oid
+
+
+class TimestampMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
+    subtype_oid = postgres.types["timestamp"].oid
+
+
+class TimestampTZMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
+    subtype_oid = postgres.types["timestamptz"].oid
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+    adapters = context.adapters
+    adapters.register_dumper(Multirange, MultirangeBinaryDumper)
+    adapters.register_dumper(Multirange, MultirangeDumper)
+    adapters.register_dumper(Int4Multirange, Int4MultirangeDumper)
+    adapters.register_dumper(Int8Multirange, Int8MultirangeDumper)
+    adapters.register_dumper(NumericMultirange, NumericMultirangeDumper)
+    adapters.register_dumper(DateMultirange, DateMultirangeDumper)
+    adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper)
+    adapters.register_dumper(
+        TimestamptzMultirange, TimestamptzMultirangeDumper
+    )
+    adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper)
+    adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper)
+    adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper)
+    adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper)
+    adapters.register_dumper(
+        TimestampMultirange, TimestampMultirangeBinaryDumper
+    )
+    adapters.register_dumper(
+        TimestamptzMultirange, TimestamptzMultirangeBinaryDumper
+    )
+    adapters.register_loader("int4multirange", Int4MultirangeLoader)
+    adapters.register_loader("int8multirange", Int8MultirangeLoader)
+    adapters.register_loader("nummultirange", NumericMultirangeLoader)
+    adapters.register_loader("datemultirange", DateMultirangeLoader)
+    adapters.register_loader("tsmultirange", TimestampMultirangeLoader)
+    adapters.register_loader("tstzmultirange", TimestampTZMultirangeLoader)
+    adapters.register_loader("int4multirange", Int4MultirangeBinaryLoader)
+    adapters.register_loader("int8multirange", Int8MultirangeBinaryLoader)
+    adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader)
+    adapters.register_loader("datemultirange", DateMultirangeBinaryLoader)
+    adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader)
+    adapters.register_loader(
+        "tstzmultirange", TimestampTZMultirangeBinaryLoader
+    )
index 149ef68c218f5f9111b70898eec995caddbd647b..beb8a93d37ac4b9d2fea7c737d68ae7787e5b69e 100644 (file)
@@ -238,6 +238,12 @@ class Faker:
         for cls in dumpers.keys():
             if isinstance(cls, str):
                 cls = deep_import(cls)
+            if (
+                issubclass(cls, psycopg.types.multirange.Multirange)
+                and self.conn.info.server_version < 140000
+            ):
+                continue
+
             rv.add(cls)
 
         # check all the types are handled
diff --git a/tests/types/test_multirange.py b/tests/types/test_multirange.py
new file mode 100644 (file)
index 0000000..8416c1e
--- /dev/null
@@ -0,0 +1,20 @@
+import pytest
+
+from psycopg.adapt import PyFormat
+from psycopg.types.multirange import Multirange
+
+pytestmark = pytest.mark.pg(">= 14")
+
+mr_names = """int4multirange int8multirange nummultirange
+    datemultirange tsmultirange tstzmultirange""".split()
+
+mr_classes = """Int4Multirange Int8Multirange NumericMultirange
+    DateMultirange TimestampMultirange TimestamptzMultirange""".split()
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_empty(conn, pgtype, fmt_in):
+    mr = Multirange()
+    cur = conn.execute(f"select '{{}}'::{pgtype} = %{fmt_in}", (mr,))
+    assert cur.fetchone()[0] is True