From 42d1c23a3f64a28cc9486eb440e956fb6b35ee4a Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sat, 2 Oct 2021 03:48:34 +0200 Subject: [PATCH] Refactor range adapters to make dumper and loader functions reusable --- psycopg/psycopg/types/composite.py | 2 +- psycopg/psycopg/types/range.py | 246 ++++++++++++++++++++--------- 2 files changed, 172 insertions(+), 76 deletions(-) diff --git a/psycopg/psycopg/types/composite.py b/psycopg/psycopg/types/composite.py index eb966c75e..b69de7d6b 100644 --- a/psycopg/psycopg/types/composite.py +++ b/psycopg/psycopg/types/composite.py @@ -30,7 +30,7 @@ class SequenceDumper(RecursiveDumper): self, obj: Sequence[Any], start: bytes, end: bytes, sep: bytes ) -> bytes: if not obj: - return b"()" + return start + end parts = [start] diff --git a/psycopg/psycopg/types/range.py b/psycopg/psycopg/types/range.py index f8429072a..34c6c8e53 100644 --- a/psycopg/psycopg/types/range.py +++ b/psycopg/psycopg/types/range.py @@ -5,11 +5,12 @@ Support for range types adaptation. # Copyright (C) 2020-2021 The Psycopg Team import re -from typing import Any, Dict, Generic, Optional, TypeVar, Type, Union +from typing import Any, Callable, Dict, Generic, Optional, TypeVar, Type, Tuple from typing import cast from decimal import Decimal from datetime import date, datetime +from .. import errors as e from .. import postgres from ..pq import Format from ..abc import AdaptContext, Buffer, Dumper, DumperKey @@ -17,7 +18,6 @@ from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat from .._struct import pack_len, unpack_len from ..postgres import INVALID_OID, TEXT_OID from .._typeinfo import RangeInfo as RangeInfo # exported here -from .composite import SequenceDumper, BaseCompositeLoader RANGE_EMPTY = 0x01 # range is empty RANGE_LB_INC = 0x02 # lower bound is inclusive @@ -319,120 +319,216 @@ class BaseRangeDumper(RecursiveDumper): return info.oid if info else INVALID_OID -class RangeDumper(BaseRangeDumper, SequenceDumper): +class RangeDumper(BaseRangeDumper): """ Dumper for range types. The dumper can upgrade to one specific for a different range type. """ - def dump(self, obj: Range[Any]) -> bytes: - if not obj: - return b"empty" + def dump(self, obj: Range[Any]) -> Buffer: + item = self._get_item(obj) + if item is not None: + dump = self._tx.get_dumper(item, self._adapt_format).dump else: - return self._dump_sequence( - (obj.lower, obj.upper), - b"[" if obj.lower_inc else b"(", - b"]" if obj.upper_inc else b")", - b",", - ) + dump = fail_dump - _re_needs_quotes = re.compile(br'[",\\\s()\[\]]') + return dump_range_text(obj, dump) -class RangeBinaryDumper(BaseRangeDumper): +def dump_range_text(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer: + if obj.isempty: + return b"empty" - format = Format.BINARY + parts = [b"[" if obj.lower_inc else b"("] + + def dump_item(item: Any) -> Buffer: + ad = dump(item) + if not ad: + return b'""' + elif _re_needs_quotes.search(ad): + return b'"' + _re_esc.sub(br"\1\1", ad) + b'"' + else: + return ad + + if obj.lower is not None: + parts.append(dump_item(obj.lower)) + + parts.append(b",") + + if obj.upper is not None: + parts.append(dump_item(obj.upper)) - def dump(self, obj: Range[Any]) -> Union[bytes, bytearray]: - if not obj: - return _EMPTY_HEAD + parts.append(b"]" if obj.upper_inc else b")") - out = bytearray([0]) # will replace the head later + return b"".join(parts) - head = 0 - if obj.lower_inc: - head |= RANGE_LB_INC - if obj.upper_inc: - head |= RANGE_UB_INC +_re_needs_quotes = re.compile(br'[",\\\s()\[\]]') +_re_esc = re.compile(br"([\\\"])") + + +class RangeBinaryDumper(BaseRangeDumper): + + format = Format.BINARY + + def dump(self, obj: Range[Any]) -> Buffer: item = self._get_item(obj) if item is not None: dump = self._tx.get_dumper(item, self._adapt_format).dump - - if obj.lower is not None: - data = dump(obj.lower) - out += pack_len(len(data)) - out += data else: - head |= RANGE_LB_INF + dump = fail_dump - if obj.upper is not None: - data = dump(obj.upper) - out += pack_len(len(data)) - out += data - else: - head |= RANGE_UB_INF + return dump_range_binary(obj, dump) + + +def dump_range_binary( + obj: Range[Any], dump: Callable[[Any], Buffer] +) -> Buffer: + if not obj: + return _EMPTY_HEAD - out[0] = head - return out + out = bytearray([0]) # will replace the head later + head = 0 + if obj.lower_inc: + head |= RANGE_LB_INC + if obj.upper_inc: + head |= RANGE_UB_INC -class RangeLoader(BaseCompositeLoader, Generic[T]): + if obj.lower is not None: + data = dump(obj.lower) + out += pack_len(len(data)) + out += data + else: + head |= RANGE_LB_INF + + if obj.upper is not None: + data = dump(obj.upper) + out += pack_len(len(data)) + out += data + else: + head |= RANGE_UB_INF + + out[0] = head + return out + + +def fail_dump(obj: Any) -> Buffer: + raise e.InternalError("trying to dump a range element without information") + + +class BaseRangeLoader(RecursiveLoader, Generic[T]): """Generic loader for a range. - Subclasses shoud specify the oid of the subtype and the class to load. + Subclasses must specify the oid of the subtype and the class to load. """ subtype_oid: int + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._load = self._tx.get_loader( + self.subtype_oid, format=self.format + ).load + + +class RangeLoader(BaseRangeLoader[T]): def load(self, data: Buffer) -> Range[T]: - if data == b"empty": - return Range(empty=True) - - cast = self._tx.get_loader(self.subtype_oid, format=Format.TEXT).load - bounds = _int2parens[data[0]] + _int2parens[data[-1]] - min, max = ( - cast(token) if token is not None else None - for token in self._parse_record(data[1:-1]) + return load_range_text(data, self._load)[0] + + +def load_range_text( + data: Buffer, load: Callable[[Buffer], Any] +) -> Tuple[Range[Any], int]: + if data == b"empty": + return Range(empty=True), 5 + + m = _re_range.match(data) + if m is None: + raise e.DataError( + f"failed to parse range: '{bytes(data).decode('utf8', 'replace')}'" ) - return Range(min, max, bounds) + lower = None + item = m.group(3) + if item is None: + item = m.group(2) + if item is not None: + lower = load(_re_undouble.sub(rb"\1", item)) + else: + lower = load(item) + + upper = None + item = m.group(5) + if item is None: + item = m.group(4) + if item is not None: + upper = load(_re_undouble.sub(r"\1", item)) + else: + upper = load(item) -class RangeBinaryLoader(RecursiveLoader, Generic[T]): + bounds = (m.group(1) + m.group(6)).decode() - format = Format.BINARY - subtype_oid: int + return Range(lower, upper, bounds), m.end() - def load(self, data: Buffer) -> Range[T]: - head = data[0] - if head & RANGE_EMPTY: - return Range(empty=True) - load = self._tx.get_loader(self.subtype_oid, format=Format.BINARY).load - lb = "[" if head & RANGE_LB_INC else "(" - ub = "]" if head & RANGE_UB_INC else ")" +_re_range = re.compile( + rb""" + ( \(|\[ ) # lower bound flag + (?: # lower bound: + " ( (?: [^"] | "")* ) " # - a quoted string + | ( [^",]+ ) # - or an unquoted string + )? # - or empty (not catched) + , + (?: # upper bound: + " ( (?: [^"] | "")* ) " # - a quoted string + | ( [^"\)\]]+ ) # - or an unquoted string + )? # - or empty (not catched) + ( \)|\] ) # upper bound flag + """, + re.VERBOSE, +) - pos = 1 # after the head - if head & RANGE_LB_INF: - min = None - else: - length = unpack_len(data, pos)[0] - pos += 4 - min = load(data[pos : pos + length]) - pos += length +_re_undouble = re.compile(rb'(["\\])\1') - if head & RANGE_UB_INF: - max = None - else: - length = unpack_len(data, pos)[0] - pos += 4 - max = load(data[pos : pos + length]) - return Range(min, max, lb + ub) +class RangeBinaryLoader(BaseRangeLoader[T]): + format = Format.BINARY -_int2parens = {ord(c): c for c in "[]()"} + def load(self, data: Buffer) -> Range[T]: + return load_range_binary(data, self._load) + + +def load_range_binary( + data: Buffer, load: Callable[[Buffer], Any] +) -> Range[Any]: + head = data[0] + if head & RANGE_EMPTY: + return Range(empty=True) + + lb = "[" if head & RANGE_LB_INC else "(" + ub = "]" if head & RANGE_UB_INC else ")" + + pos = 1 # after the head + if head & RANGE_LB_INF: + min = None + else: + length = unpack_len(data, pos)[0] + pos += 4 + min = load(data[pos : pos + length]) + pos += length + + if head & RANGE_UB_INF: + max = None + else: + length = unpack_len(data, pos)[0] + pos += 4 + max = load(data[pos : pos + length]) + pos += length + + return Range(min, max, lb + ub) def register_range( -- 2.47.2