From: Daniele Varrazzo Date: Thu, 6 May 2021 02:26:06 +0000 (+0200) Subject: Handle Decimal sNaN value X-Git-Tag: 3.0.dev0~48^2~8 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b111cf9480088fe75ccee6e296f3496bfcd6974d;p=thirdparty%2Fpsycopg.git Handle Decimal sNaN value --- diff --git a/psycopg3/psycopg3/types/numeric.py b/psycopg3/psycopg3/types/numeric.py index e07025037..fb9b6fea2 100644 --- a/psycopg3/psycopg3/types/numeric.py +++ b/psycopg3/psycopg3/types/numeric.py @@ -88,6 +88,13 @@ class DecimalDumper(SpecialValuesDumper): _oid = builtins["numeric"].oid + def dump(self, obj: Decimal) -> bytes: + if obj.is_nan(): + # cover NaN and sNaN + return b"NaN" + else: + return str(obj).encode("utf8") + _special = { b"Infinity": b"'Infinity'::numeric", b"-Infinity": b"'-Infinity'::numeric", @@ -355,7 +362,7 @@ class DecimalBinaryDumper(Dumper): def dump(self, obj: Decimal) -> Union[bytearray, bytes]: sign, digits, exp = obj.as_tuple() - if exp == "n": # type: ignore[comparison-overlap] + if exp == "n" or exp == "N": # type: ignore[comparison-overlap] return NUMERIC_NAN_BIN elif exp == "F": # type: ignore[comparison-overlap] return NUMERIC_NINF_BIN if sign else NUMERIC_PINF_BIN diff --git a/tests/fix_faker.py b/tests/fix_faker.py index 2555047b1..b5b6f20e5 100644 --- a/tests/fix_faker.py +++ b/tests/fix_faker.py @@ -229,9 +229,9 @@ class Faker: def make_Decimal(self, spec): if random() >= 0.99: if self.conn.info.server_version >= 140000: - return Decimal(choice(["NaN", "Inf", "-Inf"])) + return Decimal(choice(["NaN", "sNaN", "Inf", "-Inf"])) else: - return Decimal("NaN") + return Decimal(choice(["NaN", "sNaN"])) sign = choice("+-") num = choice(["0.zd", "d", "d.d"]) diff --git a/tests/types/test_numeric.py b/tests/types/test_numeric.py index 9a6be1694..3f6f93f82 100644 --- a/tests/types/test_numeric.py +++ b/tests/types/test_numeric.py @@ -297,16 +297,20 @@ def test_load_float_approx(conn, expr, pgtype, want, fmt_out): "val", [ "0", + "-0", "0.0", "0.000000000000000000001", "-0.000000000000000000001", "nan", + "snan", ], ) -def test_roundtrip_numeric(conn, val): - cur = conn.cursor() +@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) +def test_roundtrip_numeric(conn, val, fmt_in, fmt_out): + cur = conn.cursor(binary=fmt_out) val = Decimal(val) - cur.execute("select %s", (val,)) + cur.execute(f"select %{fmt_in}", (val,)) result = cur.fetchone()[0] assert isinstance(result, Decimal) if val.is_nan(): @@ -323,6 +327,7 @@ def test_roundtrip_numeric(conn, val): ("0.00000000000000001", b"1E-17"), ("-0.00000000000000001", b" -1E-17"), ("nan", b"'NaN'::numeric"), + ("snan", b"'NaN'::numeric"), ], ) def test_quote_numeric(conn, val, expr):