Float4BinaryLoader as Float4BinaryLoader,
Float8BinaryLoader as Float8BinaryLoader,
NumericLoader as NumericLoader,
+ NumericBinaryLoader as NumericBinaryLoader,
)
from .singletons import (
BoolDumper as BoolDumper,
Float4BinaryLoader.register("float4", ctx)
Float8BinaryLoader.register("float8", ctx)
NumericLoader.register("numeric", ctx)
+ NumericBinaryLoader.register("numeric", ctx)
BoolDumper.register(bool, ctx)
BoolBinaryDumper.register(bool, ctx)
# Copyright (C) 2020-2021 The Psycopg Team
import struct
-from typing import Any, Callable, Dict, Tuple, cast
-from decimal import Decimal
+from typing import Any, Callable, DefaultDict, Dict, Tuple, cast
+from decimal import Decimal, DefaultContext, Context
+from .. import errors as e
from ..pq import Format
from ..oids import postgres_types as builtins
from ..adapt import Buffer, Dumper, Loader
_UnpackFloat = Callable[[bytes], Tuple[float]]
_pack_int2 = cast(_PackInt, struct.Struct("!h").pack)
+_pack_uint2 = cast(_PackInt, struct.Struct("!H").pack)
_pack_int4 = cast(_PackInt, struct.Struct("!i").pack)
_pack_uint4 = cast(_PackInt, struct.Struct("!I").pack)
_pack_int8 = cast(_PackInt, struct.Struct("!q").pack)
_pack_float8 = cast(_PackFloat, struct.Struct("!d").pack)
_unpack_int2 = cast(_UnpackInt, struct.Struct("!h").unpack)
+_unpack_uint2 = cast(_UnpackInt, struct.Struct("!H").unpack)
_unpack_int4 = cast(_UnpackInt, struct.Struct("!i").unpack)
_unpack_uint4 = cast(_UnpackInt, struct.Struct("!I").unpack)
_unpack_int8 = cast(_UnpackInt, struct.Struct("!q").unpack)
if isinstance(data, memoryview):
data = bytes(data)
return Decimal(data.decode("utf8"))
+
+
+NUMERIC_POS = 0x0000
+NUMERIC_NEG = 0x4000
+NUMERIC_NAN = 0xC000
+NUMERIC_PINF = 0xD000
+NUMERIC_NINF = 0xF000
+
+_numeric_special = {
+ NUMERIC_NAN: Decimal("NaN"),
+ NUMERIC_PINF: Decimal("Infinity"),
+ NUMERIC_NINF: Decimal("-Infinity"),
+}
+
+
+class _ContextMap(DefaultDict[int, Context]):
+ """
+ Cache for decimal contexts to use when the precision requires it.
+
+ Note: if the default context is used (prec=28) you can get an invalid
+ operation or a rounding to 0:
+
+ - Decimal(1000).shift(24) = Decimal('1000000000000000000000000000')
+ - Decimal(1000).shift(25) = Decimal('0')
+ - Decimal(1000).shift(30) raises InvalidOperation
+ """
+
+ def __missing__(self, key: int) -> Context:
+ val = Context(prec=key)
+ self[key] = val
+ return val
+
+
+_contexts = _ContextMap()
+for i in range(DefaultContext.prec):
+ _contexts[i] = DefaultContext
+
+_unpack_numeric_head = cast(
+ Callable[[bytes], Tuple[int, int, int, int]],
+ struct.Struct("!HhHH").unpack_from,
+)
+
+
+class NumericBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Decimal:
+ ndigits, weight, sign, dscale = _unpack_numeric_head(data)
+ if sign == NUMERIC_POS or sign == NUMERIC_NEG:
+ val = 0
+ for i in range(8, len(data), 2):
+ val = val * 10_000 + data[i] * 0x100 + data[i + 1]
+
+ shift = dscale - (ndigits - weight - 1) * 4
+ ctx = _contexts[weight * 4 + dscale + 8]
+ return (
+ Decimal(val if sign == NUMERIC_POS else -val)
+ .scaleb(-dscale, ctx)
+ .shift(shift, ctx)
+ )
+ else:
+ try:
+ return _numeric_special[sign]
+ except KeyError:
+ raise e.DataError(f"bad value for numeric sign: 0x{sign:X}")
tx.get_dumper(n, Format.BINARY).dump(n)
-@pytest.mark.xfail
-def test_load_numeric_binary(conn):
- # TODO: numeric binary casting
+@pytest.mark.parametrize(
+ "expr",
+ ["nan", "0", "1", "-1", "0.0", "0.01"]
+ + [
+ "0.0000000",
+ "-1.00000000000000",
+ "-2.00000000000000",
+ "1000000000.12345",
+ "100.123456790000000000000000",
+ "1.0e-1000",
+ "1e1000",
+ "0.000000000000000000000000001",
+ "1.0000000000000000000000001",
+ "1000000000000000000000000.001",
+ "1000000000000000000000000000.001",
+ "9999999999999999999999999999.9",
+ ],
+)
+def test_load_numeric_binary(conn, expr):
+ cur = conn.cursor(binary=1)
+ res = cur.execute(f"select '{expr}'::numeric").fetchone()[0]
+ val = Decimal(expr)
+ if val.is_nan():
+ assert res.is_nan()
+ else:
+ assert res == val
+ if "e" not in expr:
+ assert str(res) == str(val)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_load_numeric_exhaustive(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+
+ funcs = [
+ (lambda i: "1" + "0" * i),
+ (lambda i: "1" + "0" * i + "." + "0" * i),
+ (lambda i: "-1" + "0" * i),
+ (lambda i: "0." + "0" * i + "1"),
+ (lambda i: "-0." + "0" * i + "1"),
+ (lambda i: "1." + "0" * i + "1"),
+ (lambda i: "1." + "0" * i + "10"),
+ (lambda i: "1" + "0" * i + ".001"),
+ (lambda i: "9" + "9" * i),
+ (lambda i: "9" + "." + "9" * i),
+ (lambda i: "9" + "9" * i + ".9"),
+ (lambda i: "9" + "9" * i + "." + "9" * i),
+ ]
+
+ for i in range(100):
+ for f in funcs:
+ snum = f(i)
+ want = Decimal(snum)
+ got = cur.execute(f"select '{snum}'::decimal").fetchone()[0]
+ assert want == got
+ assert str(want) == str(got)
+
+
+@pytest.mark.pg(">= 14")
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("inf", "Infinity"),
+ ("-inf", "-Infinity"),
+ ],
+)
+def test_load_numeric_binary_inf(conn, val, expr):
cur = conn.cursor(binary=1)
- res = cur.execute("select 1::numeric").fetchone()[0]
- assert res == Decimal(1)
+ res = cur.execute(f"select '{expr}'::numeric").fetchone()[0]
+ val = Decimal(val)
+ assert res == val
@pytest.mark.parametrize(