]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Decimal binary dump algorithm streamlined
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 5 May 2021 18:20:52 +0000 (20:20 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 6 May 2021 17:09:56 +0000 (19:09 +0200)
psycopg3/psycopg3/types/numeric.py
tests/types/test_numeric.py

index 4d07aaaa058e2c7bf847a9f2d558a5f1cfc07593..0ea7f2c3096bb4e399dc7befd0a851eeb26357bb 100644 (file)
@@ -364,46 +364,36 @@ class DecimalBinaryDumper(Dumper):
             dscale = -exp
 
             # left pad with 0 to align the py digits to the pg digits
-            tmp = len(digits) + exp
-            if tmp % DEC_DIGITS != 0:
-                pad = DEC_DIGITS - tmp % DEC_DIGITS
-                digits = (0,) * pad + digits
-                tmp += pad
-
-            weight = tmp // DEC_DIGITS - 1
-
-            # drop excessive trailing 0s
-            while digits and digits[-1] == 0:
-                digits = digits[:-1]
-            # but right pad with 0s to the last pg digit
-            if len(digits) % DEC_DIGITS != 0:
-                pad = DEC_DIGITS - len(digits) % DEC_DIGITS
-                digits += (0,) * pad
+            mod = (len(digits) - dscale) % DEC_DIGITS
+            if mod:
+                digits = (0,) * (DEC_DIGITS - mod) + digits
 
         else:
             dscale = 0
 
             # align the py digits to the pg digits if there's some py exponent
             if exp % DEC_DIGITS != 0:
-                digits = digits + (0,) * (exp % DEC_DIGITS)
+                digits += (0,) * (exp % DEC_DIGITS)
 
             # left pad with 0 to align the py digits to the pg digits
-            if len(digits) % DEC_DIGITS != 0:
-                pad = DEC_DIGITS - len(digits) % DEC_DIGITS
-                digits = (0,) * pad + digits
+            mod = len(digits) % DEC_DIGITS
+            if mod:
+                digits = (0,) * (DEC_DIGITS - mod) + digits
 
-            weight = len(digits) // DEC_DIGITS - 1 + exp // DEC_DIGITS
+        weight = (len(digits) + exp) // DEC_DIGITS - 1
+        mod = len(digits) % DEC_DIGITS
 
         out = bytearray(
             _pack_numeric_head(
-                len(digits) // DEC_DIGITS,
+                len(digits) // DEC_DIGITS + (mod and 1),
                 weight,
                 NUMERIC_NEG if sign else NUMERIC_POS,
                 dscale,
             )
         )
 
-        for i in range(0, len(digits), DEC_DIGITS):
+        i = 0
+        while i + 3 < len(digits):
             digit = (
                 1000 * digits[i]
                 + 100 * digits[i + 1]
@@ -411,5 +401,17 @@ class DecimalBinaryDumper(Dumper):
                 + digits[i + 3]
             )
             out += _pack_uint2(digit)
+            i += DEC_DIGITS
+
+        if mod:
+            if mod == 1:
+                digit = 1000 * digits[i]
+            elif mod == 2:
+                digit = 1000 * digits[i] + 100 * digits[i + 1]
+            elif mod == 3:
+                digit = (
+                    1000 * digits[i] + 100 * digits[i + 1] + 10 * digits[i + 2]
+                )
+            out += _pack_uint2(digit)
 
         return out
index 43aba1662a19dac306d784c887dc8ee1dbb9e403..9a6be1694cff6f9b288b8bfde0511587b0301482 100644 (file)
@@ -363,8 +363,9 @@ def test_quote_numeric(conn, val, expr):
 def test_dump_numeric_binary(conn, expr):
     cur = conn.cursor()
     val = Decimal(expr)
-    cur.execute("select %b::text = %s::decimal::text", [val, expr])
-    assert cur.fetchone()[0] is True
+    cur.execute("select %b::text, %s::decimal::text", [val, expr])
+    want, got = cur.fetchone()
+    assert got == want
 
 
 @pytest.mark.slow
@@ -393,11 +394,11 @@ def test_dump_numeric_exhaustive(conn, fmt_in):
         for f in funcs:
             expr = f(i)
             val = Decimal(expr)
-            # For Postgres, NaN = NaN. Shrodinger says it's fine.
             cur.execute(
-                f"select %{fmt_in}::text = %s::decimal::text", [val, expr]
+                f"select %{fmt_in}::text, %s::decimal::text", [val, expr]
             )
-            assert cur.fetchone()[0] is True
+            want, got = cur.fetchone()
+            assert got == want
 
 
 @pytest.mark.pg(">= 14")