]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Allow random testing with text format
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 22 Jun 2021 17:00:43 +0000 (18:00 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 25 Jun 2021 15:16:26 +0000 (16:16 +0100)
Ranges are a bit complicated because upgrading empty ones only works in
text mode. This should be addressed, but hopefully it only affects
binary copy for now.

tests/fix_faker.py
tests/test_adapt.py
tests/test_cursor.py
tests/test_cursor_async.py
tests/types/test_range.py

index fb1c7bf74442b33b8ec83d266354e5e13976f31b..9c20e3689ce2cf0b111ba2cd57bb0aed638925bd 100644 (file)
@@ -49,8 +49,6 @@ class Faker:
 
     @format.setter
     def format(self, format):
-        if format != Format.BINARY:
-            pytest.xfail("faker to extend to all text dumpers")
         self._format = format
 
     @property
@@ -83,20 +81,27 @@ class Faker:
 
         record = self.make_record(nulls=0)
         tx = psycopg3.adapt.Transformer(self.conn)
-        types = []
-        registry = self.conn.adapters.types
-        for value in record:
-            dumper = tx.get_dumper(value, self.format)
-            dumper.dump(value)  # load the oid if it's dynamic (e.g. array)
-            info = registry.get(dumper.oid) or registry.get("text")
-            if dumper.oid == info.array_oid:
-                types.append(sql.SQL("{}[]").format(sql.Identifier(info.name)))
-            else:
-                types.append(sql.Identifier(info.name))
-
+        types = [
+            self._get_type_name(tx, schema, value)
+            for schema, value in zip(self.schema, record)
+        ]
         self._types_names = types
         return types
 
+    def _get_type_name(self, tx, schema, value):
+        # Special case it as it is passed as unknown so is returned as text
+        if schema == (list, str):
+            return sql.SQL("text[]")
+
+        registry = self.conn.adapters.types
+        dumper = tx.get_dumper(value, self.format)
+        dumper.dump(value)  # load the oid if it's dynamic (e.g. array)
+        info = registry.get(dumper.oid) or registry.get("text")
+        if dumper.oid == info.array_oid:
+            return sql.SQL("{}[]").format(sql.Identifier(info.name))
+        else:
+            return sql.Identifier(info.name)
+
     @property
     def drop_stmt(self):
         return sql.SQL("drop table if exists {}").format(self.table_name)
@@ -138,7 +143,11 @@ class Faker:
         )
 
     def choose_schema(self, ncols=20):
-        schema = [self.make_schema(choice(self.types)) for i in range(ncols)]
+        schema = []
+        while len(schema) < ncols:
+            s = self.make_schema(choice(self.types))
+            if s is not None:
+                schema.append(s)
         return schema
 
     def make_records(self, nrecords):
@@ -184,6 +193,8 @@ class Faker:
         A schema for a type is represented by a tuple (type, ...) which the
         matching make_*() method can interpret, or just type if the type
         doesn't require further specification.
