# Copyright (C) 2020-2021 The Psycopg Team
import re
-from typing import Any, Dict, Generic, Optional, TypeVar, Type, Union
+from typing import Any, Callable, Dict, Generic, Optional, TypeVar, Type, Tuple
from typing import cast
from decimal import Decimal
from datetime import date, datetime
+from .. import errors as e
from .. import postgres
from ..pq import Format
from ..abc import AdaptContext, Buffer, Dumper, DumperKey
from .._struct import pack_len, unpack_len
from ..postgres import INVALID_OID, TEXT_OID
from .._typeinfo import RangeInfo as RangeInfo # exported here
-from .composite import SequenceDumper, BaseCompositeLoader
RANGE_EMPTY = 0x01 # range is empty
RANGE_LB_INC = 0x02 # lower bound is inclusive
return info.oid if info else INVALID_OID
-class RangeDumper(BaseRangeDumper, SequenceDumper):
+class RangeDumper(BaseRangeDumper):
"""
Dumper for range types.
The dumper can upgrade to one specific for a different range type.
"""
- def dump(self, obj: Range[Any]) -> bytes:
- if not obj:
- return b"empty"
+ def dump(self, obj: Range[Any]) -> Buffer:
+ item = self._get_item(obj)
+ if item is not None:
+ dump = self._tx.get_dumper(item, self._adapt_format).dump
else:
- return self._dump_sequence(
- (obj.lower, obj.upper),
- b"[" if obj.lower_inc else b"(",
- b"]" if obj.upper_inc else b")",
- b",",
- )
+ dump = fail_dump
- _re_needs_quotes = re.compile(br'[",\\\s()\[\]]')
+ return dump_range_text(obj, dump)
-class RangeBinaryDumper(BaseRangeDumper):
+def dump_range_text(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer:
+ if obj.isempty:
+ return b"empty"
- format = Format.BINARY
+ parts = [b"[" if obj.lower_inc else b"("]
+
+ def dump_item(item: Any) -> Buffer:
+ ad = dump(item)
+ if not ad:
+ return b'""'
+ elif _re_needs_quotes.search(ad):
+ return b'"' + _re_esc.sub(br"\1\1", ad) + b'"'
+ else:
+ return ad
+
+ if obj.lower is not None:
+ parts.append(dump_item(obj.lower))
+
+ parts.append(b",")
+
+ if obj.upper is not None:
+ parts.append(dump_item(obj.upper))
- def dump(self, obj: Range[Any]) -> Union[bytes, bytearray]:
- if not obj:
- return _EMPTY_HEAD
+ parts.append(b"]" if obj.upper_inc else b")")
- out = bytearray([0]) # will replace the head later
+ return b"".join(parts)
- head = 0
- if obj.lower_inc:
- head |= RANGE_LB_INC
- if obj.upper_inc:
- head |= RANGE_UB_INC
+_re_needs_quotes = re.compile(br'[",\\\s()\[\]]')
+_re_esc = re.compile(br"([\\\"])")
+
+
+class RangeBinaryDumper(BaseRangeDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: Range[Any]) -> Buffer:
item = self._get_item(obj)
if item is not None:
dump = self._tx.get_dumper(item, self._adapt_format).dump
-
- if obj.lower is not None:
- data = dump(obj.lower)
- out += pack_len(len(data))
- out += data
else:
- head |= RANGE_LB_INF
+ dump = fail_dump
- if obj.upper is not None:
- data = dump(obj.upper)
- out += pack_len(len(data))
- out += data
- else:
- head |= RANGE_UB_INF
+ return dump_range_binary(obj, dump)
+
+
+def dump_range_binary(
+ obj: Range[Any], dump: Callable[[Any], Buffer]
+) -> Buffer:
+ if not obj:
+ return _EMPTY_HEAD
- out[0] = head
- return out
+ out = bytearray([0]) # will replace the head later
+ head = 0
+ if obj.lower_inc:
+ head |= RANGE_LB_INC
+ if obj.upper_inc:
+ head |= RANGE_UB_INC
-class RangeLoader(BaseCompositeLoader, Generic[T]):
+ if obj.lower is not None:
+ data = dump(obj.lower)
+ out += pack_len(len(data))
+ out += data
+ else:
+ head |= RANGE_LB_INF
+
+ if obj.upper is not None:
+ data = dump(obj.upper)
+ out += pack_len(len(data))
+ out += data
+ else:
+ head |= RANGE_UB_INF
+
+ out[0] = head
+ return out
+
+
+def fail_dump(obj: Any) -> Buffer:
+ raise e.InternalError("trying to dump a range element without information")
+
+
+class BaseRangeLoader(RecursiveLoader, Generic[T]):
"""Generic loader for a range.
- Subclasses shoud specify the oid of the subtype and the class to load.
+ Subclasses must specify the oid of the subtype and the class to load.
"""
subtype_oid: int
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._load = self._tx.get_loader(
+ self.subtype_oid, format=self.format
+ ).load
+
+
+class RangeLoader(BaseRangeLoader[T]):
def load(self, data: Buffer) -> Range[T]:
- if data == b"empty":
- return Range(empty=True)
-
- cast = self._tx.get_loader(self.subtype_oid, format=Format.TEXT).load
- bounds = _int2parens[data[0]] + _int2parens[data[-1]]
- min, max = (
- cast(token) if token is not None else None
- for token in self._parse_record(data[1:-1])
+ return load_range_text(data, self._load)[0]
+
+
+def load_range_text(
+ data: Buffer, load: Callable[[Buffer], Any]
+) -> Tuple[Range[Any], int]:
+ if data == b"empty":
+ return Range(empty=True), 5
+
+ m = _re_range.match(data)
+ if m is None:
+ raise e.DataError(
+ f"failed to parse range: '{bytes(data).decode('utf8', 'replace')}'"
)
- return Range(min, max, bounds)
+ lower = None
+ item = m.group(3)
+ if item is None:
+ item = m.group(2)
+ if item is not None:
+ lower = load(_re_undouble.sub(rb"\1", item))
+ else:
+ lower = load(item)
+
+ upper = None
+ item = m.group(5)
+ if item is None:
+ item = m.group(4)
+ if item is not None:
+ upper = load(_re_undouble.sub(r"\1", item))
+ else:
+ upper = load(item)
-class RangeBinaryLoader(RecursiveLoader, Generic[T]):
+ bounds = (m.group(1) + m.group(6)).decode()
- format = Format.BINARY
- subtype_oid: int
+ return Range(lower, upper, bounds), m.end()
- def load(self, data: Buffer) -> Range[T]:
- head = data[0]
- if head & RANGE_EMPTY:
- return Range(empty=True)
- load = self._tx.get_loader(self.subtype_oid, format=Format.BINARY).load
- lb = "[" if head & RANGE_LB_INC else "("
- ub = "]" if head & RANGE_UB_INC else ")"
+_re_range = re.compile(
+ rb"""
+ ( \(|\[ ) # lower bound flag
+ (?: # lower bound:
+ " ( (?: [^"] | "")* ) " # - a quoted string
+ | ( [^",]+ ) # - or an unquoted string
+ )? # - or empty (not catched)
+ ,
+ (?: # upper bound:
+ " ( (?: [^"] | "")* ) " # - a quoted string
+ | ( [^"\)\]]+ ) # - or an unquoted string
+ )? # - or empty (not catched)
+ ( \)|\] ) # upper bound flag
+ """,
+ re.VERBOSE,
+)
- pos = 1 # after the head
- if head & RANGE_LB_INF:
- min = None
- else:
- length = unpack_len(data, pos)[0]
- pos += 4
- min = load(data[pos : pos + length])
- pos += length
+_re_undouble = re.compile(rb'(["\\])\1')
- if head & RANGE_UB_INF:
- max = None
- else:
- length = unpack_len(data, pos)[0]
- pos += 4
- max = load(data[pos : pos + length])
- return Range(min, max, lb + ub)
+class RangeBinaryLoader(BaseRangeLoader[T]):
+ format = Format.BINARY
-_int2parens = {ord(c): c for c in "[]()"}
+ def load(self, data: Buffer) -> Range[T]:
+ return load_range_binary(data, self._load)
+
+
+def load_range_binary(
+ data: Buffer, load: Callable[[Buffer], Any]
+) -> Range[Any]:
+ head = data[0]
+ if head & RANGE_EMPTY:
+ return Range(empty=True)
+
+ lb = "[" if head & RANGE_LB_INC else "("
+ ub = "]" if head & RANGE_UB_INC else ")"
+
+ pos = 1 # after the head
+ if head & RANGE_LB_INF:
+ min = None
+ else:
+ length = unpack_len(data, pos)[0]
+ pos += 4
+ min = load(data[pos : pos + length])
+ pos += length
+
+ if head & RANGE_UB_INF:
+ max = None
+ else:
+ length = unpack_len(data, pos)[0]
+ pos += 4
+ max = load(data[pos : pos + length])
+ pos += length
+
+ return Range(min, max, lb + ub)
def register_range(