]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Refactor range adapters to make dumper and loader functions reusable
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Oct 2021 01:48:34 +0000 (03:48 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 4 Oct 2021 12:45:56 +0000 (14:45 +0200)
psycopg/psycopg/types/composite.py
psycopg/psycopg/types/range.py

index eb966c75e7beba79fea6ebbace4a55a7cf0afc9f..b69de7d6bb7afda0a6b38fec2bef557489255333 100644 (file)
@@ -30,7 +30,7 @@ class SequenceDumper(RecursiveDumper):
         self, obj: Sequence[Any], start: bytes, end: bytes, sep: bytes
     ) -> bytes:
         if not obj:
-            return b"()"
+            return start + end
 
         parts = [start]
 
index f8429072ae999b085920780ab96a269b6e0008a5..34c6c8e533bf34cbf3de159fa45d721c0bc20003 100644 (file)
@@ -5,11 +5,12 @@ Support for range types adaptation.
 # Copyright (C) 2020-2021 The Psycopg Team
 
 import re
-from typing import Any, Dict, Generic, Optional, TypeVar, Type, Union
+from typing import Any, Callable, Dict, Generic, Optional, TypeVar, Type, Tuple
 from typing import cast
 from decimal import Decimal
 from datetime import date, datetime
 
+from .. import errors as e
 from .. import postgres
 from ..pq import Format
 from ..abc import AdaptContext, Buffer, Dumper, DumperKey
@@ -17,7 +18,6 @@ from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
 from .._struct import pack_len, unpack_len
 from ..postgres import INVALID_OID, TEXT_OID
 from .._typeinfo import RangeInfo as RangeInfo  # exported here
-from .composite import SequenceDumper, BaseCompositeLoader
 
 RANGE_EMPTY = 0x01  # range is empty
 RANGE_LB_INC = 0x02  # lower bound is inclusive
@@ -319,120 +319,216 @@ class BaseRangeDumper(RecursiveDumper):
         return info.oid if info else INVALID_OID
 
 
-class RangeDumper(BaseRangeDumper, SequenceDumper):
+class RangeDumper(BaseRangeDumper):
     """
     Dumper for range types.
 
     The dumper can upgrade to one specific for a different range type.
     """
 
-    def dump(self, obj: Range[Any]) -> bytes:
-        if not obj:
-            return b"empty"
+    def dump(self, obj: Range[Any]) -> Buffer:
+        item = self._get_item(obj)
+        if item is not None:
+            dump = self._tx.get_dumper(item, self._adapt_format).dump
         else:
-            return self._dump_sequence(
-                (obj.lower, obj.upper),
-                b"[" if obj.lower_inc else b"(",
-                b"]" if obj.upper_inc else b")",
-                b",",
-            )
+            dump = fail_dump
 
