]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
test(numpy): consolidate numpy float tests
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 16 Dec 2022 20:23:28 +0000 (20:23 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 5 Aug 2023 14:21:30 +0000 (15:21 +0100)
tests/types/test_numpy.py

index 52dcb7f6fac6836134e02c67d74b630e351a31a2..0d283392ceae032527fea965bab27ad8c72bfeed 100644 (file)
@@ -1,3 +1,5 @@
+from math import isnan
+
 import pytest
 from psycopg.adapt import PyFormat
 
@@ -25,6 +27,28 @@ def test_classes_identities():
     assert np.int_ is np.int64
 
 
+@pytest.mark.parametrize(
+    "name, equiv",
+    [
+        ("inf", "inf"),
+        ("infty", "inf"),
+        ("NINF", "-inf"),
+        ("nan", "nan"),
+        ("NaN", "nan"),
+        ("NAN", "nan"),
+        ("PZERO", "0.0"),
+        ("NZERO", "-0.0"),
+    ],
+)
+def test_special_values(name, equiv):
+    obj = getattr(np, name)
+    assert isinstance(obj, float)
+    if equiv == "nan":
+        assert isnan(obj)
+    else:
+        assert obj == float(equiv)
+
+
 @pytest.mark.parametrize(
     "nptype, val, expr",
     [
@@ -69,80 +93,40 @@ def test_dump_int(conn, val, nptype, expr, fmt_in):
     assert cur.fetchone()[0] is True
 
 
-# Test float special values
 @pytest.mark.parametrize(
-    "val, expr",
+    "nptype, val, pgtype",
     [
-        (np.PZERO, "'0.0'::float8"),
-        (np.NZERO, "'-0.0'::float8"),
-        (np.nan, "'NaN'::float8"),
-        (np.inf, "'Infinity'::float8"),
-        (np.NINF, "'-Infinity'::float8"),
+        ("float16", "4e4", "float4"),
+        ("float16", "4e-4", "float4"),
+        ("float16", "4000.0", "float4"),
+        ("float16", "3.14", "float4"),
+        ("float32", "256e6", "float4"),
+        ("float32", "256e-6", "float4"),
+        ("float32", "2.7182817", "float4"),
+        ("float32", "3.1415927", "float4"),
+        ("float64", "256e12", "float8"),
+        ("float64", "256e-12", "float8"),
+        ("float64", "2.718281828459045", "float8"),
+        ("float64", "3.141592653589793", "float8"),
     ],
 )
 @pytest.mark.parametrize("fmt_in", PyFormat)
-def test_dump_special_values(conn, val, expr, fmt_in):
-
-    if val == np.nan:
-        assert np.nan == np.NAN == np.NaN
-
-    if val == np.inf:
-        assert np.inf == np.Inf == np.PINF == np.infty
-
-    assert isinstance(val, float)
-
-    cur = conn.cursor()
-    cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
-
-    assert cur.fetchone()[0] is True
-
-    cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
-    assert cur.fetchone()[0] is True
-
-
-@pytest.mark.parametrize("val", ["4e4", "4e-4", "4000.0", "3.14"])
-@pytest.mark.parametrize("fmt_in", PyFormat)
-def test_dump_numpy_float16(conn, val, fmt_in):
-
-    val = np.float16(val)
+def test_dump_float(conn, nptype, val, pgtype, fmt_in):
+    nptype = getattr(np, nptype)
+    val = nptype(val)
     cur = conn.cursor()
 
-    cur.execute(f"select pg_typeof({val}::float4) = pg_typeof(%{fmt_in.value})", (val,))
+    cur.execute(
+        f"select pg_typeof('{val}'::{pgtype}) = pg_typeof(%{fmt_in.value})", (val,)
+    )
     assert cur.fetchone()[0] is True
 
-    cur.execute(f"select {val}::float4, %(obj){fmt_in.value}", {"obj": val})
+    cur.execute(f"select '{val}'::{pgtype}, %(obj){fmt_in.value}", {"obj": val})
     rec = cur.fetchone()
-    assert rec[0] == pytest.approx(rec[1], 1e-3)
-
-
-@pytest.mark.parametrize("val", ["256e6", "256e-6", "2.7182817", "3.1415927"])
-@pytest.mark.parametrize("fmt_in", PyFormat)
-def test_dump_numpy_float32(conn, val, fmt_in):
-
-    val = np.float32(val)
-    cur = conn.cursor()
-
-    cur.execute(f"select pg_typeof({val}::float4) = pg_typeof(%{fmt_in.value})", (val,))
-    assert cur.fetchone()[0] is True
-
-    cur.execute(f"select {val}::float4 = %{fmt_in.value}", (val,))
-    assert cur.fetchone()[0] is True
-
-
-@pytest.mark.parametrize(
-    "val", ["256e12", "256e-12", "2.718281828459045", "3.141592653589793"]
-)
-@pytest.mark.parametrize("fmt_in", PyFormat)
-def test_dump_numpy_float64(conn, val, fmt_in):
-
-    val = np.float64(val)
-    cur = conn.cursor()
-
-    cur.execute(f"select pg_typeof({val}::float8) = pg_typeof(%{fmt_in.value})", (val,))
-    assert cur.fetchone()[0] is True
-
-    cur.execute(f"select {val}::float8 = %{fmt_in.value}", (val,))
-    assert cur.fetchone()[0] is True
+    if nptype is np.float16:
+        assert rec[0] == pytest.approx(rec[1], 1e-3)
+    else:
+        assert rec[0] == rec[1]
 
 
 @pytest.mark.slow