]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added builtin ranges adaptation
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 5 Dec 2020 18:41:26 +0000 (18:41 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 10 Dec 2020 01:43:50 +0000 (02:43 +0100)
psycopg3/psycopg3/types/__init__.py
psycopg3/psycopg3/types/composite.py
psycopg3/psycopg3/types/range.py [new file with mode: 0644]
tests/types/test_range.py [new file with mode: 0644]

index 80c0f3345af7ef99870189615bda4633d984bc69..357997ff36cedb6189179e042e5799724b57d867 100644 (file)
@@ -9,7 +9,7 @@ from ..oids import builtins
 
 # Register default adapters
 from . import array, composite, date, json, network, numeric  # noqa
-from . import singletons, text, uuid  # noqa
+from . import range, singletons, text, uuid  # noqa
 
 # Register associations with array oids
 array.register_all_arrays()
index 2dc4bf494ff8c2848ae260466d7e43bd5e50da05..f7a79f844c7c029305f77be03d03bb84bd0c14b0 100644 (file)
@@ -136,21 +136,22 @@ where t.oid = %(name)s::regtype
 """
 
 
-@Dumper.text(tuple)
-class TupleDumper(Dumper):
+class SequenceDumper(Dumper):
     def __init__(self, src: type, context: AdaptContext = None):
         super().__init__(src, context)
         self._tx = Transformer(context)
 
-    def dump(self, obj: Tuple[Any, ...]) -> bytes:
+    def _dump_sequence(
+        self, obj: Sequence[Any], start: bytes, end: bytes, sep: bytes
+    ) -> bytes:
         if not obj:
             return b"()"
 
-        parts = [b"("]
+        parts = [start]
 
         for item in obj:
             if item is None:
-                parts.append(b",")
+                parts.append(sep)
                 continue
 
             dumper = self._tx.get_dumper(item, Format.TEXT)
@@ -161,9 +162,9 @@ class TupleDumper(Dumper):
                 ad = b'"' + self._re_escape.sub(br"\1\1", ad) + b'"'
 
             parts.append(ad)
-            parts.append(b",")
+            parts.append(sep)
 
-        parts[-1] = b")"
+        parts[-1] = end
 
         return b"".join(parts)
 
@@ -171,24 +172,17 @@ class TupleDumper(Dumper):
     _re_escape = re.compile(br"([\\\"])")
 
 
+@Dumper.text(tuple)
+class TupleDumper(SequenceDumper):
+    def dump(self, obj: Tuple[Any, ...]) -> bytes:
+        return self._dump_sequence(obj, b"(", b")", b",")
+
+
 class BaseCompositeLoader(Loader):
     def __init__(self, oid: int, context: AdaptContext = None):
         super().__init__(oid, context)
         self._tx = Transformer(context)
 
-
-@Loader.text(builtins["record"].oid)
-class RecordLoader(BaseCompositeLoader):
-    def load(self, data: bytes) -> Tuple[Any, ...]:
-        if data == b"()":
-            return ()
-
-        cast = self._tx.get_loader(TEXT_OID, format=Format.TEXT).load
-        return tuple(
-            cast(token) if token is not None else None
-            for token in self._parse_record(data[1:-1])
-        )
-
     def _parse_record(self, data: bytes) -> Iterator[Optional[bytes]]:
         """
         Split a non-empty representation of a composite type into components.
@@ -220,6 +214,19 @@ class RecordLoader(BaseCompositeLoader):
     _re_undouble = re.compile(br'(["\\])\1')
 
 
+@Loader.text(builtins["record"].oid)
+class RecordLoader(BaseCompositeLoader):
+    def load(self, data: bytes) -> Tuple[Any, ...]:
+        if data == b"()":
+            return ()
+
+        cast = self._tx.get_loader(TEXT_OID, format=Format.TEXT).load
+        return tuple(
+            cast(token) if token is not None else None
+            for token in self._parse_record(data[1:-1])
+        )
+
+
 _struct_len = struct.Struct("!i")
 _struct_oidlen = struct.Struct("!Ii")
 
diff --git a/psycopg3/psycopg3/types/range.py b/psycopg3/psycopg3/types/range.py
new file mode 100644 (file)
index 0000000..3f21a36
--- /dev/null
@@ -0,0 +1,344 @@
+"""
+Support for range types adaptation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, cast, Dict, Generic, Optional, TypeVar, Type
+from decimal import Decimal
+from datetime import date, datetime
+
+from ..oids import builtins
+from ..adapt import Format, Dumper, Loader
+from .composite import SequenceDumper, BaseCompositeLoader
+
+T = TypeVar("T")
+
+
+class Range(Generic[T]):
+    """Python representation for a PostgreSQL |range|_ type.
+
+    :param lower: lower bound for the range. `!None` means unbound
+    :param upper: upper bound for the range. `!None` means unbound
+    :param bounds: one of the literal strings ``()``, ``[)``, ``(]``, ``[]``,
+        representing whether the lower or upper bounds are included
+    :param empty: if `!True`, the range is empty
+
+    """
+
+    __slots__ = ("_lower", "_upper", "_bounds")
+
+    def __init__(
+        self,
+        lower: Optional[T] = None,
+        upper: Optional[T] = None,
+        bounds: str = "[)",
+        empty: bool = False,
+    ):
+        if not empty:
+            if bounds not in ("[)", "(]", "()", "[]"):
+                raise ValueError("bound flags not valid: %r" % bounds)
+
+            self._lower = lower
+            self._upper = upper
+            self._bounds = bounds
+        else:
+            self._lower = self._upper = None
+            self._bounds = ""
+
+    def __repr__(self) -> str:
+        if not self._bounds:
+            return "%s(empty=True)" % self.__class__.__name__
+        else:
+            return "%s(%r, %r, %r)" % (
+                self.__class__.__name__,
+                self._lower,
+                self._upper,
+                self._bounds,
+            )
+
+    def __str__(self) -> str:
+        if not self._bounds:
+            return "empty"
+
+        items = [
+            self._bounds[0],
+            str(self._lower),
+            ", ",
+            str(self._upper),
+            self._bounds[1],
+        ]
+        return "".join(items)
+
+    @property
+    def lower(self) -> Optional[T]:
+        """The lower bound of the range. `!None` if empty or unbound."""
+        return self._lower
+
+    @property
+    def upper(self) -> Optional[T]:
+        """The upper bound of the range. `!None` if empty or unbound."""
+        return self._upper
+
+    @property
+    def isempty(self) -> bool:
+        """`!True` if the range is empty."""
+        return not self._bounds
+
+    @property
+    def lower_inf(self) -> bool:
+        """`!True` if the range doesn't have a lower bound."""
+        if not self._bounds:
+            return False
+        return self._lower is None
+
+    @property
+    def upper_inf(self) -> bool:
+        """`!True` if the range doesn't have an upper bound."""
+        if not self._bounds:
+            return False
+        return self._upper is None
+
+    @property
+    def lower_inc(self) -> bool:
+        """`!True` if the lower bound is included in the range."""
+        if not self._bounds or self._lower is None:
+            return False
+        return self._bounds[0] == "["
+
+    @property
+    def upper_inc(self) -> bool:
+        """`!True` if the upper bound is included in the range."""
+        if not self._bounds or self._upper is None:
+            return False
+        return self._bounds[1] == "]"
+
+    def __contains__(self, x: T) -> bool:
+        if not self._bounds:
+            return False
+
+        if self._lower is not None:
+            if self._bounds[0] == "[":
+                # It doesn't seem that Python has an ABC for ordered types.
+                if x < self._lower:  # type: ignore[operator]
+                    return False
+            else:
+                if x <= self._lower:  # type: ignore[operator]
+                    return False
+
+        if self._upper is not None:
+            if self._bounds[1] == "]":
+                if x > self._upper:  # type: ignore[operator]
+                    return False
+            else:
+                if x >= self._upper:  # type: ignore[operator]
+                    return False
+
+        return True
+
+    def __bool__(self) -> bool:
+        return bool(self._bounds)
+
+    def __eq__(self, other: Any) -> bool:
+        if not isinstance(other, Range):
+            return False
+        return (
+            self._lower == other._lower
+            and self._upper == other._upper
+            and self._bounds == other._bounds
+        )
+
+    def __ne__(self, other: Any) -> bool:
+        return not self.__eq__(other)
+
+    def __hash__(self) -> int:
+        return hash((self._lower, self._upper, self._bounds))
+
+    # as the postgres docs describe for the server-side stuff,
+    # ordering is rather arbitrary, but will remain stable
+    # and consistent.
+
+    def __lt__(self, other: Any) -> bool:
+        if not isinstance(other, Range):
+            return NotImplemented
+        for attr in ("_lower", "_upper", "_bounds"):
+            self_value = getattr(self, attr)
+            other_value = getattr(other, attr)
+            if self_value == other_value:
+                pass
+            elif self_value is None:
+                return True
+            elif other_value is None:
+                return False
+            else:
+                return cast(bool, self_value < other_value)
+        return False
+
+    def __le__(self, other: Any) -> bool:
+        if self == other:
+            return True
+        else:
+            return self.__lt__(other)
+
+    def __gt__(self, other: Any) -> bool:
+        if isinstance(other, Range):
+            return other.__lt__(self)
+        else:
+            return NotImplemented
+
+    def __ge__(self, other: Any) -> bool:
+        if self == other:
+            return True
+        else:
+            return self.__gt__(other)
+
+    def __getstate__(self) -> Dict[str, Any]:
+        return {
+            slot: getattr(self, slot)
+            for slot in self.__slots__
+            if hasattr(self, slot)
+        }
+
+    def __setstate__(self, state: Dict[str, Any]) -> None:
+        for slot, value in state.items():
+            setattr(self, slot, value)
+
+
+class RangeDumper(SequenceDumper):
+    """
+    Generic dumper for a range.
+
+    Subclasses shoud specify the type oid.
+    """
+
+    def dump(self, obj: Range[Any]) -> bytes:
+        if not obj:
+            return b"empty"
+        else:
+            return self._dump_sequence(
+                (obj.lower, obj.upper),
+                b"[" if obj.lower_inc else b"(",
+                b"]" if obj.upper_inc else b")",
+                b",",
+            )
+
+
+class RangeLoader(BaseCompositeLoader, Generic[T]):
+    """Generic loader for a range.
+
+    Subclasses shoud specify the oid of the subtype and the class to load.
+    """
+
+    subtype_oid: int
+    cls: Type[Range[T]]
+
+    def load(self, data: bytes) -> Range[T]:
+        if data == b"empty":
+            return self.cls(empty=True)
+
+        cast = self._tx.get_loader(self.subtype_oid, format=Format.TEXT).load
+        bounds = (data[:1] + data[-1:]).decode("utf-8")
+        min, max = (
+            cast(token) if token is not None else None
+            for token in self._parse_record(data[1:-1])
+        )
+        return self.cls(min, max, bounds)
+
+
+# Python wrappers for builtin range types
+
+
+class Int4Range(Range[int]):
+    pass
+
+
+class Int8Range(Range[int]):
+    pass
+
+
+class DecimalRange(Range[Decimal]):
+    pass
+
+
+class DateRange(Range[date]):
+    pass
+
+
+class DateTimeRange(Range[datetime]):
+    pass
+
+
+class DateTimeTZRange(Range[datetime]):
+    pass
+
+
+# Dumpers for builtin range types
+
+
+@Dumper.text(Int4Range)
+class Int4RangeDumper(RangeDumper):
+    oid = builtins["int4range"].oid
+
+
+@Dumper.text(Int8Range)
+class Int8RangeDumper(RangeDumper):
+    oid = builtins["int8range"].oid
+
+
+@Dumper.text(DecimalRange)
+class NumRangeDumper(RangeDumper):
+    oid = builtins["numrange"].oid
+
+
+@Dumper.text(DateRange)
+class DateRangeDumper(RangeDumper):
+    oid = builtins["daterange"].oid
+
+
+@Dumper.text(DateTimeRange)
+class TimestampRangeDumper(RangeDumper):
+    oid = builtins["tsrange"].oid
+
+
+@Dumper.text(DateTimeTZRange)
+class TimestampTZRangeDumper(RangeDumper):
+    oid = builtins["tstzrange"].oid
+
+
+# Loaders for builtin range types
+
+
+@Loader.text(builtins["int4range"].oid)
+class Int4RangeLoader(RangeLoader[int]):
+    subtype_oid = builtins["int4"].oid
+    cls = Int4Range
+
+
+@Loader.text(builtins["int8range"].oid)
+class Int8RangeLoader(RangeLoader[int]):
+    subtype_oid = builtins["int8"].oid
+    cls = Int8Range
+
+
+@Loader.text(builtins["numrange"].oid)
+class NumericRangeLoader(RangeLoader[Decimal]):
+    subtype_oid = builtins["numeric"].oid
+    cls = DecimalRange
+
+
+@Loader.text(builtins["daterange"].oid)
+class DateRangeLoader(RangeLoader[date]):
+    subtype_oid = builtins["numeric"].oid
+    cls = DateRange
+
+
+@Loader.text(builtins["tsrange"].oid)
+class TimestampRangeLoader(RangeLoader[datetime]):
+    subtype_oid = builtins["timestamp"].oid
+    cls = DateTimeRange
+
+
+@Loader.text(builtins["tstzrange"].oid)
+class TimestampTZRangeLoader(RangeLoader[datetime]):
+    subtype_oid = builtins["timestamptz"].oid
+    cls = DateTimeTZRange
diff --git a/tests/types/test_range.py b/tests/types/test_range.py
new file mode 100644 (file)
index 0000000..91eb669
--- /dev/null
@@ -0,0 +1,81 @@
+import pytest
+
+from psycopg3.types import range
+
+
+type2cls = {
+    "int4range": range.Int4Range,
+    "int8range": range.Int8Range,
+    "numrange": range.DecimalRange,
+    "daterange": range.DateRange,
+    "tsrange": range.DateTimeRange,
+    "tstzrange": range.DateTimeTZRange,
+}
+type2sub = {
+    "int4range": "int4",
+    "int8range": "int8",
+    "numrange": "numeric",
+    "daterange": "date",
+    "tsrange": "timestamp",
+    "tstzrange": "timestamptz",
+}
+
+samples = [
+    ("int4range", None, None, "()"),
+    ("int4range", 10, 20, "[]"),
+    ("int4range", -(2 ** 31), (2 ** 31) - 1, "[)"),
+    ("int8range", None, None, "()"),
+    ("int8range", 10, 20, "[)"),
+    ("int8range", -(2 ** 63), (2 ** 63) - 1, "[)"),
+    # TODO: complete samples
+]
+
+
+@pytest.mark.parametrize(
+    "pgtype",
+    "int4range int8range numrange daterange tsrange tstzrange".split(),
+)
+def test_dump_builtin_range_empty(conn, pgtype):
+    r = type2cls[pgtype](empty=True)
+    cur = conn.cursor()
+    cur.execute(f"select 'empty'::{pgtype} = %s", (r,))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype, min, max, bounds", samples)
+def test_dump_builtin_range(conn, pgtype, min, max, bounds):
+    r = type2cls[pgtype](min, max, bounds)
+    sub = type2sub[pgtype]
+    cur = conn.cursor()
+    cur.execute(
+        f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %s",
+        (min, max, bounds, r),
+    )
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "pgtype",
+    "int4range int8range numrange daterange tsrange tstzrange".split(),
+)
+def test_load_builtin_range_empty(conn, pgtype):
+    r = type2cls[pgtype](empty=True)
+    cur = conn.cursor()
+    (got,) = cur.execute(f"select 'empty'::{pgtype}").fetchone()
+    assert type(got) is type2cls[pgtype]
+    assert got == r
+
+
+@pytest.mark.parametrize("pgtype, min, max, bounds", samples)
+def test_load_builtin_range(conn, pgtype, min, max, bounds):
+    r = type2cls[pgtype](min, max, bounds)
+    sub = type2sub[pgtype]
+    cur = conn.cursor()
+    cur.execute(
+        f"select {pgtype}(%s::{sub}, %s::{sub}, %s)", (min, max, bounds)
+    )
+    # normalise discrete ranges
+    if r.upper_inc and isinstance(r.upper, int):
+        bounds = "[)" if r.lower_inc else "()"
+        r = type(r)(r.lower, r.upper + 1, bounds)
+    assert cur.fetchone()[0] == r