]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Handle Decimal sNaN value
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 6 May 2021 02:26:06 +0000 (04:26 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 6 May 2021 17:09:56 +0000 (19:09 +0200)
psycopg3/psycopg3/types/numeric.py
tests/fix_faker.py
tests/types/test_numeric.py

index e070250373ba5335228d328c2e01c9ce1006087a..fb9b6fea23ee620b8c1a41b9e78b6081d420bdf3 100644 (file)
@@ -88,6 +88,13 @@ class DecimalDumper(SpecialValuesDumper):
 
     _oid = builtins["numeric"].oid
 
+    def dump(self, obj: Decimal) -> bytes:
+        if obj.is_nan():
+            # cover NaN and sNaN
+            return b"NaN"
+        else:
+            return str(obj).encode("utf8")
+
     _special = {
         b"Infinity": b"'Infinity'::numeric",
         b"-Infinity": b"'-Infinity'::numeric",
@@ -355,7 +362,7 @@ class DecimalBinaryDumper(Dumper):
 
     def dump(self, obj: Decimal) -> Union[bytearray, bytes]:
         sign, digits, exp = obj.as_tuple()
-        if exp == "n":  # type: ignore[comparison-overlap]
+        if exp == "n" or exp == "N":  # type: ignore[comparison-overlap]
             return NUMERIC_NAN_BIN
         elif exp == "F":  # type: ignore[comparison-overlap]
             return NUMERIC_NINF_BIN if sign else NUMERIC_PINF_BIN
index 2555047b1ac1dd1d967b2268654f9f17358efbf5..b5b6f20e50e021ff11d096ffe9b84c1c2ef9b82e 100644 (file)
@@ -229,9 +229,9 @@ class Faker:
     def make_Decimal(self, spec):
         if random() >= 0.99:
             if self.conn.info.server_version >= 140000:
-                return Decimal(choice(["NaN", "Inf", "-Inf"]))
+                return Decimal(choice(["NaN", "sNaN", "Inf", "-Inf"]))
             else:
-                return Decimal("NaN")
+                return Decimal(choice(["NaN", "sNaN"]))
 
         sign = choice("+-")
         num = choice(["0.zd", "d", "d.d"])
index 9a6be1694cff6f9b288b8bfde0511587b0301482..3f6f93f822b14bd5f7341accc4d22935ea4305f2 100644 (file)
@@ -297,16 +297,20 @@ def test_load_float_approx(conn, expr, pgtype, want, fmt_out):
     "val",
     [
         "0",
+        "-0",
         "0.0",
         "0.000000000000000000001",
         "-0.000000000000000000001",
         "nan",
+        "snan",
     ],
 )
-def test_roundtrip_numeric(conn, val):
-    cur = conn.cursor()
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_roundtrip_numeric(conn, val, fmt_in, fmt_out):
+    cur = conn.cursor(binary=fmt_out)
     val = Decimal(val)
-    cur.execute("select %s", (val,))
+    cur.execute(f"select %{fmt_in}", (val,))
     result = cur.fetchone()[0]
     assert isinstance(result, Decimal)
     if val.is_nan():
@@ -323,6 +327,7 @@ def test_roundtrip_numeric(conn, val):
         ("0.00000000000000001", b"1E-17"),
         ("-0.00000000000000001", b" -1E-17"),
         ("nan", b"'NaN'::numeric"),
+        ("snan", b"'NaN'::numeric"),
     ],
 )
 def test_quote_numeric(conn, val, expr):