]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
test(numpy): add random tests with numpy objects
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 16 Dec 2022 19:11:19 +0000 (19:11 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 5 Aug 2023 14:21:30 +0000 (15:21 +0100)
tests/fix_faker.py
tests/types/test_numpy.py

index 6d42ff8a5ecaea47a293f97f51ff73da29145d9c..95012fcd228dc93a66d469dff97462d7c5bab6ff 100644 (file)
@@ -4,7 +4,7 @@ import ipaddress
 from math import isnan
 from uuid import UUID
 from random import choice, random, randrange
-from typing import Any, List, Set, Tuple, Union
+from typing import Any, List, Optional, Set, Tuple, Union
 from decimal import Decimal
 from contextlib import contextmanager, asynccontextmanager
 
@@ -42,7 +42,7 @@ class Faker:
         self.records = []
 
         self._schema = None
-        self._types = None
+        self._types: Optional[List[type]] = None
         self._types_names = None
         self._makers = {}
         self.table_name = sql.Identifier("fake_table")
@@ -63,15 +63,20 @@ class Faker:
         return [sql.Identifier(f"fld_{i}") for i in range(len(self.schema))]
 
     @property
-    def types(self):
+    def types(self) -> List[type]:
         if not self._types:
 
             def key(cls: type) -> str:
-                return cls.__name__
+                return f"{cls.__module__}.{cls.__qualname__}"
 
             self._types = sorted(self.get_supported_types(), key=key)
+
         return self._types
 
+    @types.setter
+    def types(self, types: List[type]) -> None:
+        self._types = types
+
     @property
     def types_names_sql(self):
         if self._types_names:
@@ -280,7 +285,7 @@ class Faker:
             name = f"{cls.__module__}.{name}"
 
         parts = name.split(".")
-        for i in range(len(parts)):
+        for i in range(len(parts) - 1, -1, -1):
             mname = f"{prefix}_{'_'.join(parts[-(i + 1) :])}"
             meth = getattr(self, mname, None)
             if meth:
@@ -313,7 +318,7 @@ class Faker:
         return want.obj == got
 
     def make_bool(self, spec):
-        return choice((True, False))
+        return spec(choice((True, False)))
 
     def make_bytearray(self, spec):
         return self.make_bytes(spec)
@@ -394,13 +399,15 @@ class Faker:
     def make_float(self, spec, double=True):
         if random() <= 0.99:
             # These exponents should generate no inf
-            return float(
+            return spec(
                 f"{choice('-+')}0.{randrange(1 << 53)}e{randrange(-310,309)}"
                 if double
                 else f"{choice('-+')}0.{randrange(1 << 22)}e{randrange(-37,38)}"
             )
         else:
-            return choice((0.0, -0.0, float("-inf"), float("inf"), float("nan")))
+            return choice(
+                (spec(0.0), spec(-0.0), spec("-inf"), spec("inf"), spec("nan"))
+            )
 
     def match_float(self, spec, got, want, approx=False, rel=None):
         if got is not None and isnan(got):
@@ -846,6 +853,69 @@ class Faker:
         minutes = randrange(-12 * 60, 12 * 60 + 1)
         return dt.timezone(dt.timedelta(minutes=minutes))
 
+    # numpy types support
+
+    def make_numpy_bool_(self, spec):
+        return self.make_bool(spec)
+
+    def make_numpy_int8(self, spec):
+        return spec(randrange(-(1 << 7), 1 << 7))
+
+    def make_numpy_int16(self, spec):
+        return self.make_Int2(spec)
+
+    def make_numpy_int32(self, spec):
+        return self.make_Int4(spec)
+
+    def make_numpy_int64(self, spec):
+        return self.make_Int8(spec)
+
+    def make_numpy_longlong(self, spec):
+        return self.make_numpy_int64(spec)
+
+    def make_numpy_uint8(self, spec):
+        return spec(randrange(0, 1 << 8))
+
+    def make_numpy_uint16(self, spec):
+        return spec(randrange(0, 1 << 16))
+
+    def make_numpy_uint32(self, spec):
+        return spec(randrange(0, 1 << 32))
+
+    def make_numpy_uint64(self, spec):
+        return spec(randrange(0, 1 << 64))
+
+    def make_numpy_ulonglong(self, spec):
+        return self.make_numpy_uint64(spec)
+
+    def make_numpy_float16(self, spec):
+        return self.make_Float4(spec)
+
+    def make_numpy_float32(self, spec):
+        return self.make_Float4(spec)
+
+    def make_numpy_float64(self, spec):
+        return self.make_Float8(spec)
+
+    def match_numpy_ulonglong(self, spec, got, want):
+        return self._match_numpy_with_decimal(spec, got, want)
+
+    def match_numpy_uint64(self, spec, got, want):
+        return self._match_numpy_with_decimal(spec, got, want)
+
+    def _match_numpy_with_decimal(self, spec, got, want):
+        assert isinstance(got, Decimal)
+        return self.match_any(spec, int(got), want)
+
+    def match_numpy_float16(self, spec, got, want):
+        return self.match_numpy_float32(spec, got, want)
+
+    def match_numpy_float32(self, spec, got, want):
+        return self.match_Float4(spec, got, want)
+
+    def match_numpy_float64(self, spec, got, want):
+        return self.match_Float8(spec, got, want)
+
 
 class JsonFloat:
     pass
index c0eb8a606d2c147c2e1e080b6ca92bdb56dfd105..52dcb7f6fac6836134e02c67d74b630e351a31a2 100644 (file)
@@ -143,3 +143,24 @@ def test_dump_numpy_float64(conn, val, fmt_in):
 
     cur.execute(f"select {val}::float8 = %{fmt_in.value}", (val,))
     assert cur.fetchone()[0] is True
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt", PyFormat)
+def test_random(conn, faker, fmt):
+    faker.types = [t for t in faker.types if issubclass(t, np.generic)]
+    faker.format = fmt
+    faker.choose_schema(ncols=20)
+    faker.make_records(50)
+
+    with conn.cursor() as cur:
+        cur.execute(faker.drop_stmt)
+        cur.execute(faker.create_stmt)
+        with faker.find_insert_problem(conn):
+            cur.executemany(faker.insert_stmt, faker.records)
+
+        cur.execute(faker.select_stmt)
+        recs = cur.fetchall()
+
+    for got, want in zip(recs, faker.records):
+        faker.assert_record(got, want)