]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add range binary dumpers
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 9 Jun 2021 00:31:48 +0000 (01:31 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 25 Jun 2021 15:16:26 +0000 (16:16 +0100)
psycopg3/psycopg3/types/__init__.py
psycopg3/psycopg3/types/range.py
tests/types/test_range.py

index 821fe37a631d34e3514a736e64772c2b1a8db166..7e1a0659e900854b9a32189e5d5a89906887ceb6 100644 (file)
@@ -137,6 +137,7 @@ from .network import (
 )
 from .range import (
     RangeDumper as RangeDumper,
+    BinaryRangeDumper as BinaryRangeDumper,
     RangeLoader as RangeLoader,
     Int4RangeLoader as Int4RangeLoader,
     Int8RangeLoader as Int8RangeLoader,
@@ -285,6 +286,7 @@ def register_default_globals(ctx: AdaptContext) -> None:
     CidrBinaryLoader.register("cidr", ctx)
 
     RangeDumper.register(Range, ctx)
+    BinaryRangeDumper.register(Range, ctx)
     Int4RangeLoader.register("int4range", ctx)
     Int8RangeLoader.register("int8range", ctx)
     NumericRangeLoader.register("numrange", ctx)
index bbc5c1e5da53f4496d31f6776dbd132aa75c1df9..f2819afa7170e319290163e3094ca39a3e6fde9c 100644 (file)
@@ -5,15 +5,17 @@ Support for range types adaptation.
 # Copyright (C) 2020-2021 The Psycopg Team
 
 import re
-from typing import cast, Any, Dict, Generic, Optional, Tuple, TypeVar, Type
+from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Type, Union
+from typing import cast
 from decimal import Decimal
 from datetime import date, datetime
 
 from ..pq import Format
 from ..oids import postgres_types as builtins, INVALID_OID
-from ..adapt import Dumper, RecursiveLoader, Format as Pg3Format
+from ..adapt import Dumper, RecursiveDumper, RecursiveLoader
+from ..adapt import Format as Pg3Format
 from ..proto import AdaptContext, Buffer
-from .._struct import unpack_len
+from .._struct import pack_len, unpack_len
 from .._typeinfo import RangeInfo
 
 from .composite import SequenceDumper, BaseCompositeLoader
@@ -24,6 +26,8 @@ RANGE_UB_INC = 0x04  # upper bound is inclusive
 RANGE_LB_INF = 0x08  # lower bound is -infinity
 RANGE_UB_INF = 0x10  # upper bound is +infinity
 
+_EMPTY_HEAD = bytes([RANGE_EMPTY])
+
 T = TypeVar("T")
 
 
@@ -221,60 +225,45 @@ class Range(Generic[T]):
             setattr(self, slot, value)
 
 
-class RangeDumper(SequenceDumper):
-    """
-    Dumper for range types.
-
-    The dumper can upgrade to one specific for a different range type.
-    """
-
-    format = Format.TEXT
-
+class BaseRangeDumper(RecursiveDumper):
     def __init__(self, cls: type, context: Optional[AdaptContext] = None):
         super().__init__(cls, context)
         self.sub_dumper: Optional[Dumper] = None
         self._types = context.adapters.types if context else builtins
-
-    def dump(self, obj: Range[Any]) -> bytes:
-        if not obj:
-            return b"empty"
-        else:
-            return self._dump_sequence(
-                (obj.lower, obj.upper),
-                b"[" if obj.lower_inc else b"(",
-                b"]" if obj.upper_inc else b")",
-                b",",
-            )
-
-    _re_needs_quotes = re.compile(br'[",\\\s()\[\]]')
+        self._adapt_format = Pg3Format.from_pq(self.format)
 
     def get_key(self, obj: Range[Any], format: Pg3Format) -> Tuple[type, ...]:
         item = self._get_item(obj)
         if item is not None:
-            # TODO: binary range support
-            sd = self._tx.get_dumper(item, Pg3Format.TEXT)
+            sd = self._tx.get_dumper(item, self._adapt_format)
             return (self.cls, sd.cls)
         else:
             return (self.cls,)
 
-    def upgrade(self, obj: Range[Any], format: Pg3Format) -> "RangeDumper":
+    def upgrade(self, obj: Range[Any], format: Pg3Format) -> "BaseRangeDumper":
         item = self._get_item(obj)
         if item is None:
             return RangeDumper(self.cls)
 
-        # TODO: binary range support
-        sd = self._tx.get_dumper(item, Pg3Format.TEXT)
-        dumper = type(self)(self.cls, self._tx)
-        dumper.sub_dumper = sd
-        if isinstance(item, int):
+        dumper: BaseRangeDumper
+        if type(item) is int:
             # postgres won't cast int4range -> int8range so we must use
             # text format and unknown oid here
+            sd = self._tx.get_dumper(item, Pg3Format.TEXT)
+            dumper = RangeDumper(self.cls, self._tx)
+            dumper.sub_dumper = sd
             dumper.oid = INVALID_OID
-        elif isinstance(item, str) and sd.oid == INVALID_OID:
+            return dumper
+
+        sd = self._tx.get_dumper(item, format)
+        dumper = type(self)(self.cls, self._tx)
+        dumper.sub_dumper = sd
+        if sd.oid == INVALID_OID and isinstance(item, str):
             # Work around the normal mapping where text is dumped as unknown
             dumper.oid = self._get_range_oid(self._types["text"].oid)
         else:
             dumper.oid = self._get_range_oid(sd.oid)
+
         return dumper
 
     def _get_item(self, obj: Range[Any]) -> Any:
@@ -294,6 +283,67 @@ class RangeDumper(SequenceDumper):
         return info.oid if info else INVALID_OID
 
 
+class RangeDumper(BaseRangeDumper, SequenceDumper):
+    """
+    Dumper for range types.
+
+    The dumper can upgrade to one specific for a different range type.
+    """
+
+    format = Format.TEXT
+
+    def dump(self, obj: Range[Any]) -> bytes:
+        if not obj:
+            return b"empty"
+        else:
+            return self._dump_sequence(
+                (obj.lower, obj.upper),
+                b"[" if obj.lower_inc else b"(",
+                b"]" if obj.upper_inc else b")",
+                b",",
+            )
+
+    _re_needs_quotes = re.compile(br'[",\\\s()\[\]]')
+
+
+class BinaryRangeDumper(BaseRangeDumper):
+
+    format = Format.BINARY
+
+    def dump(self, obj: Range[Any]) -> Union[bytes, bytearray]:
+        if not obj:
+            return _EMPTY_HEAD
+
+        out = bytearray([0])  # will replace the head later
+
+        head = 0
+        if obj.lower_inc:
+            head |= RANGE_LB_INC
+        if obj.upper_inc:
+            head |= RANGE_UB_INC
+
+        item = self._get_item(obj)
+        if item is not None:
+            dump = self._tx.get_dumper(item, self._adapt_format).dump
+
+        if obj.lower is not None:
+            data = dump(obj.lower)
+            out += pack_len(len(data))
+            out += data
+        else:
+            head |= RANGE_LB_INF
+
+        if obj.upper is not None:
+            data = dump(obj.upper)
+            out += pack_len(len(data))
+            out += data
+        else:
+            head |= RANGE_UB_INF
+
+        out[0] = head
+        return out
+
+
 class RangeLoader(BaseCompositeLoader, Generic[T]):
     """Generic loader for a range.
 
@@ -344,7 +394,6 @@ class RangeBinaryLoader(RecursiveLoader, Generic[T]):
             length = unpack_len(data, pos)[0]
             pos += 4
             max = load(data[pos : pos + length])
-            pos += length
 
         return Range(min, max, lb + ub)
 
index 285d02a6278016bc219cc1e3367f0d7eaacd6723..8f624b65617c2e2bcadcf7f4bbd3d84d86dc6826 100644 (file)
@@ -6,6 +6,7 @@ import pytest
 
 from psycopg3 import pq
 from psycopg3.sql import Identifier
+from psycopg3.adapt import Format
 from psycopg3.types import Range, RangeInfo
 
 
@@ -50,9 +51,10 @@ samples = [
     "pgtype",
     "int4range int8range numrange daterange tsrange tstzrange".split(),
 )
-def test_dump_builtin_empty(conn, pgtype):
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_builtin_empty(conn, pgtype, fmt_in):
     r = Range(empty=True)
-    cur = conn.execute(f"select 'empty'::{pgtype} = %s", (r,))
+    cur = conn.execute(f"select 'empty'::{pgtype} = %{fmt_in}", (r,))
     assert cur.fetchone()[0] is True
 
 
@@ -60,22 +62,24 @@ def test_dump_builtin_empty(conn, pgtype):
     "pgtype",
     "int4range int8range numrange daterange tsrange tstzrange".split(),
 )
-def test_dump_builtin_array(conn, pgtype):
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_builtin_array(conn, pgtype, fmt_in):
     r1 = Range(empty=True)
     r2 = Range(bounds="()")
     cur = conn.execute(
-        f"select array['empty'::{pgtype}, '(,)'::{pgtype}] = %s",
+        f"select array['empty'::{pgtype}, '(,)'::{pgtype}] = %{fmt_in}",
         ([r1, r2],),
     )
     assert cur.fetchone()[0] is True
 
 
 @pytest.mark.parametrize("pgtype, min, max, bounds", samples)
-def test_dump_builtin_range(conn, pgtype, min, max, bounds):
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_builtin_range(conn, pgtype, min, max, bounds, fmt_in):
     r = Range(min, max, bounds)
     sub = type2sub[pgtype]
     cur = conn.execute(
-        f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %s::{pgtype}",
+        f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %{fmt_in}",
         (min, max, bounds, r),
     )
     assert cur.fetchone()[0] is True