From acfab940a423b70e0b17ba50ea9bf0f11cf3f077 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Tue, 4 May 2021 23:08:45 +0200 Subject: [PATCH] Add numeric binary loader --- psycopg3/psycopg3/types/__init__.py | 2 + psycopg3/psycopg3/types/numeric.py | 73 ++++++++++++++++++++++++++- tests/types/test_numeric.py | 76 +++++++++++++++++++++++++++-- 3 files changed, 144 insertions(+), 7 deletions(-) diff --git a/psycopg3/psycopg3/types/__init__.py b/psycopg3/psycopg3/types/__init__.py index 4ebf711a0..5e879da02 100644 --- a/psycopg3/psycopg3/types/__init__.py +++ b/psycopg3/psycopg3/types/__init__.py @@ -69,6 +69,7 @@ from .numeric import ( Float4BinaryLoader as Float4BinaryLoader, Float8BinaryLoader as Float8BinaryLoader, NumericLoader as NumericLoader, + NumericBinaryLoader as NumericBinaryLoader, ) from .singletons import ( BoolDumper as BoolDumper, @@ -190,6 +191,7 @@ def register_default_globals(ctx: AdaptContext) -> None: Float4BinaryLoader.register("float4", ctx) Float8BinaryLoader.register("float8", ctx) NumericLoader.register("numeric", ctx) + NumericBinaryLoader.register("numeric", ctx) BoolDumper.register(bool, ctx) BoolBinaryDumper.register(bool, ctx) diff --git a/psycopg3/psycopg3/types/numeric.py b/psycopg3/psycopg3/types/numeric.py index 36b4e9cd4..65ea5751e 100644 --- a/psycopg3/psycopg3/types/numeric.py +++ b/psycopg3/psycopg3/types/numeric.py @@ -5,9 +5,10 @@ Adapers for numeric types. # Copyright (C) 2020-2021 The Psycopg Team import struct -from typing import Any, Callable, Dict, Tuple, cast -from decimal import Decimal +from typing import Any, Callable, DefaultDict, Dict, Tuple, cast +from decimal import Decimal, DefaultContext, Context +from .. import errors as e from ..pq import Format from ..oids import postgres_types as builtins from ..adapt import Buffer, Dumper, Loader @@ -20,11 +21,13 @@ _UnpackInt = Callable[[bytes], Tuple[int]] _UnpackFloat = Callable[[bytes], Tuple[float]] _pack_int2 = cast(_PackInt, struct.Struct("!h").pack) +_pack_uint2 = cast(_PackInt, struct.Struct("!H").pack) _pack_int4 = cast(_PackInt, struct.Struct("!i").pack) _pack_uint4 = cast(_PackInt, struct.Struct("!I").pack) _pack_int8 = cast(_PackInt, struct.Struct("!q").pack) _pack_float8 = cast(_PackFloat, struct.Struct("!d").pack) _unpack_int2 = cast(_UnpackInt, struct.Struct("!h").unpack) +_unpack_uint2 = cast(_UnpackInt, struct.Struct("!H").unpack) _unpack_int4 = cast(_UnpackInt, struct.Struct("!i").unpack) _unpack_uint4 = cast(_UnpackInt, struct.Struct("!I").unpack) _unpack_int8 = cast(_UnpackInt, struct.Struct("!q").unpack) @@ -267,3 +270,69 @@ class NumericLoader(Loader): if isinstance(data, memoryview): data = bytes(data) return Decimal(data.decode("utf8")) + + +NUMERIC_POS = 0x0000 +NUMERIC_NEG = 0x4000 +NUMERIC_NAN = 0xC000 +NUMERIC_PINF = 0xD000 +NUMERIC_NINF = 0xF000 + +_numeric_special = { + NUMERIC_NAN: Decimal("NaN"), + NUMERIC_PINF: Decimal("Infinity"), + NUMERIC_NINF: Decimal("-Infinity"), +} + + +class _ContextMap(DefaultDict[int, Context]): + """ + Cache for decimal contexts to use when the precision requires it. + + Note: if the default context is used (prec=28) you can get an invalid + operation or a rounding to 0: + + - Decimal(1000).shift(24) = Decimal('1000000000000000000000000000') + - Decimal(1000).shift(25) = Decimal('0') + - Decimal(1000).shift(30) raises InvalidOperation + """ + + def __missing__(self, key: int) -> Context: + val = Context(prec=key) + self[key] = val + return val + + +_contexts = _ContextMap() +for i in range(DefaultContext.prec): + _contexts[i] = DefaultContext + +_unpack_numeric_head = cast( + Callable[[bytes], Tuple[int, int, int, int]], + struct.Struct("!HhHH").unpack_from, +) + + +class NumericBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> Decimal: + ndigits, weight, sign, dscale = _unpack_numeric_head(data) + if sign == NUMERIC_POS or sign == NUMERIC_NEG: + val = 0 + for i in range(8, len(data), 2): + val = val * 10_000 + data[i] * 0x100 + data[i + 1] + + shift = dscale - (ndigits - weight - 1) * 4 + ctx = _contexts[weight * 4 + dscale + 8] + return ( + Decimal(val if sign == NUMERIC_POS else -val) + .scaleb(-dscale, ctx) + .shift(shift, ctx) + ) + else: + try: + return _numeric_special[sign] + except KeyError: + raise e.DataError(f"bad value for numeric sign: 0x{sign:X}") diff --git a/tests/types/test_numeric.py b/tests/types/test_numeric.py index 886a31439..0ec5e1930 100644 --- a/tests/types/test_numeric.py +++ b/tests/types/test_numeric.py @@ -348,12 +348,78 @@ def test_dump_numeric_binary(): tx.get_dumper(n, Format.BINARY).dump(n) -@pytest.mark.xfail -def test_load_numeric_binary(conn): - # TODO: numeric binary casting +@pytest.mark.parametrize( + "expr", + ["nan", "0", "1", "-1", "0.0", "0.01"] + + [ + "0.0000000", + "-1.00000000000000", + "-2.00000000000000", + "1000000000.12345", + "100.123456790000000000000000", + "1.0e-1000", + "1e1000", + "0.000000000000000000000000001", + "1.0000000000000000000000001", + "1000000000000000000000000.001", + "1000000000000000000000000000.001", + "9999999999999999999999999999.9", + ], +) +def test_load_numeric_binary(conn, expr): + cur = conn.cursor(binary=1) + res = cur.execute(f"select '{expr}'::numeric").fetchone()[0] + val = Decimal(expr) + if val.is_nan(): + assert res.is_nan() + else: + assert res == val + if "e" not in expr: + assert str(res) == str(val) + + +@pytest.mark.slow +@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) +def test_load_numeric_exhaustive(conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + + funcs = [ + (lambda i: "1" + "0" * i), + (lambda i: "1" + "0" * i + "." + "0" * i), + (lambda i: "-1" + "0" * i), + (lambda i: "0." + "0" * i + "1"), + (lambda i: "-0." + "0" * i + "1"), + (lambda i: "1." + "0" * i + "1"), + (lambda i: "1." + "0" * i + "10"), + (lambda i: "1" + "0" * i + ".001"), + (lambda i: "9" + "9" * i), + (lambda i: "9" + "." + "9" * i), + (lambda i: "9" + "9" * i + ".9"), + (lambda i: "9" + "9" * i + "." + "9" * i), + ] + + for i in range(100): + for f in funcs: + snum = f(i) + want = Decimal(snum) + got = cur.execute(f"select '{snum}'::decimal").fetchone()[0] + assert want == got + assert str(want) == str(got) + + +@pytest.mark.pg(">= 14") +@pytest.mark.parametrize( + "val, expr", + [ + ("inf", "Infinity"), + ("-inf", "-Infinity"), + ], +) +def test_load_numeric_binary_inf(conn, val, expr): cur = conn.cursor(binary=1) - res = cur.execute("select 1::numeric").fetchone()[0] - assert res == Decimal(1) + res = cur.execute(f"select '{expr}'::numeric").fetchone()[0] + val = Decimal(val) + assert res == val @pytest.mark.parametrize( -- 2.47.2