import psycopg3
from psycopg3 import sql
from psycopg3.adapt import Format
+from psycopg3.types.range import Range
+from psycopg3.wrappers.numeric import Int4, Int8
@pytest.fixture
json_max_length = 10
str_max_length = 100
list_max_length = 20
+ tuple_max_length = 15
def __init__(self, connection):
self.conn = connection
self.records = []
self._schema = None
+ self._types = None
self._types_names = None
self._makers = {}
self.table_name = sql.Identifier("fake_table")
def fields_names(self):
return [sql.Identifier(f"fld_{i}") for i in range(len(self.schema))]
+ @property
+ def types(self):
+ if not self._types:
+ self._types = sorted(
+ self.get_supported_types(), key=lambda cls: cls.__name__
+ )
+ return self._types
+
@property
def types_names(self):
if self._types_names:
fields, self.table_name
)
- def choose_schema(self, types=None, ncols=20):
- if not types:
- types = self.get_supported_types()
-
- types_list = sorted(types, key=lambda cls: cls.__name__)
- schema = [choice(types_list) for i in range(ncols)]
- for i, cls in enumerate(schema):
- # choose the type of the array
- if cls is list:
- while 1:
- scls = choice(types_list)
- if scls is not list:
- break
- schema[i] = [scls]
- elif cls is tuple:
- schema[i] = tuple(self.choose_schema(types=types, ncols=ncols))
- # Pick timezone yes/no
- elif cls is dt.time:
- if choice([True, False]):
- schema[i] = TimeTz
- elif cls is dt.datetime:
- if choice([True, False]):
- schema[i] = DateTimeTz
-
+ def choose_schema(self, ncols=20):
+ schema = [self.make_schema(choice(self.types)) for i in range(ncols)]
return schema
def make_records(self, nrecords):
return rv
+ def make_schema(self, cls):
+ """Create a schema spec from a Python type.
+
+ A schema specifies what Postgres type to generate when a Python type
+ maps to more than one (e.g. tuple -> composite, list -> array[],
+ datetime -> timestamp[tz]).
+
+ A schema for a type is represented by a tuple (type, ...) which the
+ matching make_*() method can interpret, or just type if the type
+ doesn't require further specification.
+ """
+ meth = self._get_method("schema", cls)
+ return meth(cls) if meth else cls
+
def get_maker(self, spec):
- # convert a list or tuple into list or tuple
- cls = spec if isinstance(spec, type) else type(spec)
+ cls = spec if isinstance(spec, type) else spec[0]
try:
return self._makers[cls]
)
def get_matcher(self, spec):
- # convert a list or tuple into list or tuple
- cls = spec if isinstance(spec, type) else type(spec)
+ cls = spec if isinstance(spec, type) else spec[0]
meth = self._get_method("match", cls)
return meth if meth else self.match_any
return None
def make(self, spec):
- # spec can be a type or a list [type] or a tuple (spec, spec, ...)
+ # spec can be a type or a tuple (type, options)
return self.get_maker(spec)(spec)
def match_any(self, spec, got, want):
day = randrange(dt.date.max.toordinal())
return dt.date.fromordinal(day + 1)
+ def schema_datetime(self, cls):
+ return self.schema_time(cls)
+
def make_datetime(self, spec):
delta = dt.datetime.max - dt.datetime.min
micros = randrange((delta.days + 1) * 24 * 60 * 60 * 1_000_000)
- return dt.datetime.min + dt.timedelta(microseconds=micros)
-
- def make_DateTimeTz(self, spec):
- rv = self.make_datetime(spec)
- return rv.replace(tzinfo=self._make_tz(spec))
+ rv = dt.datetime.min + dt.timedelta(microseconds=micros)
+ if spec[1]:
+ rv = rv.replace(tzinfo=self._make_tz(spec))
+ return rv
def make_Decimal(self, spec):
if random() >= 0.99:
f"{choice('-+')}0.{randrange(1 << 20)}e{randrange(-15,15)}"
)
+ def schema_list(self, cls):
+ while 1:
+ scls = choice(self.types)
+ if scls is not cls:
+ break
+
+ return (cls, self.make_schema(scls))
+
def make_list(self, spec):
# don't make empty lists because they regularly fail cast
length = randrange(1, self.list_max_length)
- spec = spec[0]
+ spec = spec[1]
return [self.make(spec) for i in range(length)]
def match_list(self, spec, got, want):
assert len(got) == len(want)
- m = self.get_matcher(spec[0])
+ m = self.get_matcher(spec[1])
for g, w in zip(got, want):
m(spec, g, w)
def make_Oid(self, spec):
return spec(randrange(1 << 32))
+ def schema_Range(self, cls):
+ subtypes = [
+ Int4,
+ Int8,
+ Decimal,
+ dt.date,
+ (dt.datetime, True),
+ (dt.datetime, False),
+ ]
+ return (cls, choice(subtypes))
+
+ def make_Range(self, spec):
+ if random() < 0.02:
+ return Range(empty=True)
+
+ bounds = []
+ while len(bounds) < 2:
+ if random() < 0.05:
+ bounds.append(None)
+ continue
+
+ val = self.make(spec[1])
+ # NaN are allowed in a range, but comparison in Python get tricky.
+ if spec[1] is Decimal and val.is_nan():
+ continue
+
+ bounds.append(val)
+
+ if bounds[0] is not None and bounds[1] is not None:
+ if bounds[0] > bounds[1]:
+ bounds.reverse()
+
+ return Range(bounds[0], bounds[1], choice("[(") + choice("])"))
+
+ def match_Range(self, spec, got, want):
+ # normalise the bounds of unbounded ranges
+ if want.lower is None and want.lower_inc:
+ want = type(want)(want.lower, want.upper, "(" + want.bounds[1])
+ if want.upper is None and want.upper_inc:
+ want = type(want)(want.lower, want.upper, want.bounds[0] + ")")
+ return got == want
+
def make_str(self, spec, length=0):
if not length:
length = randrange(self.str_max_length)
return "".join(map(chr, rv))
+ def schema_time(self, cls):
+ # Choose timezone yes/no
+ return (cls, choice([True, False]))
+
def make_time(self, spec):
val = randrange(24 * 60 * 60 * 1_000_000)
val, ms = divmod(val, 1_000_000)
val, s = divmod(val, 60)
h, m = divmod(val, 60)
- return dt.time(h, m, s, ms)
+ tz = self._make_tz(spec) if spec[1] else None
+ return dt.time(h, m, s, ms, tz)
def make_timedelta(self, spec):
return choice([dt.timedelta.min, dt.timedelta.max]) * random()
- def make_TimeTz(self, spec):
- rv = self.make_time(spec)
- return rv.replace(tzinfo=self._make_tz(spec))
+ def schema_tuple(self, cls):
+ length = randrange(1, self.tuple_max_length)
+ return (cls, self.choose_schema(ncols=length))
def make_UUID(self, spec):
return UUID(bytes=bytes([randrange(256) for i in range(16)]))
pass
-class TimeTz:
- """
- Placeholder to create time objects with tzinfo.
- """
-
-
-class DateTimeTz:
- """
- Placeholder to create datetime objects with tzinfo.
- """
-
-
def deep_import(name):
parts = deque(name.split("."))
seen = []