From: Daniele Varrazzo Date: Wed, 5 May 2021 18:20:52 +0000 (+0200) Subject: Decimal binary dump algorithm streamlined X-Git-Tag: 3.0.dev0~48^2~10 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=52cca81d0c881bc312a9b412089a69e54e89d078;p=thirdparty%2Fpsycopg.git Decimal binary dump algorithm streamlined --- diff --git a/psycopg3/psycopg3/types/numeric.py b/psycopg3/psycopg3/types/numeric.py index 4d07aaaa0..0ea7f2c30 100644 --- a/psycopg3/psycopg3/types/numeric.py +++ b/psycopg3/psycopg3/types/numeric.py @@ -364,46 +364,36 @@ class DecimalBinaryDumper(Dumper): dscale = -exp # left pad with 0 to align the py digits to the pg digits - tmp = len(digits) + exp - if tmp % DEC_DIGITS != 0: - pad = DEC_DIGITS - tmp % DEC_DIGITS - digits = (0,) * pad + digits - tmp += pad - - weight = tmp // DEC_DIGITS - 1 - - # drop excessive trailing 0s - while digits and digits[-1] == 0: - digits = digits[:-1] - # but right pad with 0s to the last pg digit - if len(digits) % DEC_DIGITS != 0: - pad = DEC_DIGITS - len(digits) % DEC_DIGITS - digits += (0,) * pad + mod = (len(digits) - dscale) % DEC_DIGITS + if mod: + digits = (0,) * (DEC_DIGITS - mod) + digits else: dscale = 0 # align the py digits to the pg digits if there's some py exponent if exp % DEC_DIGITS != 0: - digits = digits + (0,) * (exp % DEC_DIGITS) + digits += (0,) * (exp % DEC_DIGITS) # left pad with 0 to align the py digits to the pg digits - if len(digits) % DEC_DIGITS != 0: - pad = DEC_DIGITS - len(digits) % DEC_DIGITS - digits = (0,) * pad + digits + mod = len(digits) % DEC_DIGITS + if mod: + digits = (0,) * (DEC_DIGITS - mod) + digits - weight = len(digits) // DEC_DIGITS - 1 + exp // DEC_DIGITS + weight = (len(digits) + exp) // DEC_DIGITS - 1 + mod = len(digits) % DEC_DIGITS out = bytearray( _pack_numeric_head( - len(digits) // DEC_DIGITS, + len(digits) // DEC_DIGITS + (mod and 1), weight, NUMERIC_NEG if sign else NUMERIC_POS, dscale, ) ) - for i in range(0, len(digits), DEC_DIGITS): + i = 0 + while i + 3 < len(digits): digit = ( 1000 * digits[i] + 100 * digits[i + 1] @@ -411,5 +401,17 @@ class DecimalBinaryDumper(Dumper): + digits[i + 3] ) out += _pack_uint2(digit) + i += DEC_DIGITS + + if mod: + if mod == 1: + digit = 1000 * digits[i] + elif mod == 2: + digit = 1000 * digits[i] + 100 * digits[i + 1] + elif mod == 3: + digit = ( + 1000 * digits[i] + 100 * digits[i + 1] + 10 * digits[i + 2] + ) + out += _pack_uint2(digit) return out diff --git a/tests/types/test_numeric.py b/tests/types/test_numeric.py index 43aba1662..9a6be1694 100644 --- a/tests/types/test_numeric.py +++ b/tests/types/test_numeric.py @@ -363,8 +363,9 @@ def test_quote_numeric(conn, val, expr): def test_dump_numeric_binary(conn, expr): cur = conn.cursor() val = Decimal(expr) - cur.execute("select %b::text = %s::decimal::text", [val, expr]) - assert cur.fetchone()[0] is True + cur.execute("select %b::text, %s::decimal::text", [val, expr]) + want, got = cur.fetchone() + assert got == want @pytest.mark.slow @@ -393,11 +394,11 @@ def test_dump_numeric_exhaustive(conn, fmt_in): for f in funcs: expr = f(i) val = Decimal(expr) - # For Postgres, NaN = NaN. Shrodinger says it's fine. cur.execute( - f"select %{fmt_in}::text = %s::decimal::text", [val, expr] + f"select %{fmt_in}::text, %s::decimal::text", [val, expr] ) - assert cur.fetchone()[0] is True + want, got = cur.fetchone() + assert got == want @pytest.mark.pg(">= 14")