+
+        A `None` means that the type is not supported.
         """
         meth = self._get_method("schema", cls)
         return meth(cls) if meth else cls
@@ -321,6 +332,9 @@ class Faker:
     def make_Int8(self, spec):
         return spec(randrange(-(1 << 63), 1 << 63))
 
+    def make_IntNumeric(self, spec):
+        return spec(randrange(-(1 << 100), 1 << 100))
+
     def make_IPv4Address(self, spec):
         return ipaddress.IPv4Address(bytes(randrange(256) for _ in range(4)))
 
@@ -367,12 +381,15 @@ class Faker:
         )
 
     def schema_list(self, cls):
-        while 1:
+        while True:
             scls = choice(self.types)
-            if scls is not cls:
+            if scls is cls:
+                continue
+            schema = self.make_schema(scls)
+            if schema is not None:
                 break
 
-        return (cls, self.make_schema(scls))
+        return (cls, schema)
 
     def make_list(self, spec):
         # don't make empty lists because they regularly fail cast
@@ -389,6 +406,9 @@ class Faker:
     def make_memoryview(self, spec):
         return self.make_bytes(spec)
 
+    def schema_NoneType(self, spec):
+        return None
+
     def make_NoneType(self, spec):
         return None
 
@@ -397,37 +417,50 @@ class Faker:
 
     def schema_Range(self, cls):
         subtypes = [
-            Int4,
-            Int8,
             Decimal,
             dt.date,
             (dt.datetime, True),
             (dt.datetime, False),
         ]
+        # TODO: learn to dump numeric ranges in binary
+        if self.format != Format.BINARY:
+            subtypes.extend([Int4, Int8])
+
         return (cls, choice(subtypes))
 
     def make_Range(self, spec):
-        if random() < 0.02:
+        # TODO: drop format check after fixing binary dumping of empty ranges
+        if random() < 0.02 and self.format == Format.TEXT:
             return Range(empty=True)
 
-        bounds = []
-        while len(bounds) < 2:
-            if random() < 0.05:
-                bounds.append(None)
-                continue
+        while True:
+            bounds = []
+            while len(bounds) < 2:
+                if random() < 0.05:
+                    bounds.append(None)
+                    continue
 
-            val = self.make(spec[1])
-            # NaN are allowed in a range, but comparison in Python get tricky.
-            if spec[1] is Decimal and val.is_nan():
-                continue
+                val = self.make(spec[1])
+                # NaN are allowed in a range, but comparison in Python get tricky.
+                if spec[1] is Decimal and val.is_nan():
+                    continue
+
+                bounds.append(val)
 
-            bounds.append(val)
+            if bounds[0] is not None and bounds[1] is not None:
+                if bounds[0] > bounds[1]:
+                    bounds.reverse()
 
-        if bounds[0] is not None and bounds[1] is not None:
-            if bounds[0] > bounds[1]:
-                bounds.reverse()
+            # avoid generating ranges with no type info if dumping in binary
+            # TODO: lift this limitation after test_copy_in_empty xfail is fixed
+            if self.format == Format.BINARY:
+                if bounds[0] is bounds[1] is None:
+                    continue
 
-        return Range(bounds[0], bounds[1], choice("[(") + choice("])"))
+            break
+
+        r = Range(bounds[0], bounds[1], choice("[(") + choice("])"))
+        return r
 
     def match_Range(self, spec, got, want):
         # normalise the bounds of unbounded ranges
@@ -465,8 +498,23 @@ class Faker:
         return choice([dt.timedelta.min, dt.timedelta.max]) * random()
 
     def schema_tuple(self, cls):
-        length = randrange(1, self.tuple_max_length)
-        return (cls, self.choose_schema(ncols=length))
+        # TODO: this is a complicated matter as it would involve creating
+        # temporary composite types.
+        # length = randrange(1, self.tuple_max_length)
+        # return (cls, self.choose_schema(ncols=length))
+        return None
+
+    def make_tuple(self, spec):
+        return tuple(self.make(s) for s in spec[1])
+
+    def match_tuple(self, spec, got, want):
+        assert len(got) == len(want) == len(spec[1])
+        for g, w, s in zip(got, want, spec):
+            if g is None or w is None:
+                assert g is w
+            else:
+                m = self.get_matcher(s)
+                m(s, g, w)
 
     def make_UUID(self, spec):
         return UUID(bytes=bytes([randrange(256) for i in range(16)]))
index 69e05b0747640591c536d74c1951f54b1e3d100c..37452156818b2af6e4349593afe13d0c68b72f82 100644 (file)
@@ -345,15 +345,29 @@ def test_optimised_adapters():
 
 @pytest.mark.slow
 @pytest.mark.parametrize("fmt", [Format.AUTO, Format.TEXT, Format.BINARY])
-def test_random(conn, faker, fmt):
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_random(conn, faker, fmt, fmt_out):
     faker.format = fmt
     faker.choose_schema(ncols=20)
     faker.make_records(50)
 
-    with conn.cursor(binary=Format.as_pq(fmt)) as cur:
+    with conn.cursor(binary=fmt_out) as cur:
         cur.execute(faker.drop_stmt)
         cur.execute(faker.create_stmt)
-        cur.executemany(faker.insert_stmt, faker.records)
+        try:
+            cur.executemany(faker.insert_stmt, faker.records)
+        except psycopg3.DatabaseError:
+            # Insert one by one to find problematic values
+            conn.rollback()
+            cur.execute(faker.drop_stmt)
+            cur.execute(faker.create_stmt)
+            for rec in faker.records:
+                for i, val in enumerate(rec):
+                    cur.execute(faker.insert_field_stmt(i), (val,))
+
+            # just in case, but hopefully we should have triggered the problem
+            raise
+
         cur.execute(faker.select_stmt)
         recs = cur.fetchall()
 
index cb89f849998b19834177d2c632fd4b0ef8c221e6..e3caa85a1441bbb018149f34b3efe4e6f0b44205 100644 (file)
@@ -6,7 +6,7 @@ import datetime as dt
 import pytest
 
 import psycopg3
-from psycopg3 import sql, rows
+from psycopg3 import pq, sql, rows
 from psycopg3.oids import postgres_types as builtins
 from psycopg3.adapt import Format
 
@@ -538,11 +538,12 @@ def test_str(conn):
 
 @pytest.mark.slow
 @pytest.mark.parametrize("fmt", [Format.AUTO, Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
 @pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
 @pytest.mark.parametrize(
     "row_factory", ["tuple_row", "dict_row", "namedtuple_row"]
 )
-def test_leak(dsn, faker, fmt, fetch, row_factory):
+def test_leak(dsn, faker, fmt, fmt_out, fetch, row_factory):
     faker.format = fmt
     faker.choose_schema(ncols=5)
     faker.make_records(10)
@@ -551,9 +552,7 @@ def test_leak(dsn, faker, fmt, fetch, row_factory):
     n = []
     for i in range(3):
         with psycopg3.connect(dsn) as conn:
-            with conn.cursor(
-                binary=fmt == Format.BINARY, row_factory=row_factory
-            ) as cur:
+            with conn.cursor(binary=fmt_out, row_factory=row_factory) as cur:
                 cur.execute(faker.drop_stmt)
                 cur.execute(faker.create_stmt)
                 cur.executemany(faker.insert_stmt, faker.records)
index 0a507dda177a69e6167a521ac7814d2f83d5e3a9..c915231964130796a732accc7cb584ca519e8a22 100644 (file)
@@ -4,7 +4,7 @@ import weakref
 import datetime as dt
 
 import psycopg3
-from psycopg3 import sql, rows
+from psycopg3 import pq, sql, rows
 from psycopg3.adapt import Format
 
 from .utils import gc_collect
@@ -452,11 +452,12 @@ async def test_str(aconn):
 
 @pytest.mark.slow
 @pytest.mark.parametrize("fmt", [Format.AUTO, Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
 @pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
 @pytest.mark.parametrize(
     "row_factory", ["tuple_row", "dict_row", "namedtuple_row"]
 )
-async def test_leak(dsn, faker, fmt, fetch, row_factory):
+async def test_leak(dsn, faker, fmt, fmt_out, fetch, row_factory):
     faker.format = fmt
     faker.choose_schema(ncols=5)
     faker.make_records(10)
@@ -466,7 +467,7 @@ async def test_leak(dsn, faker, fmt, fetch, row_factory):
     for i in range(3):
         async with await psycopg3.AsyncConnection.connect(dsn) as conn:
             async with conn.cursor(
-                binary=fmt == Format.BINARY, row_factory=row_factory
+                binary=fmt_out, row_factory=row_factory
             ) as cur:
                 await cur.execute(faker.drop_stmt)
                 await cur.execute(faker.create_stmt)
index 8f624b65617c2e2bcadcf7f4bbd3d84d86dc6826..884867a1e1e5b8cbd4c068c71a9d6c56063e3e57 100644 (file)
@@ -4,6 +4,7 @@ from decimal import Decimal
 
 import pytest
 
+import psycopg3.errors
 from psycopg3 import pq
 from psycopg3.sql import Identifier
 from psycopg3.adapt import Format
@@ -148,6 +149,45 @@ def test_load_builtin_range(conn, pgtype, min, max, bounds, fmt_out):
     assert cur.fetchone()[0] == r
 
 
+@pytest.mark.parametrize(
+    "min, max, bounds",
+    [
+        ("2000,1,1", "2001,1,1", "[)"),
+        ("2000,1,1", None, "[)"),
+        (None, "2001,1,1", "()"),
+        (None, None, "()"),
+        (None, None, "empty"),
+    ],
+)
+@pytest.mark.parametrize("format", [pq.Format.TEXT, pq.Format.BINARY])
+def test_copy_in_empty(conn, min, max, bounds, format):
+    cur = conn.cursor()
+    cur.execute("create table copyrange (id serial primary key, r daterange)")
+
+    if bounds != "empty":
+        min = dt.date(*map(int, min.split(","))) if min else None
+        max = dt.date(*map(int, max.split(","))) if max else None
+        r = Range(min, max, bounds)
+    else:
+        r = Range(empty=True)
+
+    try:
+        with cur.copy(
+            f"copy copyrange (r) from stdin (format {format.name})"
+        ) as copy:
+            copy.write_row([r])
+    except psycopg3.errors.ProtocolViolation:
+        if not min and not max and format == pq.Format.BINARY:
+            pytest.xfail(
+                "TODO: add annotation to dump array with no type info"
+            )
+        else:
+            raise
+
+    rec = cur.execute("select r from copyrange order by id").fetchone()
+    assert rec[0] == r
+
+
 @pytest.fixture(scope="session")
 def testrange(svcconn):
     svcconn.execute(