]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add numeric binary loader
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 4 May 2021 21:08:45 +0000 (23:08 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 4 May 2021 21:47:49 +0000 (23:47 +0200)
psycopg3/psycopg3/types/__init__.py
psycopg3/psycopg3/types/numeric.py
tests/types/test_numeric.py

index 4ebf711a067c3d38083a324c00a60d24cf33b877..5e879da02cabe8d6969846d164afb91e3152726d 100644 (file)
@@ -69,6 +69,7 @@ from .numeric import (
     Float4BinaryLoader as Float4BinaryLoader,
     Float8BinaryLoader as Float8BinaryLoader,
     NumericLoader as NumericLoader,
+    NumericBinaryLoader as NumericBinaryLoader,
 )
 from .singletons import (
     BoolDumper as BoolDumper,
@@ -190,6 +191,7 @@ def register_default_globals(ctx: AdaptContext) -> None:
     Float4BinaryLoader.register("float4", ctx)
     Float8BinaryLoader.register("float8", ctx)
     NumericLoader.register("numeric", ctx)
+    NumericBinaryLoader.register("numeric", ctx)
 
     BoolDumper.register(bool, ctx)
     BoolBinaryDumper.register(bool, ctx)
index 36b4e9cd4eae62ef3ee92b854476e096192c75d6..65ea5751ebe7e67e542b60ce93b23ffb7485fd7b 100644 (file)
@@ -5,9 +5,10 @@ Adapers for numeric types.
 # Copyright (C) 2020-2021 The Psycopg Team
 
 import struct
-from typing import Any, Callable, Dict, Tuple, cast
-from decimal import Decimal
+from typing import Any, Callable, DefaultDict, Dict, Tuple, cast
+from decimal import Decimal, DefaultContext, Context
 
+from .. import errors as e
 from ..pq import Format
 from ..oids import postgres_types as builtins
 from ..adapt import Buffer, Dumper, Loader
@@ -20,11 +21,13 @@ _UnpackInt = Callable[[bytes], Tuple[int]]
 _UnpackFloat = Callable[[bytes], Tuple[float]]
 
 _pack_int2 = cast(_PackInt, struct.Struct("!h").pack)
+_pack_uint2 = cast(_PackInt, struct.Struct("!H").pack)
 _pack_int4 = cast(_PackInt, struct.Struct("!i").pack)
 _pack_uint4 = cast(_PackInt, struct.Struct("!I").pack)
 _pack_int8 = cast(_PackInt, struct.Struct("!q").pack)
 _pack_float8 = cast(_PackFloat, struct.Struct("!d").pack)
 _unpack_int2 = cast(_UnpackInt, struct.Struct("!h").unpack)
+_unpack_uint2 = cast(_UnpackInt, struct.Struct("!H").unpack)
 _unpack_int4 = cast(_UnpackInt, struct.Struct("!i").unpack)
 _unpack_uint4 = cast(_UnpackInt, struct.Struct("!I").unpack)
 _unpack_int8 = cast(_UnpackInt, struct.Struct("!q").unpack)
@@ -267,3 +270,69 @@ class NumericLoader(Loader):
         if isinstance(data, memoryview):
             data = bytes(data)
         return Decimal(data.decode("utf8"))
+
+
+NUMERIC_POS = 0x0000
+NUMERIC_NEG = 0x4000
+NUMERIC_NAN = 0xC000
+NUMERIC_PINF = 0xD000
+NUMERIC_NINF = 0xF000
+
+_numeric_special = {
+    NUMERIC_NAN: Decimal("NaN"),
+    NUMERIC_PINF: Decimal("Infinity"),
+    NUMERIC_NINF: Decimal("-Infinity"),
+}
+
+
+class _ContextMap(DefaultDict[int, Context]):
+    """
+    Cache for decimal contexts to use when the precision requires it.
+
+    Note: if the default context is used (prec=28) you can get an invalid
+    operation or a rounding to 0:
+
+    - Decimal(1000).shift(24) = Decimal('1000000000000000000000000000')
+    - Decimal(1000).shift(25) = Decimal('0')
+    - Decimal(1000).shift(30) raises InvalidOperation
+    """
+
+    def __missing__(self, key: int) -> Context:
+        val = Context(prec=key)
+        self[key] = val
+        return val
+
+
+_contexts = _ContextMap()
+for i in range(DefaultContext.prec):
+    _contexts[i] = DefaultContext
+
+_unpack_numeric_head = cast(
+    Callable[[bytes], Tuple[int, int, int, int]],
+    struct.Struct("!HhHH").unpack_from,
+)
+
+
+class NumericBinaryLoader(Loader):
+
+    format = Format.BINARY
+
+    def load(self, data: Buffer) -> Decimal:
+        ndigits, weight, sign, dscale = _unpack_numeric_head(data)
+        if sign == NUMERIC_POS or sign == NUMERIC_NEG:
+            val = 0
+            for i in range(8, len(data), 2):
+                val = val * 10_000 + data[i] * 0x100 + data[i + 1]
+
+            shift = dscale - (ndigits - weight - 1) * 4
+            ctx = _contexts[weight * 4 + dscale + 8]
+            return (
+                Decimal(val if sign == NUMERIC_POS else -val)
+                .scaleb(-dscale, ctx)
+                .shift(shift, ctx)
+            )
+        else:
+            try:
+                return _numeric_special[sign]
+            except KeyError:
+                raise e.DataError(f"bad value for numeric sign: 0x{sign:X}")
index 886a3143943671b8201a52dd69a89b6cd854acd2..0ec5e193089d1a3d002e64f7119abdd77a46816f 100644 (file)
@@ -348,12 +348,78 @@ def test_dump_numeric_binary():
     tx.get_dumper(n, Format.BINARY).dump(n)
 
 
-@pytest.mark.xfail
-def test_load_numeric_binary(conn):
-    # TODO: numeric binary casting
+@pytest.mark.parametrize(
+    "expr",
+    ["nan", "0", "1", "-1", "0.0", "0.01"]
+    + [
+        "0.0000000",
+        "-1.00000000000000",
+        "-2.00000000000000",
+        "1000000000.12345",
+        "100.123456790000000000000000",
+        "1.0e-1000",
+        "1e1000",
+        "0.000000000000000000000000001",
+        "1.0000000000000000000000001",
+        "1000000000000000000000000.001",
+        "1000000000000000000000000000.001",
+        "9999999999999999999999999999.9",
+    ],
+)
+def test_load_numeric_binary(conn, expr):
+    cur = conn.cursor(binary=1)
+    res = cur.execute(f"select '{expr}'::numeric").fetchone()[0]
+    val = Decimal(expr)
+    if val.is_nan():
+        assert res.is_nan()
+    else:
+        assert res == val
+        if "e" not in expr:
+            assert str(res) == str(val)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_load_numeric_exhaustive(conn, fmt_out):
+    cur = conn.cursor(binary=fmt_out)
+
+    funcs = [
+        (lambda i: "1" + "0" * i),
+        (lambda i: "1" + "0" * i + "." + "0" * i),
+        (lambda i: "-1" + "0" * i),
+        (lambda i: "0." + "0" * i + "1"),
+        (lambda i: "-0." + "0" * i + "1"),
+        (lambda i: "1." + "0" * i + "1"),
+        (lambda i: "1." + "0" * i + "10"),
+        (lambda i: "1" + "0" * i + ".001"),
+        (lambda i: "9" + "9" * i),
+        (lambda i: "9" + "." + "9" * i),
+        (lambda i: "9" + "9" * i + ".9"),
+        (lambda i: "9" + "9" * i + "." + "9" * i),
+    ]
+
+    for i in range(100):
+        for f in funcs:
+            snum = f(i)
+            want = Decimal(snum)
+            got = cur.execute(f"select '{snum}'::decimal").fetchone()[0]
+            assert want == got
+            assert str(want) == str(got)
+
+
+@pytest.mark.pg(">= 14")
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        ("inf", "Infinity"),
+        ("-inf", "-Infinity"),
+    ],
+)
+def test_load_numeric_binary_inf(conn, val, expr):
     cur = conn.cursor(binary=1)
-    res = cur.execute("select 1::numeric").fetchone()[0]
-    assert res == Decimal(1)
+    res = cur.execute(f"select '{expr}'::numeric").fetchone()[0]
+    val = Decimal(val)
+    assert res == val
 
 
 @pytest.mark.parametrize(