]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
test(crdb): fix random tests
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 6 Jun 2022 21:57:31 +0000 (23:57 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jul 2022 11:58:34 +0000 (12:58 +0100)
- Keep into account the different supported interval range
- Allow for float representation rounding (pre-PG 12 behaviour)
- Run leak tests in transactions to allow to recreate the tables

tests/fix_faker.py
tests/test_client_cursor.py
tests/test_client_cursor_async.py
tests/test_cursor.py
tests/test_cursor_async.py

index d1603219c65061beb31b500c0f05f77ad8323d30..b076b374d37a04bfb271d8fa9f91e75025fcef82 100644 (file)
@@ -137,9 +137,9 @@ class Faker:
     def find_insert_problem(self, conn):
         """Context manager to help finding a problematic value."""
         try:
-            yield
+            with conn.transaction():
+                yield
         except psycopg.DatabaseError:
-            conn.rollback()
             cur = conn.cursor()
             # Repeat insert one field at time, until finding the wrong one
             cur.execute(self.drop_stmt)
@@ -162,9 +162,9 @@ class Faker:
     @asynccontextmanager
     async def find_insert_problem_async(self, aconn):
         try:
-            yield
+            async with aconn.transaction():
+                yield
         except psycopg.DatabaseError:
-            await aconn.rollback()
             acur = aconn.cursor()
             # Repeat insert one field at time, until finding the wrong one
             await acur.execute(self.drop_stmt)
@@ -390,14 +390,21 @@ class Faker:
     def match_float(self, spec, got, want, approx=False, rel=None):
         if got is not None and isnan(got):
             assert isnan(want)
+        else:
+            if approx or self._server_rounds():
+                assert got == pytest.approx(want, rel=rel)
+            else:
+                assert got == want
+
+    def _server_rounds(self):
+        """Return True if the connected server perform float rounding"""
+        if self.conn.info.vendor == "CockroachDB":
+            return True
         else:
             # Versions older than 12 make some rounding. e.g. in Postgres 10.4
             # select '-1.409006204063909e+112'::float8
             #      -> -1.40900620406391e+112
-            if not approx and self.conn.info.server_version >= 120000:
-                assert got == want
-            else:
-                assert got == pytest.approx(want, rel=rel)
+            return self.conn.info.server_version < 120000
 
     def make_Float4(self, spec):
         return spec(self.make_float(spec, double=False))
@@ -759,8 +766,15 @@ class Faker:
         tz = self._make_tz(spec) if spec[1] else None
         return dt.time(h, m, s, ms, tz)
 
+    CRDB_TIMEDELTA_MAX = dt.timedelta(days=1281239)
+
     def make_timedelta(self, spec):
-        return choice([dt.timedelta.min, dt.timedelta.max]) * random()
+        if self.conn.info.vendor == "CockroachDB":
+            rng = [-self.CRDB_TIMEDELTA_MAX, self.CRDB_TIMEDELTA_MAX]
+        else:
+            rng = [dt.timedelta.min, dt.timedelta.max]
+
+        return choice(rng) * random()
 
     def schema_tuple(self, cls):
         # TODO: this is a complicated matter as it would involve creating
index cc6944233f532ae8969d0f59987df8b5e4edd309..bc28df41f03a5ae49c6e40d78ba49c634cb73240 100644 (file)
@@ -761,7 +761,7 @@ def test_leak(dsn, faker, fetch, row_factory):
     row_factory = getattr(rows, row_factory)
 
     def work():
-        with psycopg.connect(dsn) as conn:
+        with psycopg.connect(dsn) as conn, conn.transaction(force_rollback=True):
             with psycopg.ClientCursor(conn, row_factory=row_factory) as cur:
                 cur.execute(faker.drop_stmt)
                 cur.execute(faker.create_stmt)
index 20393a617acdd358f1d97a006257f106985345e1..56d744fd42069b865d6875d00ec1ddf654dbc816 100644 (file)
@@ -631,7 +631,9 @@ async def test_leak(dsn, faker, fetch, row_factory):
     row_factory = getattr(rows, row_factory)
 
     async def work():
-        async with await psycopg.AsyncConnection.connect(dsn) as conn:
+        async with await psycopg.AsyncConnection.connect(dsn) as conn, conn.transaction(
+            force_rollback=True
+        ):
             async with psycopg.AsyncClientCursor(conn, row_factory=row_factory) as cur:
                 await cur.execute(faker.drop_stmt)
                 await cur.execute(faker.create_stmt)
index 29ee2ebbd6486fa404d58dea66e151f3f9b6b3b3..18b561c7a82b84df107b228bee7841ae6472aad9 100644 (file)
@@ -840,7 +840,7 @@ def test_leak(dsn, faker, fmt, fmt_out, fetch, row_factory):
     row_factory = getattr(rows, row_factory)
 
     def work():
-        with psycopg.connect(dsn) as conn:
+        with psycopg.connect(dsn) as conn, conn.transaction(force_rollback=True):
             with conn.cursor(binary=fmt_out, row_factory=row_factory) as cur:
                 cur.execute(faker.drop_stmt)
                 cur.execute(faker.create_stmt)
index 2161a30d73175b5c669e8672ce58d4f9e6231db7..c2519edbe3941f2ae31b82638a90fdeb081c6f1f 100644 (file)
@@ -710,7 +710,9 @@ async def test_leak(dsn, faker, fmt, fmt_out, fetch, row_factory):
     row_factory = getattr(rows, row_factory)
 
     async def work():
-        async with await psycopg.AsyncConnection.connect(dsn) as conn:
+        async with await psycopg.AsyncConnection.connect(dsn) as conn, conn.transaction(
+            force_rollback=True
+        ):
             async with conn.cursor(binary=fmt_out, row_factory=row_factory) as cur:
                 await cur.execute(faker.drop_stmt)
                 await cur.execute(faker.create_stmt)