-    _re_needs_quotes = re.compile(br'[",\\\s()\[\]]')
+        return dump_range_text(obj, dump)
 
 
-class RangeBinaryDumper(BaseRangeDumper):
+def dump_range_text(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer:
+    if obj.isempty:
+        return b"empty"
 
-    format = Format.BINARY
+    parts = [b"[" if obj.lower_inc else b"("]
+
+    def dump_item(item: Any) -> Buffer:
+        ad = dump(item)
+        if not ad:
+            return b'""'
+        elif _re_needs_quotes.search(ad):
+            return b'"' + _re_esc.sub(br"\1\1", ad) + b'"'
+        else:
+            return ad
+
+    if obj.lower is not None:
+        parts.append(dump_item(obj.lower))
+
+    parts.append(b",")
+
+    if obj.upper is not None:
+        parts.append(dump_item(obj.upper))
 
-    def dump(self, obj: Range[Any]) -> Union[bytes, bytearray]:
-        if not obj:
-            return _EMPTY_HEAD
+    parts.append(b"]" if obj.upper_inc else b")")
 
-        out = bytearray([0])  # will replace the head later
+    return b"".join(parts)
 
-        head = 0
-        if obj.lower_inc:
-            head |= RANGE_LB_INC
-        if obj.upper_inc:
-            head |= RANGE_UB_INC
 
+_re_needs_quotes = re.compile(br'[",\\\s()\[\]]')
+_re_esc = re.compile(br"([\\\"])")
+
+
+class RangeBinaryDumper(BaseRangeDumper):
+
+    format = Format.BINARY
+
+    def dump(self, obj: Range[Any]) -> Buffer:
         item = self._get_item(obj)
         if item is not None:
             dump = self._tx.get_dumper(item, self._adapt_format).dump
-
-        if obj.lower is not None:
-            data = dump(obj.lower)
-            out += pack_len(len(data))
-            out += data
         else:
-            head |= RANGE_LB_INF
+            dump = fail_dump
 
-        if obj.upper is not None:
-            data = dump(obj.upper)
-            out += pack_len(len(data))
-            out += data
-        else:
-            head |= RANGE_UB_INF
+        return dump_range_binary(obj, dump)
+
+
+def dump_range_binary(
+    obj: Range[Any], dump: Callable[[Any], Buffer]
+) -> Buffer:
+    if not obj:
+        return _EMPTY_HEAD
 
-        out[0] = head
-        return out
+    out = bytearray([0])  # will replace the head later
 
+    head = 0
+    if obj.lower_inc:
+        head |= RANGE_LB_INC
+    if obj.upper_inc:
+        head |= RANGE_UB_INC
 
-class RangeLoader(BaseCompositeLoader, Generic[T]):
+    if obj.lower is not None:
+        data = dump(obj.lower)
+        out += pack_len(len(data))
+        out += data
+    else:
+        head |= RANGE_LB_INF
+
+    if obj.upper is not None:
+        data = dump(obj.upper)
+        out += pack_len(len(data))
+        out += data
+    else:
+        head |= RANGE_UB_INF
+
+    out[0] = head
+    return out
+
+
+def fail_dump(obj: Any) -> Buffer:
+    raise e.InternalError("trying to dump a range element without information")
+
+
+class BaseRangeLoader(RecursiveLoader, Generic[T]):
     """Generic loader for a range.
 
-    Subclasses shoud specify the oid of the subtype and the class to load.
+    Subclasses must specify the oid of the subtype and the class to load.
     """
 
     subtype_oid: int
 
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+        super().__init__(oid, context)
+        self._load = self._tx.get_loader(
+            self.subtype_oid, format=self.format
+        ).load
+
+
+class RangeLoader(BaseRangeLoader[T]):
     def load(self, data: Buffer) -> Range[T]:
-        if data == b"empty":
-            return Range(empty=True)
-
-        cast = self._tx.get_loader(self.subtype_oid, format=Format.TEXT).load
-        bounds = _int2parens[data[0]] + _int2parens[data[-1]]
-        min, max = (
-            cast(token) if token is not None else None
-            for token in self._parse_record(data[1:-1])
+        return load_range_text(data, self._load)[0]
+
+
+def load_range_text(
+    data: Buffer, load: Callable[[Buffer], Any]
+) -> Tuple[Range[Any], int]:
+    if data == b"empty":
+        return Range(empty=True), 5
+
+    m = _re_range.match(data)
+    if m is None:
+        raise e.DataError(
+            f"failed to parse range: '{bytes(data).decode('utf8', 'replace')}'"
         )
-        return Range(min, max, bounds)
 
+    lower = None
+    item = m.group(3)
+    if item is None:
+        item = m.group(2)
+        if item is not None:
+            lower = load(_re_undouble.sub(rb"\1", item))
+    else:
+        lower = load(item)
+
+    upper = None
+    item = m.group(5)
+    if item is None:
+        item = m.group(4)
+        if item is not None:
+            upper = load(_re_undouble.sub(r"\1", item))
+    else:
+        upper = load(item)
 
-class RangeBinaryLoader(RecursiveLoader, Generic[T]):
+    bounds = (m.group(1) + m.group(6)).decode()
 
-    format = Format.BINARY
-    subtype_oid: int
+    return Range(lower, upper, bounds), m.end()
 
-    def load(self, data: Buffer) -> Range[T]:
-        head = data[0]
-        if head & RANGE_EMPTY:
-            return Range(empty=True)
 
-        load = self._tx.get_loader(self.subtype_oid, format=Format.BINARY).load
-        lb = "[" if head & RANGE_LB_INC else "("
-        ub = "]" if head & RANGE_UB_INC else ")"
+_re_range = re.compile(
+    rb"""
+    ( \(|\[ )                   # lower bound flag
+    (?:                         # lower bound:
+      " ( (?: [^"] | "")* ) "   #   - a quoted string
+      | ( [^",]+ )              #   - or an unquoted string
+    )?                          #   - or empty (not catched)
+    ,
+    (?:                         # upper bound:
+      " ( (?: [^"] | "")* ) "   #   - a quoted string
+      | ( [^"\)\]]+ )           #   - or an unquoted string
+    )?                          #   - or empty (not catched)
+    ( \)|\] )                   # upper bound flag
+    """,
+    re.VERBOSE,
+)
 
-        pos = 1  # after the head
-        if head & RANGE_LB_INF:
-            min = None
-        else:
-            length = unpack_len(data, pos)[0]
-            pos += 4
-            min = load(data[pos : pos + length])
-            pos += length
+_re_undouble = re.compile(rb'(["\\])\1')
 
-        if head & RANGE_UB_INF:
-            max = None
-        else:
-            length = unpack_len(data, pos)[0]
-            pos += 4
-            max = load(data[pos : pos + length])
 
-        return Range(min, max, lb + ub)
+class RangeBinaryLoader(BaseRangeLoader[T]):
 
+    format = Format.BINARY
 
-_int2parens = {ord(c): c for c in "[]()"}
+    def load(self, data: Buffer) -> Range[T]:
+        return load_range_binary(data, self._load)
+
+
+def load_range_binary(
+    data: Buffer, load: Callable[[Buffer], Any]
+) -> Range[Any]:
+    head = data[0]
+    if head & RANGE_EMPTY:
+        return Range(empty=True)
+
+    lb = "[" if head & RANGE_LB_INC else "("
+    ub = "]" if head & RANGE_UB_INC else ")"
+
+    pos = 1  # after the head
+    if head & RANGE_LB_INF:
+        min = None
+    else:
+        length = unpack_len(data, pos)[0]
+        pos += 4
+        min = load(data[pos : pos + length])
+        pos += length
+
+    if head & RANGE_UB_INF:
+        max = None
+    else:
+        length = unpack_len(data, pos)[0]
+        pos += 4
+        max = load(data[pos : pos + length])
+        pos += length
+
+    return Range(min, max, lb + ub)
 
 
 def register_range(