# Copyright (C) 2020-2021 The Psycopg Team
import re
-from typing import cast, Any, Dict, Generic, Optional, Tuple, TypeVar, Type
+from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Type, Union
+from typing import cast
from decimal import Decimal
from datetime import date, datetime
from ..pq import Format
from ..oids import postgres_types as builtins, INVALID_OID
-from ..adapt import Dumper, RecursiveLoader, Format as Pg3Format
+from ..adapt import Dumper, RecursiveDumper, RecursiveLoader
+from ..adapt import Format as Pg3Format
from ..proto import AdaptContext, Buffer
-from .._struct import unpack_len
+from .._struct import pack_len, unpack_len
from .._typeinfo import RangeInfo
from .composite import SequenceDumper, BaseCompositeLoader
RANGE_LB_INF = 0x08 # lower bound is -infinity
RANGE_UB_INF = 0x10 # upper bound is +infinity
+_EMPTY_HEAD = bytes([RANGE_EMPTY])
+
T = TypeVar("T")
setattr(self, slot, value)
-class RangeDumper(SequenceDumper):
- """
- Dumper for range types.
-
- The dumper can upgrade to one specific for a different range type.
- """
-
- format = Format.TEXT
-
+class BaseRangeDumper(RecursiveDumper):
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
super().__init__(cls, context)
self.sub_dumper: Optional[Dumper] = None
self._types = context.adapters.types if context else builtins
-
- def dump(self, obj: Range[Any]) -> bytes:
- if not obj:
- return b"empty"
- else:
- return self._dump_sequence(
- (obj.lower, obj.upper),
- b"[" if obj.lower_inc else b"(",
- b"]" if obj.upper_inc else b")",
- b",",
- )
-
- _re_needs_quotes = re.compile(br'[",\\\s()\[\]]')
+ self._adapt_format = Pg3Format.from_pq(self.format)
def get_key(self, obj: Range[Any], format: Pg3Format) -> Tuple[type, ...]:
item = self._get_item(obj)
if item is not None:
- # TODO: binary range support
- sd = self._tx.get_dumper(item, Pg3Format.TEXT)
+ sd = self._tx.get_dumper(item, self._adapt_format)
return (self.cls, sd.cls)
else:
return (self.cls,)
- def upgrade(self, obj: Range[Any], format: Pg3Format) -> "RangeDumper":
+ def upgrade(self, obj: Range[Any], format: Pg3Format) -> "BaseRangeDumper":
item = self._get_item(obj)
if item is None:
return RangeDumper(self.cls)
- # TODO: binary range support
- sd = self._tx.get_dumper(item, Pg3Format.TEXT)
- dumper = type(self)(self.cls, self._tx)
- dumper.sub_dumper = sd
- if isinstance(item, int):
+ dumper: BaseRangeDumper
+ if type(item) is int:
# postgres won't cast int4range -> int8range so we must use
# text format and unknown oid here
+ sd = self._tx.get_dumper(item, Pg3Format.TEXT)
+ dumper = RangeDumper(self.cls, self._tx)
+ dumper.sub_dumper = sd
dumper.oid = INVALID_OID
- elif isinstance(item, str) and sd.oid == INVALID_OID:
+ return dumper
+
+ sd = self._tx.get_dumper(item, format)
+ dumper = type(self)(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ if sd.oid == INVALID_OID and isinstance(item, str):
# Work around the normal mapping where text is dumped as unknown
dumper.oid = self._get_range_oid(self._types["text"].oid)
else:
dumper.oid = self._get_range_oid(sd.oid)
+
return dumper
def _get_item(self, obj: Range[Any]) -> Any:
return info.oid if info else INVALID_OID
+class RangeDumper(BaseRangeDumper, SequenceDumper):
+ """
+ Dumper for range types.
+
+ The dumper can upgrade to one specific for a different range type.
+ """
+
+ format = Format.TEXT
+
+ def dump(self, obj: Range[Any]) -> bytes:
+ if not obj:
+ return b"empty"
+ else:
+ return self._dump_sequence(
+ (obj.lower, obj.upper),
+ b"[" if obj.lower_inc else b"(",
+ b"]" if obj.upper_inc else b")",
+ b",",
+ )
+
+ _re_needs_quotes = re.compile(br'[",\\\s()\[\]]')
+
+
+class BinaryRangeDumper(BaseRangeDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: Range[Any]) -> Union[bytes, bytearray]:
+ if not obj:
+ return _EMPTY_HEAD
+
+ out = bytearray([0]) # will replace the head later
+
+ head = 0
+ if obj.lower_inc:
+ head |= RANGE_LB_INC
+ if obj.upper_inc:
+ head |= RANGE_UB_INC
+
+ item = self._get_item(obj)
+ if item is not None:
+ dump = self._tx.get_dumper(item, self._adapt_format).dump
+
+ if obj.lower is not None:
+ data = dump(obj.lower)
+ out += pack_len(len(data))
+ out += data
+ else:
+ head |= RANGE_LB_INF
+
+ if obj.upper is not None:
+ data = dump(obj.upper)
+ out += pack_len(len(data))
+ out += data
+ else:
+ head |= RANGE_UB_INF
+
+ out[0] = head
+ return out
+
+
class RangeLoader(BaseCompositeLoader, Generic[T]):
"""Generic loader for a range.
length = unpack_len(data, pos)[0]
pos += 4
max = load(data[pos : pos + length])
- pos += length
return Range(min, max, lb + ub)
from psycopg3 import pq
from psycopg3.sql import Identifier
+from psycopg3.adapt import Format
from psycopg3.types import Range, RangeInfo
"pgtype",
"int4range int8range numrange daterange tsrange tstzrange".split(),
)
-def test_dump_builtin_empty(conn, pgtype):
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_builtin_empty(conn, pgtype, fmt_in):
r = Range(empty=True)
- cur = conn.execute(f"select 'empty'::{pgtype} = %s", (r,))
+ cur = conn.execute(f"select 'empty'::{pgtype} = %{fmt_in}", (r,))
assert cur.fetchone()[0] is True
"pgtype",
"int4range int8range numrange daterange tsrange tstzrange".split(),
)
-def test_dump_builtin_array(conn, pgtype):
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_builtin_array(conn, pgtype, fmt_in):
r1 = Range(empty=True)
r2 = Range(bounds="()")
cur = conn.execute(
- f"select array['empty'::{pgtype}, '(,)'::{pgtype}] = %s",
+ f"select array['empty'::{pgtype}, '(,)'::{pgtype}] = %{fmt_in}",
([r1, r2],),
)
assert cur.fetchone()[0] is True
@pytest.mark.parametrize("pgtype, min, max, bounds", samples)
-def test_dump_builtin_range(conn, pgtype, min, max, bounds):
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_builtin_range(conn, pgtype, min, max, bounds, fmt_in):
r = Range(min, max, bounds)
sub = type2sub[pgtype]
cur = conn.execute(
- f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %s::{pgtype}",
+ f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %{fmt_in}",
(min, max, bounds, r),
)
assert cur.fetchone()[0] is True