]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add range binary loaders
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 8 Jun 2021 19:08:29 +0000 (20:08 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 25 Jun 2021 15:16:26 +0000 (16:16 +0100)
psycopg3/psycopg3/types/__init__.py
psycopg3/psycopg3/types/range.py
tests/types/test_range.py

index e3d0ea4436e9fc2de4fb9122470209c0db382ffd..821fe37a631d34e3514a736e64772c2b1a8db166 100644 (file)
@@ -144,6 +144,13 @@ from .range import (
     DateRangeLoader as DateRangeLoader,
     TimestampRangeLoader as TimestampRangeLoader,
     TimestampTZRangeLoader as TimestampTZRangeLoader,
+    RangeBinaryLoader as RangeBinaryLoader,
+    Int4RangeBinaryLoader as Int4RangeBinaryLoader,
+    Int8RangeBinaryLoader as Int8RangeBinaryLoader,
+    NumericRangeBinaryLoader as NumericRangeBinaryLoader,
+    DateRangeBinaryLoader as DateRangeBinaryLoader,
+    TimestampRangeBinaryLoader as TimestampRangeBinaryLoader,
+    TimestampTZRangeBinaryLoader as TimestampTZRangeBinaryLoader,
 )
 from .array import (
     ListDumper as ListDumper,
@@ -284,6 +291,12 @@ def register_default_globals(ctx: AdaptContext) -> None:
     DateRangeLoader.register("daterange", ctx)
     TimestampRangeLoader.register("tsrange", ctx)
     TimestampTZRangeLoader.register("tstzrange", ctx)
+    Int4RangeBinaryLoader.register("int4range", ctx)
+    Int8RangeBinaryLoader.register("int8range", ctx)
+    NumericRangeBinaryLoader.register("numrange", ctx)
+    DateRangeBinaryLoader.register("daterange", ctx)
+    TimestampRangeBinaryLoader.register("tsrange", ctx)
+    TimestampTZRangeBinaryLoader.register("tstzrange", ctx)
 
     ListDumper.register(list, ctx)
     ListBinaryDumper.register(list, ctx)
index fdf996d2f2fbf7c5c4430910f5e86462ffe76f42..09ea3ccca08644dd9d42150d9aed6e33a826d43a 100644 (file)
@@ -11,12 +11,19 @@ from datetime import date, datetime
 
 from ..pq import Format
 from ..oids import postgres_types as builtins, INVALID_OID
-from ..adapt import Buffer, Dumper, Format as Pg3Format
+from ..adapt import Buffer, Dumper, Loader, Format as Pg3Format, Transformer
 from ..proto import AdaptContext
+from .._struct import unpack_len
 from .._typeinfo import RangeInfo
 
 from .composite import SequenceDumper, BaseCompositeLoader
 
+RANGE_EMPTY = 0x01  # range is empty
+RANGE_LB_INC = 0x02  # lower bound is inclusive
+RANGE_UB_INC = 0x04  # upper bound is inclusive
+RANGE_LB_INF = 0x08  # lower bound is -infinity
+RANGE_UB_INF = 0x10  # upper bound is +infinity
+
 T = TypeVar("T")
 
 
@@ -303,6 +310,44 @@ class RangeLoader(BaseCompositeLoader, Generic[T]):
         return Range(min, max, bounds)
 
 
+class RangeBinaryLoader(Loader, Generic[T]):
+
+    format = Format.BINARY
+    subtype_oid: int
+
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+        super().__init__(oid, context)
+        self._tx = Transformer(context)
+
+    def load(self, data: Buffer) -> Range[T]:
+        head = data[0]
+        if head & RANGE_EMPTY:
+            return Range(empty=True)
+
+        load = self._tx.get_loader(self.subtype_oid, format=Format.BINARY).load
+        lb = "[" if head & RANGE_LB_INC else "("
+        ub = "]" if head & RANGE_UB_INC else ")"
+
+        pos = 1  # after the head
+        if head & RANGE_LB_INF:
+            min = None
+        else:
+            length = unpack_len(data, pos)[0]
+            pos += 4
+            min = load(data[pos : pos + length])
+            pos += length
+
+        if head & RANGE_UB_INF:
+            max = None
+        else:
+            length = unpack_len(data, pos)[0]
+            pos += 4
+            max = load(data[pos : pos + length])
+            pos += length
+
+        return Range(min, max, lb + ub)
+
+
 _int2parens = {ord(c): c for c in "[]()"}
 
 
@@ -317,8 +362,16 @@ def register_adapters(
     )
     loader.register(info.oid, context=context)
 
+    # generate and register a customized binary loader
+    bloader: Type[RangeBinaryLoader[Any]] = type(
+        f"{info.name.title()}BinaryLoader",
+        (RangeBinaryLoader,),
+        {"subtype_oid": info.subtype_oid},
+    )
+    bloader.register(info.oid, context=context)
 
-# Loaders for builtin range types
+
+# Text loaders for builtin range types
 
 
 class Int4RangeLoader(RangeLoader[int]):
@@ -343,3 +396,30 @@ class TimestampRangeLoader(RangeLoader[datetime]):
 
 class TimestampTZRangeLoader(RangeLoader[datetime]):
     subtype_oid = builtins["timestamptz"].oid
+
+
+# Binary loaders for builtin range types
+
+
+class Int4RangeBinaryLoader(RangeBinaryLoader[int]):
+    subtype_oid = builtins["int4"].oid
+
+
+class Int8RangeBinaryLoader(RangeBinaryLoader[int]):
+    subtype_oid = builtins["int8"].oid
+
+
+class NumericRangeBinaryLoader(RangeBinaryLoader[Decimal]):
+    subtype_oid = builtins["numeric"].oid
+
+
+class DateRangeBinaryLoader(RangeBinaryLoader[date]):
+    subtype_oid = builtins["date"].oid
+
+
+class TimestampRangeBinaryLoader(RangeBinaryLoader[datetime]):
+    subtype_oid = builtins["timestamp"].oid
+
+
+class TimestampTZRangeBinaryLoader(RangeBinaryLoader[datetime]):
+    subtype_oid = builtins["timestamptz"].oid
index f170cbb9e9bea286f72c29f6fdfd49b6c2bb4e26..b8e1617464192353ca07b61c2dca01c91d8bc8c3 100644 (file)
@@ -4,6 +4,7 @@ from decimal import Decimal
 
 import pytest
 
+from psycopg3 import pq
 from psycopg3.sql import Identifier
 from psycopg3.types import Range, RangeInfo
 
@@ -27,6 +28,8 @@ samples = [
     ("int8range", 10, 20, "[)"),
     ("int8range", -(2 ** 63), (2 ** 63) - 1, "[)"),
     ("numrange", Decimal(-100), Decimal("100.123"), "(]"),
+    ("numrange", Decimal(100), None, "()"),
+    ("numrange", None, Decimal(100), "()"),
     ("daterange", dt.date(2000, 1, 1), dt.date(2020, 1, 1), "[)"),
     (
         "tsrange",
@@ -82,9 +85,11 @@ def test_dump_builtin_range(conn, pgtype, min, max, bounds):
     "pgtype",
     "int4range int8range numrange daterange tsrange tstzrange".split(),
 )
-def test_load_builtin_empty(conn, pgtype):
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_load_builtin_empty(conn, pgtype, fmt_out):
     r = Range(empty=True)
-    (got,) = conn.execute(f"select 'empty'::{pgtype}").fetchone()
+    cur = conn.cursor(binary=fmt_out)
+    (got,) = cur.execute(f"select 'empty'::{pgtype}").fetchone()
     assert type(got) is Range
     assert got == r
     assert not got
@@ -95,9 +100,11 @@ def test_load_builtin_empty(conn, pgtype):
     "pgtype",
     "int4range int8range numrange daterange tsrange tstzrange".split(),
 )
-def test_load_builtin_inf(conn, pgtype):
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_load_builtin_inf(conn, pgtype, fmt_out):
     r = Range(bounds="()")
-    (got,) = conn.execute(f"select '(,)'::{pgtype}").fetchone()
+    cur = conn.cursor(binary=fmt_out)
+    (got,) = cur.execute(f"select '(,)'::{pgtype}").fetchone()
     assert type(got) is Range
     assert got == r
     assert got
@@ -110,20 +117,24 @@ def test_load_builtin_inf(conn, pgtype):
     "pgtype",
     "int4range int8range numrange daterange tsrange tstzrange".split(),
 )
-def test_load_builtin_array(conn, pgtype):
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_load_builtin_array(conn, pgtype, fmt_out):
     r1 = Range(empty=True)
     r2 = Range(bounds="()")
-    (got,) = conn.execute(
+    cur = conn.cursor(binary=fmt_out)
+    (got,) = cur.execute(
         f"select array['empty'::{pgtype}, '(,)'::{pgtype}]"
     ).fetchone()
     assert got == [r1, r2]
 
 
 @pytest.mark.parametrize("pgtype, min, max, bounds", samples)
-def test_load_builtin_range(conn, pgtype, min, max, bounds):
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_load_builtin_range(conn, pgtype, min, max, bounds, fmt_out):
     r = Range(min, max, bounds)
     sub = type2sub[pgtype]
-    cur = conn.execute(
+    cur = conn.cursor(binary=fmt_out)
+    cur.execute(
         f"select {pgtype}(%s::{sub}, %s::{sub}, %s)", (min, max, bounds)
     )
     # normalise discrete ranges
@@ -210,19 +221,22 @@ def test_dump_quoting(conn, testrange):
         assert cur.fetchone()[0] is True
 
 
-def test_load_custom_empty(conn, testrange):
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_load_custom_empty(conn, testrange, fmt_out):
     info = RangeInfo.fetch(conn, "testrange")
     info.register(conn)
 
-    (got,) = conn.execute("select 'empty'::testrange").fetchone()
+    cur = conn.cursor(binary=fmt_out)
+    (got,) = cur.execute("select 'empty'::testrange").fetchone()
     assert isinstance(got, Range)
     assert got.isempty
 
 
-def test_load_quoting(conn, testrange):
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_load_quoting(conn, testrange, fmt_out):
     info = RangeInfo.fetch(conn, "testrange")
     info.register(conn)
-    cur = conn.cursor()
+    cur = conn.cursor(binary=fmt_out)
     for i in range(1, 254):
         cur.execute(
             "select testrange(chr(%(low)s::int), chr(%(up)s::int))",