From: Daniele Varrazzo Date: Sat, 5 Dec 2020 18:41:26 +0000 (+0000) Subject: Added builtin ranges adaptation X-Git-Tag: 3.0.dev0~274^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4758fc05989b84e697e3cb9f8c5aaaf1b66e7f41;p=thirdparty%2Fpsycopg.git Added builtin ranges adaptation --- diff --git a/psycopg3/psycopg3/types/__init__.py b/psycopg3/psycopg3/types/__init__.py index 80c0f3345..357997ff3 100644 --- a/psycopg3/psycopg3/types/__init__.py +++ b/psycopg3/psycopg3/types/__init__.py @@ -9,7 +9,7 @@ from ..oids import builtins # Register default adapters from . import array, composite, date, json, network, numeric # noqa -from . import singletons, text, uuid # noqa +from . import range, singletons, text, uuid # noqa # Register associations with array oids array.register_all_arrays() diff --git a/psycopg3/psycopg3/types/composite.py b/psycopg3/psycopg3/types/composite.py index 2dc4bf494..f7a79f844 100644 --- a/psycopg3/psycopg3/types/composite.py +++ b/psycopg3/psycopg3/types/composite.py @@ -136,21 +136,22 @@ where t.oid = %(name)s::regtype """ -@Dumper.text(tuple) -class TupleDumper(Dumper): +class SequenceDumper(Dumper): def __init__(self, src: type, context: AdaptContext = None): super().__init__(src, context) self._tx = Transformer(context) - def dump(self, obj: Tuple[Any, ...]) -> bytes: + def _dump_sequence( + self, obj: Sequence[Any], start: bytes, end: bytes, sep: bytes + ) -> bytes: if not obj: return b"()" - parts = [b"("] + parts = [start] for item in obj: if item is None: - parts.append(b",") + parts.append(sep) continue dumper = self._tx.get_dumper(item, Format.TEXT) @@ -161,9 +162,9 @@ class TupleDumper(Dumper): ad = b'"' + self._re_escape.sub(br"\1\1", ad) + b'"' parts.append(ad) - parts.append(b",") + parts.append(sep) - parts[-1] = b")" + parts[-1] = end return b"".join(parts) @@ -171,24 +172,17 @@ class TupleDumper(Dumper): _re_escape = re.compile(br"([\\\"])") +@Dumper.text(tuple) +class TupleDumper(SequenceDumper): + def dump(self, obj: Tuple[Any, ...]) -> bytes: + return self._dump_sequence(obj, b"(", b")", b",") + + class BaseCompositeLoader(Loader): def __init__(self, oid: int, context: AdaptContext = None): super().__init__(oid, context) self._tx = Transformer(context) - -@Loader.text(builtins["record"].oid) -class RecordLoader(BaseCompositeLoader): - def load(self, data: bytes) -> Tuple[Any, ...]: - if data == b"()": - return () - - cast = self._tx.get_loader(TEXT_OID, format=Format.TEXT).load - return tuple( - cast(token) if token is not None else None - for token in self._parse_record(data[1:-1]) - ) - def _parse_record(self, data: bytes) -> Iterator[Optional[bytes]]: """ Split a non-empty representation of a composite type into components. @@ -220,6 +214,19 @@ class RecordLoader(BaseCompositeLoader): _re_undouble = re.compile(br'(["\\])\1') +@Loader.text(builtins["record"].oid) +class RecordLoader(BaseCompositeLoader): + def load(self, data: bytes) -> Tuple[Any, ...]: + if data == b"()": + return () + + cast = self._tx.get_loader(TEXT_OID, format=Format.TEXT).load + return tuple( + cast(token) if token is not None else None + for token in self._parse_record(data[1:-1]) + ) + + _struct_len = struct.Struct("!i") _struct_oidlen = struct.Struct("!Ii") diff --git a/psycopg3/psycopg3/types/range.py b/psycopg3/psycopg3/types/range.py new file mode 100644 index 000000000..3f21a3683 --- /dev/null +++ b/psycopg3/psycopg3/types/range.py @@ -0,0 +1,344 @@ +""" +Support for range types adaptation. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, cast, Dict, Generic, Optional, TypeVar, Type +from decimal import Decimal +from datetime import date, datetime + +from ..oids import builtins +from ..adapt import Format, Dumper, Loader +from .composite import SequenceDumper, BaseCompositeLoader + +T = TypeVar("T") + + +class Range(Generic[T]): + """Python representation for a PostgreSQL |range|_ type. + + :param lower: lower bound for the range. `!None` means unbound + :param upper: upper bound for the range. `!None` means unbound + :param bounds: one of the literal strings ``()``, ``[)``, ``(]``, ``[]``, + representing whether the lower or upper bounds are included + :param empty: if `!True`, the range is empty + + """ + + __slots__ = ("_lower", "_upper", "_bounds") + + def __init__( + self, + lower: Optional[T] = None, + upper: Optional[T] = None, + bounds: str = "[)", + empty: bool = False, + ): + if not empty: + if bounds not in ("[)", "(]", "()", "[]"): + raise ValueError("bound flags not valid: %r" % bounds) + + self._lower = lower + self._upper = upper + self._bounds = bounds + else: + self._lower = self._upper = None + self._bounds = "" + + def __repr__(self) -> str: + if not self._bounds: + return "%s(empty=True)" % self.__class__.__name__ + else: + return "%s(%r, %r, %r)" % ( + self.__class__.__name__, + self._lower, + self._upper, + self._bounds, + ) + + def __str__(self) -> str: + if not self._bounds: + return "empty" + + items = [ + self._bounds[0], + str(self._lower), + ", ", + str(self._upper), + self._bounds[1], + ] + return "".join(items) + + @property + def lower(self) -> Optional[T]: + """The lower bound of the range. `!None` if empty or unbound.""" + return self._lower + + @property + def upper(self) -> Optional[T]: + """The upper bound of the range. `!None` if empty or unbound.""" + return self._upper + + @property + def isempty(self) -> bool: + """`!True` if the range is empty.""" + return not self._bounds + + @property + def lower_inf(self) -> bool: + """`!True` if the range doesn't have a lower bound.""" + if not self._bounds: + return False + return self._lower is None + + @property + def upper_inf(self) -> bool: + """`!True` if the range doesn't have an upper bound.""" + if not self._bounds: + return False + return self._upper is None + + @property + def lower_inc(self) -> bool: + """`!True` if the lower bound is included in the range.""" + if not self._bounds or self._lower is None: + return False + return self._bounds[0] == "[" + + @property + def upper_inc(self) -> bool: + """`!True` if the upper bound is included in the range.""" + if not self._bounds or self._upper is None: + return False + return self._bounds[1] == "]" + + def __contains__(self, x: T) -> bool: + if not self._bounds: + return False + + if self._lower is not None: + if self._bounds[0] == "[": + # It doesn't seem that Python has an ABC for ordered types. + if x < self._lower: # type: ignore[operator] + return False + else: + if x <= self._lower: # type: ignore[operator] + return False + + if self._upper is not None: + if self._bounds[1] == "]": + if x > self._upper: # type: ignore[operator] + return False + else: + if x >= self._upper: # type: ignore[operator] + return False + + return True + + def __bool__(self) -> bool: + return bool(self._bounds) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Range): + return False + return ( + self._lower == other._lower + and self._upper == other._upper + and self._bounds == other._bounds + ) + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + def __hash__(self) -> int: + return hash((self._lower, self._upper, self._bounds)) + + # as the postgres docs describe for the server-side stuff, + # ordering is rather arbitrary, but will remain stable + # and consistent. + + def __lt__(self, other: Any) -> bool: + if not isinstance(other, Range): + return NotImplemented + for attr in ("_lower", "_upper", "_bounds"): + self_value = getattr(self, attr) + other_value = getattr(other, attr) + if self_value == other_value: + pass + elif self_value is None: + return True + elif other_value is None: + return False + else: + return cast(bool, self_value < other_value) + return False + + def __le__(self, other: Any) -> bool: + if self == other: + return True + else: + return self.__lt__(other) + + def __gt__(self, other: Any) -> bool: + if isinstance(other, Range): + return other.__lt__(self) + else: + return NotImplemented + + def __ge__(self, other: Any) -> bool: + if self == other: + return True + else: + return self.__gt__(other) + + def __getstate__(self) -> Dict[str, Any]: + return { + slot: getattr(self, slot) + for slot in self.__slots__ + if hasattr(self, slot) + } + + def __setstate__(self, state: Dict[str, Any]) -> None: + for slot, value in state.items(): + setattr(self, slot, value) + + +class RangeDumper(SequenceDumper): + """ + Generic dumper for a range. + + Subclasses shoud specify the type oid. + """ + + def dump(self, obj: Range[Any]) -> bytes: + if not obj: + return b"empty" + else: + return self._dump_sequence( + (obj.lower, obj.upper), + b"[" if obj.lower_inc else b"(", + b"]" if obj.upper_inc else b")", + b",", + ) + + +class RangeLoader(BaseCompositeLoader, Generic[T]): + """Generic loader for a range. + + Subclasses shoud specify the oid of the subtype and the class to load. + """ + + subtype_oid: int + cls: Type[Range[T]] + + def load(self, data: bytes) -> Range[T]: + if data == b"empty": + return self.cls(empty=True) + + cast = self._tx.get_loader(self.subtype_oid, format=Format.TEXT).load + bounds = (data[:1] + data[-1:]).decode("utf-8") + min, max = ( + cast(token) if token is not None else None + for token in self._parse_record(data[1:-1]) + ) + return self.cls(min, max, bounds) + + +# Python wrappers for builtin range types + + +class Int4Range(Range[int]): + pass + + +class Int8Range(Range[int]): + pass + + +class DecimalRange(Range[Decimal]): + pass + + +class DateRange(Range[date]): + pass + + +class DateTimeRange(Range[datetime]): + pass + + +class DateTimeTZRange(Range[datetime]): + pass + + +# Dumpers for builtin range types + + +@Dumper.text(Int4Range) +class Int4RangeDumper(RangeDumper): + oid = builtins["int4range"].oid + + +@Dumper.text(Int8Range) +class Int8RangeDumper(RangeDumper): + oid = builtins["int8range"].oid + + +@Dumper.text(DecimalRange) +class NumRangeDumper(RangeDumper): + oid = builtins["numrange"].oid + + +@Dumper.text(DateRange) +class DateRangeDumper(RangeDumper): + oid = builtins["daterange"].oid + + +@Dumper.text(DateTimeRange) +class TimestampRangeDumper(RangeDumper): + oid = builtins["tsrange"].oid + + +@Dumper.text(DateTimeTZRange) +class TimestampTZRangeDumper(RangeDumper): + oid = builtins["tstzrange"].oid + + +# Loaders for builtin range types + + +@Loader.text(builtins["int4range"].oid) +class Int4RangeLoader(RangeLoader[int]): + subtype_oid = builtins["int4"].oid + cls = Int4Range + + +@Loader.text(builtins["int8range"].oid) +class Int8RangeLoader(RangeLoader[int]): + subtype_oid = builtins["int8"].oid + cls = Int8Range + + +@Loader.text(builtins["numrange"].oid) +class NumericRangeLoader(RangeLoader[Decimal]): + subtype_oid = builtins["numeric"].oid + cls = DecimalRange + + +@Loader.text(builtins["daterange"].oid) +class DateRangeLoader(RangeLoader[date]): + subtype_oid = builtins["numeric"].oid + cls = DateRange + + +@Loader.text(builtins["tsrange"].oid) +class TimestampRangeLoader(RangeLoader[datetime]): + subtype_oid = builtins["timestamp"].oid + cls = DateTimeRange + + +@Loader.text(builtins["tstzrange"].oid) +class TimestampTZRangeLoader(RangeLoader[datetime]): + subtype_oid = builtins["timestamptz"].oid + cls = DateTimeTZRange diff --git a/tests/types/test_range.py b/tests/types/test_range.py new file mode 100644 index 000000000..91eb669a6 --- /dev/null +++ b/tests/types/test_range.py @@ -0,0 +1,81 @@ +import pytest + +from psycopg3.types import range + + +type2cls = { + "int4range": range.Int4Range, + "int8range": range.Int8Range, + "numrange": range.DecimalRange, + "daterange": range.DateRange, + "tsrange": range.DateTimeRange, + "tstzrange": range.DateTimeTZRange, +} +type2sub = { + "int4range": "int4", + "int8range": "int8", + "numrange": "numeric", + "daterange": "date", + "tsrange": "timestamp", + "tstzrange": "timestamptz", +} + +samples = [ + ("int4range", None, None, "()"), + ("int4range", 10, 20, "[]"), + ("int4range", -(2 ** 31), (2 ** 31) - 1, "[)"), + ("int8range", None, None, "()"), + ("int8range", 10, 20, "[)"), + ("int8range", -(2 ** 63), (2 ** 63) - 1, "[)"), + # TODO: complete samples +] + + +@pytest.mark.parametrize( + "pgtype", + "int4range int8range numrange daterange tsrange tstzrange".split(), +) +def test_dump_builtin_range_empty(conn, pgtype): + r = type2cls[pgtype](empty=True) + cur = conn.cursor() + cur.execute(f"select 'empty'::{pgtype} = %s", (r,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype, min, max, bounds", samples) +def test_dump_builtin_range(conn, pgtype, min, max, bounds): + r = type2cls[pgtype](min, max, bounds) + sub = type2sub[pgtype] + cur = conn.cursor() + cur.execute( + f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %s", + (min, max, bounds, r), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "pgtype", + "int4range int8range numrange daterange tsrange tstzrange".split(), +) +def test_load_builtin_range_empty(conn, pgtype): + r = type2cls[pgtype](empty=True) + cur = conn.cursor() + (got,) = cur.execute(f"select 'empty'::{pgtype}").fetchone() + assert type(got) is type2cls[pgtype] + assert got == r + + +@pytest.mark.parametrize("pgtype, min, max, bounds", samples) +def test_load_builtin_range(conn, pgtype, min, max, bounds): + r = type2cls[pgtype](min, max, bounds) + sub = type2sub[pgtype] + cur = conn.cursor() + cur.execute( + f"select {pgtype}(%s::{sub}, %s::{sub}, %s)", (min, max, bounds) + ) + # normalise discrete ranges + if r.upper_inc and isinstance(r.upper, int): + bounds = "[)" if r.lower_inc else "()" + r = type(r)(r.lower, r.upper + 1, bounds) + assert cur.fetchone()[0] == r