From: Daniele Varrazzo Date: Tue, 8 Jun 2021 19:08:29 +0000 (+0100) Subject: Add range binary loaders X-Git-Tag: 3.0.dev0~18^2~8 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=48eed7a25652b7a54e22a6726936d475c559f5aa;p=thirdparty%2Fpsycopg.git Add range binary loaders --- diff --git a/psycopg3/psycopg3/types/__init__.py b/psycopg3/psycopg3/types/__init__.py index e3d0ea443..821fe37a6 100644 --- a/psycopg3/psycopg3/types/__init__.py +++ b/psycopg3/psycopg3/types/__init__.py @@ -144,6 +144,13 @@ from .range import ( DateRangeLoader as DateRangeLoader, TimestampRangeLoader as TimestampRangeLoader, TimestampTZRangeLoader as TimestampTZRangeLoader, + RangeBinaryLoader as RangeBinaryLoader, + Int4RangeBinaryLoader as Int4RangeBinaryLoader, + Int8RangeBinaryLoader as Int8RangeBinaryLoader, + NumericRangeBinaryLoader as NumericRangeBinaryLoader, + DateRangeBinaryLoader as DateRangeBinaryLoader, + TimestampRangeBinaryLoader as TimestampRangeBinaryLoader, + TimestampTZRangeBinaryLoader as TimestampTZRangeBinaryLoader, ) from .array import ( ListDumper as ListDumper, @@ -284,6 +291,12 @@ def register_default_globals(ctx: AdaptContext) -> None: DateRangeLoader.register("daterange", ctx) TimestampRangeLoader.register("tsrange", ctx) TimestampTZRangeLoader.register("tstzrange", ctx) + Int4RangeBinaryLoader.register("int4range", ctx) + Int8RangeBinaryLoader.register("int8range", ctx) + NumericRangeBinaryLoader.register("numrange", ctx) + DateRangeBinaryLoader.register("daterange", ctx) + TimestampRangeBinaryLoader.register("tsrange", ctx) + TimestampTZRangeBinaryLoader.register("tstzrange", ctx) ListDumper.register(list, ctx) ListBinaryDumper.register(list, ctx) diff --git a/psycopg3/psycopg3/types/range.py b/psycopg3/psycopg3/types/range.py index fdf996d2f..09ea3ccca 100644 --- a/psycopg3/psycopg3/types/range.py +++ b/psycopg3/psycopg3/types/range.py @@ -11,12 +11,19 @@ from datetime import date, datetime from ..pq import Format from ..oids import postgres_types as builtins, INVALID_OID -from ..adapt import Buffer, Dumper, Format as Pg3Format +from ..adapt import Buffer, Dumper, Loader, Format as Pg3Format, Transformer from ..proto import AdaptContext +from .._struct import unpack_len from .._typeinfo import RangeInfo from .composite import SequenceDumper, BaseCompositeLoader +RANGE_EMPTY = 0x01 # range is empty +RANGE_LB_INC = 0x02 # lower bound is inclusive +RANGE_UB_INC = 0x04 # upper bound is inclusive +RANGE_LB_INF = 0x08 # lower bound is -infinity +RANGE_UB_INF = 0x10 # upper bound is +infinity + T = TypeVar("T") @@ -303,6 +310,44 @@ class RangeLoader(BaseCompositeLoader, Generic[T]): return Range(min, max, bounds) +class RangeBinaryLoader(Loader, Generic[T]): + + format = Format.BINARY + subtype_oid: int + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._tx = Transformer(context) + + def load(self, data: Buffer) -> Range[T]: + head = data[0] + if head & RANGE_EMPTY: + return Range(empty=True) + + load = self._tx.get_loader(self.subtype_oid, format=Format.BINARY).load + lb = "[" if head & RANGE_LB_INC else "(" + ub = "]" if head & RANGE_UB_INC else ")" + + pos = 1 # after the head + if head & RANGE_LB_INF: + min = None + else: + length = unpack_len(data, pos)[0] + pos += 4 + min = load(data[pos : pos + length]) + pos += length + + if head & RANGE_UB_INF: + max = None + else: + length = unpack_len(data, pos)[0] + pos += 4 + max = load(data[pos : pos + length]) + pos += length + + return Range(min, max, lb + ub) + + _int2parens = {ord(c): c for c in "[]()"} @@ -317,8 +362,16 @@ def register_adapters( ) loader.register(info.oid, context=context) + # generate and register a customized binary loader + bloader: Type[RangeBinaryLoader[Any]] = type( + f"{info.name.title()}BinaryLoader", + (RangeBinaryLoader,), + {"subtype_oid": info.subtype_oid}, + ) + bloader.register(info.oid, context=context) -# Loaders for builtin range types + +# Text loaders for builtin range types class Int4RangeLoader(RangeLoader[int]): @@ -343,3 +396,30 @@ class TimestampRangeLoader(RangeLoader[datetime]): class TimestampTZRangeLoader(RangeLoader[datetime]): subtype_oid = builtins["timestamptz"].oid + + +# Binary loaders for builtin range types + + +class Int4RangeBinaryLoader(RangeBinaryLoader[int]): + subtype_oid = builtins["int4"].oid + + +class Int8RangeBinaryLoader(RangeBinaryLoader[int]): + subtype_oid = builtins["int8"].oid + + +class NumericRangeBinaryLoader(RangeBinaryLoader[Decimal]): + subtype_oid = builtins["numeric"].oid + + +class DateRangeBinaryLoader(RangeBinaryLoader[date]): + subtype_oid = builtins["date"].oid + + +class TimestampRangeBinaryLoader(RangeBinaryLoader[datetime]): + subtype_oid = builtins["timestamp"].oid + + +class TimestampTZRangeBinaryLoader(RangeBinaryLoader[datetime]): + subtype_oid = builtins["timestamptz"].oid diff --git a/tests/types/test_range.py b/tests/types/test_range.py index f170cbb9e..b8e161746 100644 --- a/tests/types/test_range.py +++ b/tests/types/test_range.py @@ -4,6 +4,7 @@ from decimal import Decimal import pytest +from psycopg3 import pq from psycopg3.sql import Identifier from psycopg3.types import Range, RangeInfo @@ -27,6 +28,8 @@ samples = [ ("int8range", 10, 20, "[)"), ("int8range", -(2 ** 63), (2 ** 63) - 1, "[)"), ("numrange", Decimal(-100), Decimal("100.123"), "(]"), + ("numrange", Decimal(100), None, "()"), + ("numrange", None, Decimal(100), "()"), ("daterange", dt.date(2000, 1, 1), dt.date(2020, 1, 1), "[)"), ( "tsrange", @@ -82,9 +85,11 @@ def test_dump_builtin_range(conn, pgtype, min, max, bounds): "pgtype", "int4range int8range numrange daterange tsrange tstzrange".split(), ) -def test_load_builtin_empty(conn, pgtype): +@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) +def test_load_builtin_empty(conn, pgtype, fmt_out): r = Range(empty=True) - (got,) = conn.execute(f"select 'empty'::{pgtype}").fetchone() + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute(f"select 'empty'::{pgtype}").fetchone() assert type(got) is Range assert got == r assert not got @@ -95,9 +100,11 @@ def test_load_builtin_empty(conn, pgtype): "pgtype", "int4range int8range numrange daterange tsrange tstzrange".split(), ) -def test_load_builtin_inf(conn, pgtype): +@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) +def test_load_builtin_inf(conn, pgtype, fmt_out): r = Range(bounds="()") - (got,) = conn.execute(f"select '(,)'::{pgtype}").fetchone() + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute(f"select '(,)'::{pgtype}").fetchone() assert type(got) is Range assert got == r assert got @@ -110,20 +117,24 @@ def test_load_builtin_inf(conn, pgtype): "pgtype", "int4range int8range numrange daterange tsrange tstzrange".split(), ) -def test_load_builtin_array(conn, pgtype): +@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) +def test_load_builtin_array(conn, pgtype, fmt_out): r1 = Range(empty=True) r2 = Range(bounds="()") - (got,) = conn.execute( + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute( f"select array['empty'::{pgtype}, '(,)'::{pgtype}]" ).fetchone() assert got == [r1, r2] @pytest.mark.parametrize("pgtype, min, max, bounds", samples) -def test_load_builtin_range(conn, pgtype, min, max, bounds): +@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) +def test_load_builtin_range(conn, pgtype, min, max, bounds, fmt_out): r = Range(min, max, bounds) sub = type2sub[pgtype] - cur = conn.execute( + cur = conn.cursor(binary=fmt_out) + cur.execute( f"select {pgtype}(%s::{sub}, %s::{sub}, %s)", (min, max, bounds) ) # normalise discrete ranges @@ -210,19 +221,22 @@ def test_dump_quoting(conn, testrange): assert cur.fetchone()[0] is True -def test_load_custom_empty(conn, testrange): +@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) +def test_load_custom_empty(conn, testrange, fmt_out): info = RangeInfo.fetch(conn, "testrange") info.register(conn) - (got,) = conn.execute("select 'empty'::testrange").fetchone() + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute("select 'empty'::testrange").fetchone() assert isinstance(got, Range) assert got.isempty -def test_load_quoting(conn, testrange): +@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) +def test_load_quoting(conn, testrange, fmt_out): info = RangeInfo.fetch(conn, "testrange") info.register(conn) - cur = conn.cursor() + cur = conn.cursor(binary=fmt_out) for i in range(1, 254): cur.execute( "select testrange(chr(%(low)s::int), chr(%(up)s::int))",