]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add prototype for a decimal binary dumper
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 5 May 2021 17:15:02 +0000 (19:15 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 6 May 2021 01:49:23 +0000 (03:49 +0200)
The algorithm is pretty unwieldy, maybe it can be refactored. But tests
show that it should be correct.

psycopg3/psycopg3/types/__init__.py
psycopg3/psycopg3/types/numeric.py
tests/types/test_numeric.py

index 5e879da02cabe8d6969846d164afb91e3152726d..3a14205b0da2f3fffe2a3839f7c6ee9f8f85eae4 100644 (file)
@@ -51,6 +51,7 @@ from .numeric import (
     FloatDumper as FloatDumper,
     FloatBinaryDumper as FloatBinaryDumper,
     DecimalDumper as DecimalDumper,
+    DecimalBinaryDumper as DecimalBinaryDumper,
     Int2Dumper as Int2Dumper,
     Int4Dumper as Int4Dumper,
     Int8Dumper as Int8Dumper,
@@ -168,6 +169,10 @@ def register_default_globals(ctx: AdaptContext) -> None:
     IntBinaryDumper.register(int, ctx)
     FloatDumper.register(float, ctx)
     FloatBinaryDumper.register(float, ctx)
+    # TODO: benchmark to work out if the binary dumper is faster
+    # (the binary format is usually larger)
+    # for now leaving the text format as default.
+    DecimalBinaryDumper.register("decimal.Decimal", ctx)
     DecimalDumper.register("decimal.Decimal", ctx)
     Int2Dumper.register(Int2, ctx)
     Int4Dumper.register(Int4, ctx)
index 65ea5751ebe7e67e542b60ce93b23ffb7485fd7b..4d07aaaa058e2c7bf847a9f2d558a5f1cfc07593 100644 (file)
@@ -5,7 +5,7 @@ Adapers for numeric types.
 # Copyright (C) 2020-2021 The Psycopg Team
 
 import struct
-from typing import Any, Callable, DefaultDict, Dict, Tuple, cast
+from typing import Any, Callable, DefaultDict, Dict, Tuple, Union, cast
 from decimal import Decimal, DefaultContext, Context
 
 from .. import errors as e
@@ -272,13 +272,14 @@ class NumericLoader(Loader):
         return Decimal(data.decode("utf8"))
 
 
+DEC_DIGITS = 4  # decimal digits per Postgres "digit"
 NUMERIC_POS = 0x0000
 NUMERIC_NEG = 0x4000
 NUMERIC_NAN = 0xC000
 NUMERIC_PINF = 0xD000
 NUMERIC_NINF = 0xF000
 
-_numeric_special = {
+_decimal_special = {
     NUMERIC_NAN: Decimal("NaN"),
     NUMERIC_PINF: Decimal("Infinity"),
     NUMERIC_NINF: Decimal("-Infinity"),
@@ -311,6 +312,10 @@ _unpack_numeric_head = cast(
     Callable[[bytes], Tuple[int, int, int, int]],
     struct.Struct("!HhHH").unpack_from,
 )
+_pack_numeric_head = cast(
+    Callable[[int, int, int, int], bytes],
+    struct.Struct("!HhHH").pack,
+)
 
 
 class NumericBinaryLoader(Loader):
@@ -324,8 +329,8 @@ class NumericBinaryLoader(Loader):
             for i in range(8, len(data), 2):
                 val = val * 10_000 + data[i] * 0x100 + data[i + 1]
 
-            shift = dscale - (ndigits - weight - 1) * 4
-            ctx = _contexts[weight * 4 + dscale + 8]
+            shift = dscale - (ndigits - weight - 1) * DEC_DIGITS
+            ctx = _contexts[(weight + 2) * DEC_DIGITS + dscale]
             return (
                 Decimal(val if sign == NUMERIC_POS else -val)
                 .scaleb(-dscale, ctx)
@@ -333,6 +338,78 @@ class NumericBinaryLoader(Loader):
             )
         else:
             try:
-                return _numeric_special[sign]
+                return _decimal_special[sign]
             except KeyError:
                 raise e.DataError(f"bad value for numeric sign: 0x{sign:X}")
+
+
+NUMERIC_NAN_BIN = _pack_numeric_head(0, 0, NUMERIC_NAN, 0)
+NUMERIC_PINF_BIN = _pack_numeric_head(0, 0, NUMERIC_PINF, 0)
+NUMERIC_NINF_BIN = _pack_numeric_head(0, 0, NUMERIC_NINF, 0)
+
+
+class DecimalBinaryDumper(Dumper):
+
+    format = Format.BINARY
+    _oid = builtins["numeric"].oid
+
+    def dump(self, obj: Decimal) -> Union[bytearray, bytes]:
+        sign, digits, exp = obj.as_tuple()
+        if exp == "n":  # type: ignore[comparison-overlap]
+            return NUMERIC_NAN_BIN
+        elif exp == "F":  # type: ignore[comparison-overlap]
+            return NUMERIC_NINF_BIN if sign else NUMERIC_PINF_BIN
+
+        if exp < 0:
+            dscale = -exp
+
+            # left pad with 0 to align the py digits to the pg digits
+            tmp = len(digits) + exp
+            if tmp % DEC_DIGITS != 0:
+                pad = DEC_DIGITS - tmp % DEC_DIGITS
+                digits = (0,) * pad + digits
+                tmp += pad
+
+            weight = tmp // DEC_DIGITS - 1
+
+            # drop excessive trailing 0s
+            while digits and digits[-1] == 0:
+                digits = digits[:-1]
+            # but right pad with 0s to the last pg digit
+            if len(digits) % DEC_DIGITS != 0:
+                pad = DEC_DIGITS - len(digits) % DEC_DIGITS
+                digits += (0,) * pad
+
+        else:
+            dscale = 0
+
+            # align the py digits to the pg digits if there's some py exponent
+            if exp % DEC_DIGITS != 0:
+                digits = digits + (0,) * (exp % DEC_DIGITS)
+
+            # left pad with 0 to align the py digits to the pg digits
+            if len(digits) % DEC_DIGITS != 0:
+                pad = DEC_DIGITS - len(digits) % DEC_DIGITS
+                digits = (0,) * pad + digits
+
+            weight = len(digits) // DEC_DIGITS - 1 + exp // DEC_DIGITS
+
+        out = bytearray(
+            _pack_numeric_head(
+                len(digits) // DEC_DIGITS,
+                weight,
+                NUMERIC_NEG if sign else NUMERIC_POS,
+                dscale,
+            )
+        )
+
+        for i in range(0, len(digits), DEC_DIGITS):
+            digit = (
+                1000 * digits[i]
+                + 100 * digits[i + 1]
+                + 10 * digits[i + 2]
+                + digits[i + 3]
+            )
+            out += _pack_uint2(digit)
+
+        return out
index 0ec5e193089d1a3d002e64f7119abdd77a46816f..43aba1662a19dac306d784c887dc8ee1dbb9e403 100644 (file)
@@ -340,12 +340,78 @@ def test_quote_numeric(conn, val, expr):
         assert r == (val, -val)
 
 
-@pytest.mark.xfail
-def test_dump_numeric_binary():
-    # TODO: numeric binary adaptation
-    tx = Transformer()
-    n = Decimal(1)
-    tx.get_dumper(n, Format.BINARY).dump(n)
+@pytest.mark.parametrize(
+    "expr",
+    ["NaN", "1", "1.0", "-1", "0.0", "0.01", "11", "1.1", "1.01", "0", "0.00"]
+    + [
+        "0.0000000",
+        "0.00001",
+        "1.00001",
+        "-1.00000000000000",
+        "-2.00000000000000",
+        "1000000000.12345",
+        "100.123456790000000000000000",
+        "1.0e-1000",
+        "1e1000",
+        "0.000000000000000000000000001",
+        "1.0000000000000000000000001",
+        "1000000000000000000000000.001",
+        "1000000000000000000000000000.001",
+        "9999999999999999999999999999.9",
+    ],
+)
+def test_dump_numeric_binary(conn, expr):
+    cur = conn.cursor()
+    val = Decimal(expr)
+    cur.execute("select %b::text = %s::decimal::text", [val, expr])
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+def test_dump_numeric_exhaustive(conn, fmt_in):
+    cur = conn.cursor()
+
+    funcs = [
+        (lambda i: "1" + "0" * i),
+        (lambda i: "1" + "0" * i + "." + "0" * i),
+        (lambda i: "-1" + "0" * i),
+        (lambda i: "0." + "0" * i + "1"),
+        (lambda i: "-0." + "0" * i + "1"),
+        (lambda i: "1." + "0" * i + "1"),
+        (lambda i: "1." + "0" * i + "10"),
+        (lambda i: "1" + "0" * i + ".001"),
+        (lambda i: "9" + "9" * i),
+        (lambda i: "9" + "." + "9" * i),
+        (lambda i: "9" + "9" * i + ".9"),
+        (lambda i: "9" + "9" * i + "." + "9" * i),
+        (lambda i: "1.1e%s" % i),
+        (lambda i: "1.1e-%s" % i),
+    ]
+
+    for i in range(100):
+        for f in funcs:
+            expr = f(i)
+            val = Decimal(expr)
+            # For Postgres, NaN = NaN. Shrodinger says it's fine.
+            cur.execute(
+                f"select %{fmt_in}::text = %s::decimal::text", [val, expr]
+            )
+            assert cur.fetchone()[0] is True
+
+
+@pytest.mark.pg(">= 14")
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        ("inf", "Infinity"),
+        ("-inf", "-Infinity"),
+    ],
+)
+def test_dump_numeric_binary_inf(conn, val, expr):
+    cur = conn.cursor()
+    val = Decimal(val)
+    cur.execute("select %b", [val])
 
 
 @pytest.mark.parametrize(