]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Test random range objects
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 9 Jun 2021 15:44:34 +0000 (16:44 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 25 Jun 2021 15:16:26 +0000 (16:16 +0100)
tests/fix_faker.py

index fb83a8080e2d295230b81e21f922ee8668515d8a..fb1c7bf74442b33b8ec83d266354e5e13976f31b 100644 (file)
@@ -12,6 +12,8 @@ import pytest
 import psycopg3
 from psycopg3 import sql
 from psycopg3.adapt import Format
+from psycopg3.types.range import Range
+from psycopg3.wrappers.numeric import Int4, Int8
 
 
 @pytest.fixture
@@ -28,6 +30,7 @@ class Faker:
     json_max_length = 10
     str_max_length = 100
     list_max_length = 20
+    tuple_max_length = 15
 
     def __init__(self, connection):
         self.conn = connection
@@ -35,6 +38,7 @@ class Faker:
         self.records = []
 
         self._schema = None
+        self._types = None
         self._types_names = None
         self._makers = {}
         self.table_name = sql.Identifier("fake_table")
@@ -64,6 +68,14 @@ class Faker:
     def fields_names(self):
         return [sql.Identifier(f"fld_{i}") for i in range(len(self.schema))]
 
+    @property
+    def types(self):
+        if not self._types:
+            self._types = sorted(
+                self.get_supported_types(), key=lambda cls: cls.__name__
+            )
+        return self._types
+
     @property
     def types_names(self):
         if self._types_names:
@@ -125,30 +137,8 @@ class Faker:
             fields, self.table_name
         )
 
-    def choose_schema(self, types=None, ncols=20):
-        if not types:
-            types = self.get_supported_types()
-
-        types_list = sorted(types, key=lambda cls: cls.__name__)
-        schema = [choice(types_list) for i in range(ncols)]
-        for i, cls in enumerate(schema):
-            # choose the type of the array
-            if cls is list:
-                while 1:
-                    scls = choice(types_list)
-                    if scls is not list:
-                        break
-                schema[i] = [scls]
-            elif cls is tuple:
-                schema[i] = tuple(self.choose_schema(types=types, ncols=ncols))
-            # Pick timezone yes/no
-            elif cls is dt.time:
-                if choice([True, False]):
-                    schema[i] = TimeTz
-            elif cls is dt.datetime:
-                if choice([True, False]):
-                    schema[i] = DateTimeTz
-
+    def choose_schema(self, ncols=20):
+        schema = [self.make_schema(choice(self.types)) for i in range(ncols)]
         return schema
 
     def make_records(self, nrecords):
@@ -184,9 +174,22 @@ class Faker:
 
         return rv
 
+    def make_schema(self, cls):
+        """Create a schema spec from a Python type.
+
+        A schema specifies what Postgres type to generate when a Python type
+        maps to more than one (e.g. tuple -> composite, list -> array[],
+        datetime -> timestamp[tz]).
+
+        A schema for a type is represented by a tuple (type, ...) which the
+        matching make_*() method can interpret, or just type if the type
+        doesn't require further specification.
+        """
+        meth = self._get_method("schema", cls)
+        return meth(cls) if meth else cls
+
     def get_maker(self, spec):
-        # convert a list or tuple into list or tuple
-        cls = spec if isinstance(spec, type) else type(spec)
+        cls = spec if isinstance(spec, type) else spec[0]
 
         try:
             return self._makers[cls]
@@ -203,8 +206,7 @@ class Faker:
             )
 
     def get_matcher(self, spec):
-        # convert a list or tuple into list or tuple
-        cls = spec if isinstance(spec, type) else type(spec)
+        cls = spec if isinstance(spec, type) else spec[0]
         meth = self._get_method("match", cls)
         return meth if meth else self.match_any
 
@@ -223,7 +225,7 @@ class Faker:
         return None
 
     def make(self, spec):
-        # spec can be a type or a list [type] or a tuple (spec, spec, ...)
+        # spec can be a type or a tuple (type, options)
         return self.get_maker(spec)(spec)
 
     def match_any(self, spec, got, want):
@@ -251,14 +253,16 @@ class Faker:
         day = randrange(dt.date.max.toordinal())
         return dt.date.fromordinal(day + 1)
 
+    def schema_datetime(self, cls):
+        return self.schema_time(cls)
+
     def make_datetime(self, spec):
         delta = dt.datetime.max - dt.datetime.min
         micros = randrange((delta.days + 1) * 24 * 60 * 60 * 1_000_000)
