From: Daniele Varrazzo Date: Wed, 9 Jun 2021 00:31:48 +0000 (+0100) Subject: Add range binary dumpers X-Git-Tag: 3.0.dev0~18^2~5 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=0cff350bf33a87db534500135831e986d4c74cbe;p=thirdparty%2Fpsycopg.git Add range binary dumpers --- diff --git a/psycopg3/psycopg3/types/__init__.py b/psycopg3/psycopg3/types/__init__.py index 821fe37a6..7e1a0659e 100644 --- a/psycopg3/psycopg3/types/__init__.py +++ b/psycopg3/psycopg3/types/__init__.py @@ -137,6 +137,7 @@ from .network import ( ) from .range import ( RangeDumper as RangeDumper, + BinaryRangeDumper as BinaryRangeDumper, RangeLoader as RangeLoader, Int4RangeLoader as Int4RangeLoader, Int8RangeLoader as Int8RangeLoader, @@ -285,6 +286,7 @@ def register_default_globals(ctx: AdaptContext) -> None: CidrBinaryLoader.register("cidr", ctx) RangeDumper.register(Range, ctx) + BinaryRangeDumper.register(Range, ctx) Int4RangeLoader.register("int4range", ctx) Int8RangeLoader.register("int8range", ctx) NumericRangeLoader.register("numrange", ctx) diff --git a/psycopg3/psycopg3/types/range.py b/psycopg3/psycopg3/types/range.py index bbc5c1e5d..f2819afa7 100644 --- a/psycopg3/psycopg3/types/range.py +++ b/psycopg3/psycopg3/types/range.py @@ -5,15 +5,17 @@ Support for range types adaptation. # Copyright (C) 2020-2021 The Psycopg Team import re -from typing import cast, Any, Dict, Generic, Optional, Tuple, TypeVar, Type +from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Type, Union +from typing import cast from decimal import Decimal from datetime import date, datetime from ..pq import Format from ..oids import postgres_types as builtins, INVALID_OID -from ..adapt import Dumper, RecursiveLoader, Format as Pg3Format +from ..adapt import Dumper, RecursiveDumper, RecursiveLoader +from ..adapt import Format as Pg3Format from ..proto import AdaptContext, Buffer -from .._struct import unpack_len +from .._struct import pack_len, unpack_len from .._typeinfo import RangeInfo from .composite import SequenceDumper, BaseCompositeLoader @@ -24,6 +26,8 @@ RANGE_UB_INC = 0x04 # upper bound is inclusive RANGE_LB_INF = 0x08 # lower bound is -infinity RANGE_UB_INF = 0x10 # upper bound is +infinity +_EMPTY_HEAD = bytes([RANGE_EMPTY]) + T = TypeVar("T") @@ -221,60 +225,45 @@ class Range(Generic[T]): setattr(self, slot, value) -class RangeDumper(SequenceDumper): - """ - Dumper for range types. - - The dumper can upgrade to one specific for a different range type. - """ - - format = Format.TEXT - +class BaseRangeDumper(RecursiveDumper): def __init__(self, cls: type, context: Optional[AdaptContext] = None): super().__init__(cls, context) self.sub_dumper: Optional[Dumper] = None self._types = context.adapters.types if context else builtins - - def dump(self, obj: Range[Any]) -> bytes: - if not obj: - return b"empty" - else: - return self._dump_sequence( - (obj.lower, obj.upper), - b"[" if obj.lower_inc else b"(", - b"]" if obj.upper_inc else b")", - b",", - ) - - _re_needs_quotes = re.compile(br'[",\\\s()\[\]]') + self._adapt_format = Pg3Format.from_pq(self.format) def get_key(self, obj: Range[Any], format: Pg3Format) -> Tuple[type, ...]: item = self._get_item(obj) if item is not None: - # TODO: binary range support - sd = self._tx.get_dumper(item, Pg3Format.TEXT) + sd = self._tx.get_dumper(item, self._adapt_format) return (self.cls, sd.cls) else: return (self.cls,) - def upgrade(self, obj: Range[Any], format: Pg3Format) -> "RangeDumper": + def upgrade(self, obj: Range[Any], format: Pg3Format) -> "BaseRangeDumper": item = self._get_item(obj) if item is None: return RangeDumper(self.cls) - # TODO: binary range support - sd = self._tx.get_dumper(item, Pg3Format.TEXT) - dumper = type(self)(self.cls, self._tx) - dumper.sub_dumper = sd - if isinstance(item, int): + dumper: BaseRangeDumper + if type(item) is int: # postgres won't cast int4range -> int8range so we must use # text format and unknown oid here + sd = self._tx.get_dumper(item, Pg3Format.TEXT) + dumper = RangeDumper(self.cls, self._tx) + dumper.sub_dumper = sd dumper.oid = INVALID_OID - elif isinstance(item, str) and sd.oid == INVALID_OID: + return dumper + + sd = self._tx.get_dumper(item, format) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + if sd.oid == INVALID_OID and isinstance(item, str): # Work around the normal mapping where text is dumped as unknown dumper.oid = self._get_range_oid(self._types["text"].oid) else: dumper.oid = self._get_range_oid(sd.oid) + return dumper def _get_item(self, obj: Range[Any]) -> Any: @@ -294,6 +283,67 @@ class RangeDumper(SequenceDumper): return info.oid if info else INVALID_OID +class RangeDumper(BaseRangeDumper, SequenceDumper): + """ + Dumper for range types. + + The dumper can upgrade to one specific for a different range type. + """ + + format = Format.TEXT + + def dump(self, obj: Range[Any]) -> bytes: + if not obj: + return b"empty" + else: + return self._dump_sequence( + (obj.lower, obj.upper), + b"[" if obj.lower_inc else b"(", + b"]" if obj.upper_inc else b")", + b",", + ) + + _re_needs_quotes = re.compile(br'[",\\\s()\[\]]') + + +class BinaryRangeDumper(BaseRangeDumper): + + format = Format.BINARY + + def dump(self, obj: Range[Any]) -> Union[bytes, bytearray]: + if not obj: + return _EMPTY_HEAD + + out = bytearray([0]) # will replace the head later + + head = 0 + if obj.lower_inc: + head |= RANGE_LB_INC + if obj.upper_inc: + head |= RANGE_UB_INC + + item = self._get_item(obj) + if item is not None: + dump = self._tx.get_dumper(item, self._adapt_format).dump + + if obj.lower is not None: + data = dump(obj.lower) + out += pack_len(len(data)) + out += data + else: + head |= RANGE_LB_INF + + if obj.upper is not None: + data = dump(obj.upper) + out += pack_len(len(data)) + out += data + else: + head |= RANGE_UB_INF + + out[0] = head + return out + + class RangeLoader(BaseCompositeLoader, Generic[T]): """Generic loader for a range. @@ -344,7 +394,6 @@ class RangeBinaryLoader(RecursiveLoader, Generic[T]): length = unpack_len(data, pos)[0] pos += 4 max = load(data[pos : pos + length]) - pos += length return Range(min, max, lb + ub) diff --git a/tests/types/test_range.py b/tests/types/test_range.py index 285d02a62..8f624b656 100644 --- a/tests/types/test_range.py +++ b/tests/types/test_range.py @@ -6,6 +6,7 @@ import pytest from psycopg3 import pq from psycopg3.sql import Identifier +from psycopg3.adapt import Format from psycopg3.types import Range, RangeInfo @@ -50,9 +51,10 @@ samples = [ "pgtype", "int4range int8range numrange daterange tsrange tstzrange".split(), ) -def test_dump_builtin_empty(conn, pgtype): +@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) +def test_dump_builtin_empty(conn, pgtype, fmt_in): r = Range(empty=True) - cur = conn.execute(f"select 'empty'::{pgtype} = %s", (r,)) + cur = conn.execute(f"select 'empty'::{pgtype} = %{fmt_in}", (r,)) assert cur.fetchone()[0] is True @@ -60,22 +62,24 @@ def test_dump_builtin_empty(conn, pgtype): "pgtype", "int4range int8range numrange daterange tsrange tstzrange".split(), ) -def test_dump_builtin_array(conn, pgtype): +@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) +def test_dump_builtin_array(conn, pgtype, fmt_in): r1 = Range(empty=True) r2 = Range(bounds="()") cur = conn.execute( - f"select array['empty'::{pgtype}, '(,)'::{pgtype}] = %s", + f"select array['empty'::{pgtype}, '(,)'::{pgtype}] = %{fmt_in}", ([r1, r2],), ) assert cur.fetchone()[0] is True @pytest.mark.parametrize("pgtype, min, max, bounds", samples) -def test_dump_builtin_range(conn, pgtype, min, max, bounds): +@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) +def test_dump_builtin_range(conn, pgtype, min, max, bounds, fmt_in): r = Range(min, max, bounds) sub = type2sub[pgtype] cur = conn.execute( - f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %s::{pgtype}", + f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %{fmt_in}", (min, max, bounds, r), ) assert cur.fetchone()[0] is True