"""
-@Dumper.text(tuple)
-class TupleDumper(Dumper):
+class SequenceDumper(Dumper):
def __init__(self, src: type, context: AdaptContext = None):
super().__init__(src, context)
self._tx = Transformer(context)
- def dump(self, obj: Tuple[Any, ...]) -> bytes:
+ def _dump_sequence(
+ self, obj: Sequence[Any], start: bytes, end: bytes, sep: bytes
+ ) -> bytes:
if not obj:
return b"()"
- parts = [b"("]
+ parts = [start]
for item in obj:
if item is None:
- parts.append(b",")
+ parts.append(sep)
continue
dumper = self._tx.get_dumper(item, Format.TEXT)
ad = b'"' + self._re_escape.sub(br"\1\1", ad) + b'"'
parts.append(ad)
- parts.append(b",")
+ parts.append(sep)
- parts[-1] = b")"
+ parts[-1] = end
return b"".join(parts)
_re_escape = re.compile(br"([\\\"])")
+@Dumper.text(tuple)
+class TupleDumper(SequenceDumper):
+ def dump(self, obj: Tuple[Any, ...]) -> bytes:
+ return self._dump_sequence(obj, b"(", b")", b",")
+
+
class BaseCompositeLoader(Loader):
def __init__(self, oid: int, context: AdaptContext = None):
super().__init__(oid, context)
self._tx = Transformer(context)
-
-@Loader.text(builtins["record"].oid)
-class RecordLoader(BaseCompositeLoader):
- def load(self, data: bytes) -> Tuple[Any, ...]:
- if data == b"()":
- return ()
-
- cast = self._tx.get_loader(TEXT_OID, format=Format.TEXT).load
- return tuple(
- cast(token) if token is not None else None
- for token in self._parse_record(data[1:-1])
- )
-
def _parse_record(self, data: bytes) -> Iterator[Optional[bytes]]:
"""
Split a non-empty representation of a composite type into components.
_re_undouble = re.compile(br'(["\\])\1')
+@Loader.text(builtins["record"].oid)
+class RecordLoader(BaseCompositeLoader):
+ def load(self, data: bytes) -> Tuple[Any, ...]:
+ if data == b"()":
+ return ()
+
+ cast = self._tx.get_loader(TEXT_OID, format=Format.TEXT).load
+ return tuple(
+ cast(token) if token is not None else None
+ for token in self._parse_record(data[1:-1])
+ )
+
+
_struct_len = struct.Struct("!i")
_struct_oidlen = struct.Struct("!Ii")
--- /dev/null
+"""
+Support for range types adaptation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, cast, Dict, Generic, Optional, TypeVar, Type
+from decimal import Decimal
+from datetime import date, datetime
+
+from ..oids import builtins
+from ..adapt import Format, Dumper, Loader
+from .composite import SequenceDumper, BaseCompositeLoader
+
+T = TypeVar("T")
+
+
+class Range(Generic[T]):
+ """Python representation for a PostgreSQL |range|_ type.
+
+ :param lower: lower bound for the range. `!None` means unbound
+ :param upper: upper bound for the range. `!None` means unbound
+ :param bounds: one of the literal strings ``()``, ``[)``, ``(]``, ``[]``,
+ representing whether the lower or upper bounds are included
+ :param empty: if `!True`, the range is empty
+
+ """
+
+ __slots__ = ("_lower", "_upper", "_bounds")
+
+ def __init__(
+ self,
+ lower: Optional[T] = None,
+ upper: Optional[T] = None,
+ bounds: str = "[)",
+ empty: bool = False,
+ ):
+ if not empty:
+ if bounds not in ("[)", "(]", "()", "[]"):
+ raise ValueError("bound flags not valid: %r" % bounds)
+
+ self._lower = lower
+ self._upper = upper
+ self._bounds = bounds
+ else:
+ self._lower = self._upper = None
+ self._bounds = ""
+
+ def __repr__(self) -> str:
+ if not self._bounds:
+ return "%s(empty=True)" % self.__class__.__name__
+ else:
+ return "%s(%r, %r, %r)" % (
+ self.__class__.__name__,
+ self._lower,
+ self._upper,
+ self._bounds,
+ )
+
+ def __str__(self) -> str:
+ if not self._bounds:
+ return "empty"
+
+ items = [
+ self._bounds[0],
+ str(self._lower),
+ ", ",
+ str(self._upper),
+ self._bounds[1],
+ ]
+ return "".join(items)
+
+ @property
+ def lower(self) -> Optional[T]:
+ """The lower bound of the range. `!None` if empty or unbound."""
+ return self._lower
+
+ @property
+ def upper(self) -> Optional[T]:
+ """The upper bound of the range. `!None` if empty or unbound."""
+ return self._upper
+
+ @property
+ def isempty(self) -> bool:
+ """`!True` if the range is empty."""
+ return not self._bounds
+
+ @property
+ def lower_inf(self) -> bool:
+ """`!True` if the range doesn't have a lower bound."""
+ if not self._bounds:
+ return False
+ return self._lower is None
+
+ @property
+ def upper_inf(self) -> bool:
+ """`!True` if the range doesn't have an upper bound."""
+ if not self._bounds:
+ return False
+ return self._upper is None
+
+ @property
+ def lower_inc(self) -> bool:
+ """`!True` if the lower bound is included in the range."""
+ if not self._bounds or self._lower is None:
+ return False
+ return self._bounds[0] == "["
+
+ @property
+ def upper_inc(self) -> bool:
+ """`!True` if the upper bound is included in the range."""
+ if not self._bounds or self._upper is None:
+ return False
+ return self._bounds[1] == "]"
+
+ def __contains__(self, x: T) -> bool:
+ if not self._bounds:
+ return False
+
+ if self._lower is not None:
+ if self._bounds[0] == "[":
+ # It doesn't seem that Python has an ABC for ordered types.
+ if x < self._lower: # type: ignore[operator]
+ return False
+ else:
+ if x <= self._lower: # type: ignore[operator]
+ return False
+
+ if self._upper is not None:
+ if self._bounds[1] == "]":
+ if x > self._upper: # type: ignore[operator]
+ return False
+ else:
+ if x >= self._upper: # type: ignore[operator]
+ return False
+
+ return True
+
+ def __bool__(self) -> bool:
+ return bool(self._bounds)
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, Range):
+ return False
+ return (
+ self._lower == other._lower
+ and self._upper == other._upper
+ and self._bounds == other._bounds
+ )
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
+
+ def __hash__(self) -> int:
+ return hash((self._lower, self._upper, self._bounds))
+
+ # as the postgres docs describe for the server-side stuff,
+ # ordering is rather arbitrary, but will remain stable
+ # and consistent.
+
+ def __lt__(self, other: Any) -> bool:
+ if not isinstance(other, Range):
+ return NotImplemented
+ for attr in ("_lower", "_upper", "_bounds"):
+ self_value = getattr(self, attr)
+ other_value = getattr(other, attr)
+ if self_value == other_value:
+ pass
+ elif self_value is None:
+ return True
+ elif other_value is None:
+ return False
+ else:
+ return cast(bool, self_value < other_value)
+ return False
+
+ def __le__(self, other: Any) -> bool:
+ if self == other:
+ return True
+ else:
+ return self.__lt__(other)
+
+ def __gt__(self, other: Any) -> bool:
+ if isinstance(other, Range):
+ return other.__lt__(self)
+ else:
+ return NotImplemented
+
+ def __ge__(self, other: Any) -> bool:
+ if self == other:
+ return True
+ else:
+ return self.__gt__(other)
+
+ def __getstate__(self) -> Dict[str, Any]:
+ return {
+ slot: getattr(self, slot)
+ for slot in self.__slots__
+ if hasattr(self, slot)
+ }
+
+ def __setstate__(self, state: Dict[str, Any]) -> None:
+ for slot, value in state.items():
+ setattr(self, slot, value)
+
+
+class RangeDumper(SequenceDumper):
+ """
+ Generic dumper for a range.
+
+ Subclasses shoud specify the type oid.
+ """
+
+ def dump(self, obj: Range[Any]) -> bytes:
+ if not obj:
+ return b"empty"
+ else:
+ return self._dump_sequence(
+ (obj.lower, obj.upper),
+ b"[" if obj.lower_inc else b"(",
+ b"]" if obj.upper_inc else b")",
+ b",",
+ )
+
+
+class RangeLoader(BaseCompositeLoader, Generic[T]):
+ """Generic loader for a range.
+
+ Subclasses shoud specify the oid of the subtype and the class to load.
+ """
+
+ subtype_oid: int
+ cls: Type[Range[T]]
+
+ def load(self, data: bytes) -> Range[T]:
+ if data == b"empty":
+ return self.cls(empty=True)
+
+ cast = self._tx.get_loader(self.subtype_oid, format=Format.TEXT).load
+ bounds = (data[:1] + data[-1:]).decode("utf-8")
+ min, max = (
+ cast(token) if token is not None else None
+ for token in self._parse_record(data[1:-1])
+ )
+ return self.cls(min, max, bounds)
+
+
+# Python wrappers for builtin range types
+
+
+class Int4Range(Range[int]):
+ pass
+
+
+class Int8Range(Range[int]):
+ pass
+
+
+class DecimalRange(Range[Decimal]):
+ pass
+
+
+class DateRange(Range[date]):
+ pass
+
+
+class DateTimeRange(Range[datetime]):
+ pass
+
+
+class DateTimeTZRange(Range[datetime]):
+ pass
+
+
+# Dumpers for builtin range types
+
+
+@Dumper.text(Int4Range)
+class Int4RangeDumper(RangeDumper):
+ oid = builtins["int4range"].oid
+
+
+@Dumper.text(Int8Range)
+class Int8RangeDumper(RangeDumper):
+ oid = builtins["int8range"].oid
+
+
+@Dumper.text(DecimalRange)
+class NumRangeDumper(RangeDumper):
+ oid = builtins["numrange"].oid
+
+
+@Dumper.text(DateRange)
+class DateRangeDumper(RangeDumper):
+ oid = builtins["daterange"].oid
+
+
+@Dumper.text(DateTimeRange)
+class TimestampRangeDumper(RangeDumper):
+ oid = builtins["tsrange"].oid
+
+
+@Dumper.text(DateTimeTZRange)
+class TimestampTZRangeDumper(RangeDumper):
+ oid = builtins["tstzrange"].oid
+
+
+# Loaders for builtin range types
+
+
+@Loader.text(builtins["int4range"].oid)
+class Int4RangeLoader(RangeLoader[int]):
+ subtype_oid = builtins["int4"].oid
+ cls = Int4Range
+
+
+@Loader.text(builtins["int8range"].oid)
+class Int8RangeLoader(RangeLoader[int]):
+ subtype_oid = builtins["int8"].oid
+ cls = Int8Range
+
+
+@Loader.text(builtins["numrange"].oid)
+class NumericRangeLoader(RangeLoader[Decimal]):
+ subtype_oid = builtins["numeric"].oid
+ cls = DecimalRange
+
+
+@Loader.text(builtins["daterange"].oid)
+class DateRangeLoader(RangeLoader[date]):
+ subtype_oid = builtins["numeric"].oid
+ cls = DateRange
+
+
+@Loader.text(builtins["tsrange"].oid)
+class TimestampRangeLoader(RangeLoader[datetime]):
+ subtype_oid = builtins["timestamp"].oid
+ cls = DateTimeRange
+
+
+@Loader.text(builtins["tstzrange"].oid)
+class TimestampTZRangeLoader(RangeLoader[datetime]):
+ subtype_oid = builtins["timestamptz"].oid
+ cls = DateTimeTZRange
--- /dev/null
+import pytest
+
+from psycopg3.types import range
+
+
+type2cls = {
+ "int4range": range.Int4Range,
+ "int8range": range.Int8Range,
+ "numrange": range.DecimalRange,
+ "daterange": range.DateRange,
+ "tsrange": range.DateTimeRange,
+ "tstzrange": range.DateTimeTZRange,
+}
+type2sub = {
+ "int4range": "int4",
+ "int8range": "int8",
+ "numrange": "numeric",
+ "daterange": "date",
+ "tsrange": "timestamp",
+ "tstzrange": "timestamptz",
+}
+
+samples = [
+ ("int4range", None, None, "()"),
+ ("int4range", 10, 20, "[]"),
+ ("int4range", -(2 ** 31), (2 ** 31) - 1, "[)"),
+ ("int8range", None, None, "()"),
+ ("int8range", 10, 20, "[)"),
+ ("int8range", -(2 ** 63), (2 ** 63) - 1, "[)"),
+ # TODO: complete samples
+]
+
+
+@pytest.mark.parametrize(
+ "pgtype",
+ "int4range int8range numrange daterange tsrange tstzrange".split(),
+)
+def test_dump_builtin_range_empty(conn, pgtype):
+ r = type2cls[pgtype](empty=True)
+ cur = conn.cursor()
+ cur.execute(f"select 'empty'::{pgtype} = %s", (r,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype, min, max, bounds", samples)
+def test_dump_builtin_range(conn, pgtype, min, max, bounds):
+ r = type2cls[pgtype](min, max, bounds)
+ sub = type2sub[pgtype]
+ cur = conn.cursor()
+ cur.execute(
+ f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %s",
+ (min, max, bounds, r),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "pgtype",
+ "int4range int8range numrange daterange tsrange tstzrange".split(),
+)
+def test_load_builtin_range_empty(conn, pgtype):
+ r = type2cls[pgtype](empty=True)
+ cur = conn.cursor()
+ (got,) = cur.execute(f"select 'empty'::{pgtype}").fetchone()
+ assert type(got) is type2cls[pgtype]
+ assert got == r
+
+
+@pytest.mark.parametrize("pgtype, min, max, bounds", samples)
+def test_load_builtin_range(conn, pgtype, min, max, bounds):
+ r = type2cls[pgtype](min, max, bounds)
+ sub = type2sub[pgtype]
+ cur = conn.cursor()
+ cur.execute(
+ f"select {pgtype}(%s::{sub}, %s::{sub}, %s)", (min, max, bounds)
+ )
+ # normalise discrete ranges
+ if r.upper_inc and isinstance(r.upper, int):
+ bounds = "[)" if r.lower_inc else "()"
+ r = type(r)(r.lower, r.upper + 1, bounds)
+ assert cur.fetchone()[0] == r