-        return dt.datetime.min + dt.timedelta(microseconds=micros)
-
-    def make_DateTimeTz(self, spec):
-        rv = self.make_datetime(spec)
-        return rv.replace(tzinfo=self._make_tz(spec))
+        rv = dt.datetime.min + dt.timedelta(microseconds=micros)
+        if spec[1]:
+            rv = rv.replace(tzinfo=self._make_tz(spec))
+        return rv
 
     def make_Decimal(self, spec):
         if random() >= 0.99:
@@ -362,15 +366,23 @@ class Faker:
             f"{choice('-+')}0.{randrange(1 << 20)}e{randrange(-15,15)}"
         )
 
+    def schema_list(self, cls):
+        while 1:
+            scls = choice(self.types)
+            if scls is not cls:
+                break
+
+        return (cls, self.make_schema(scls))
+
     def make_list(self, spec):
         # don't make empty lists because they regularly fail cast
         length = randrange(1, self.list_max_length)
-        spec = spec[0]
+        spec = spec[1]
         return [self.make(spec) for i in range(length)]
 
     def match_list(self, spec, got, want):
         assert len(got) == len(want)
-        m = self.get_matcher(spec[0])
+        m = self.get_matcher(spec[1])
         for g, w in zip(got, want):
             m(spec, g, w)
 
@@ -383,6 +395,48 @@ class Faker:
     def make_Oid(self, spec):
         return spec(randrange(1 << 32))
 
+    def schema_Range(self, cls):
+        subtypes = [
+            Int4,
+            Int8,
+            Decimal,
+            dt.date,
+            (dt.datetime, True),
+            (dt.datetime, False),
+        ]
+        return (cls, choice(subtypes))
+
+    def make_Range(self, spec):
+        if random() < 0.02:
+            return Range(empty=True)
+
+        bounds = []
+        while len(bounds) < 2:
+            if random() < 0.05:
+                bounds.append(None)
+                continue
+
+            val = self.make(spec[1])
+            # NaN are allowed in a range, but comparison in Python get tricky.
+            if spec[1] is Decimal and val.is_nan():
+                continue
+
+            bounds.append(val)
+
+        if bounds[0] is not None and bounds[1] is not None:
+            if bounds[0] > bounds[1]:
+                bounds.reverse()
+
+        return Range(bounds[0], bounds[1], choice("[(") + choice("])"))
+
+    def match_Range(self, spec, got, want):
+        # normalise the bounds of unbounded ranges
+        if want.lower is None and want.lower_inc:
+            want = type(want)(want.lower, want.upper, "(" + want.bounds[1])
+        if want.upper is None and want.upper_inc:
+            want = type(want)(want.lower, want.upper, want.bounds[0] + ")")
+        return got == want
+
     def make_str(self, spec, length=0):
         if not length:
             length = randrange(self.str_max_length)
@@ -395,19 +449,24 @@ class Faker:
 
         return "".join(map(chr, rv))
 
+    def schema_time(self, cls):
+        # Choose timezone yes/no
+        return (cls, choice([True, False]))
+
     def make_time(self, spec):
         val = randrange(24 * 60 * 60 * 1_000_000)
         val, ms = divmod(val, 1_000_000)
         val, s = divmod(val, 60)
         h, m = divmod(val, 60)
-        return dt.time(h, m, s, ms)
+        tz = self._make_tz(spec) if spec[1] else None
+        return dt.time(h, m, s, ms, tz)
 
     def make_timedelta(self, spec):
         return choice([dt.timedelta.min, dt.timedelta.max]) * random()
 
-    def make_TimeTz(self, spec):
-        rv = self.make_time(spec)
-        return rv.replace(tzinfo=self._make_tz(spec))
+    def schema_tuple(self, cls):
+        length = randrange(1, self.tuple_max_length)
+        return (cls, self.choose_schema(ncols=length))
 
     def make_UUID(self, spec):
         return UUID(bytes=bytes([randrange(256) for i in range(16)]))
@@ -445,18 +504,6 @@ class JsonFloat:
     pass
 
 
-class TimeTz:
-    """
-    Placeholder to create time objects with tzinfo.
-    """
-
-
-class DateTimeTz:
-    """
-    Placeholder to create datetime objects with tzinfo.
-    """
-
-
 def deep_import(name):
     parts = deque(name.split("."))
     seen = []