]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add faker context manager to help found problematic values
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 19 Jul 2021 16:44:41 +0000 (18:44 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 19 Jul 2021 18:25:23 +0000 (20:25 +0200)
tests/fix_faker.py
tests/test_adapt.py
tests/test_copy.py
tests/test_copy_async.py
tests/test_cursor.py
tests/test_cursor_async.py

index d04e8669955a3b1c7a03313b650d9b6a73148ebe..45cccb6aaab4b1b0565d1b192d7ceb8c7f64f00f 100644 (file)
@@ -5,6 +5,7 @@ from math import isnan
 from uuid import UUID
 from random import choice, random, randrange
 from decimal import Decimal
+from contextlib import contextmanager
 from collections import deque
 
 import pytest
@@ -12,6 +13,7 @@ import pytest
 import psycopg
 from psycopg import sql
 from psycopg.adapt import PyFormat
+from psycopg.compat import asynccontextmanager
 from psycopg.types.range import Range
 from psycopg.types.numeric import Int4, Int8
 
@@ -129,12 +131,6 @@ class Faker:
             sql.SQL(", ").join(phs),
         )
 
-    def insert_field_stmt(self, i):
-        ph = sql.Placeholder(format=self.format)
-        return sql.SQL("insert into {} ({}) values ({})").format(
-            self.table_name, self.fields_names[i], ph
-        )
-
     @property
     def select_stmt(self):
         fields = sql.SQL(", ").join(self.fields_names)
@@ -142,6 +138,65 @@ class Faker:
             fields, self.table_name
         )
 
+    @contextmanager
+    def find_insert_problem(self, conn):
+        """Context manager to help finding a problematic vaule."""
+        try:
+            yield
+        except psycopg.DatabaseError:
+            conn.rollback()
+            cur = conn.cursor()
+            # Repeat insert one field at time, until finding the wrong one
+            cur.execute(self.drop_stmt)
+            cur.execute(self.create_stmt)
+            for i, rec in enumerate(self.records):
+                for j, val in enumerate(rec):
+                    try:
+                        cur.execute(self._insert_field_stmt(j), (val,))
+                    except psycopg.DatabaseError as e:
+                        r = repr(val)
+                        if len(r) > 200:
+                            r = f"{r[:200]}... ({len(r)} chars)"
+                        raise Exception(
+                            f"value {r!r} at record {i} column0 {j}"
+                            f" failed insert: {e}"
+                        ) from None
+
+            # just in case, but hopefully we should have triggered the problem
+            raise
+
+    @asynccontextmanager
+    async def find_insert_problem_async(self, aconn):
+        try:
+            yield
+        except psycopg.DatabaseError:
+            await aconn.rollback()
+            acur = aconn.cursor()
+            # Repeat insert one field at time, until finding the wrong one
+            await acur.execute(self.drop_stmt)
+            await acur.execute(self.create_stmt)
+            for i, rec in enumerate(self.records):
+                for j, val in enumerate(rec):
+                    try:
+                        await acur.execute(self._insert_field_stmt(j), (val,))
+                    except psycopg.DatabaseError as e:
+                        r = repr(val)
+                        if len(r) > 200:
+                            r = f"{r[:200]}... ({len(r)} chars)"
+                        raise Exception(
+                            f"value {r!r} at record {i} column0 {j}"
+                            f" failed insert: {e}"
+                        ) from None
+
+            # just in case, but hopefully we should have triggered the problem
+            raise
+
+    def _insert_field_stmt(self, i):
+        ph = sql.Placeholder(format=self.format)
+        return sql.SQL("insert into {} ({}) values ({})").format(
+            self.table_name, self.fields_names[i], ph
+        )
+
     def choose_schema(self, ncols=20):
         schema = []
         while len(schema) < ncols:
index d0ab11c0ffb22c7385825ac872159c814155759b..ee836bffce39bef1a5a8baf9bfbc75391be75758 100644 (file)
@@ -431,19 +431,8 @@ def test_random(conn, faker, fmt, fmt_out):
     with conn.cursor(binary=fmt_out) as cur:
         cur.execute(faker.drop_stmt)
         cur.execute(faker.create_stmt)
-        try:
+        with faker.find_insert_problem(conn):
             cur.executemany(faker.insert_stmt, faker.records)
-        except psycopg.DatabaseError:
-            # Insert one by one to find problematic values
-            conn.rollback()
-            cur.execute(faker.drop_stmt)
-            cur.execute(faker.create_stmt)
-            for rec in faker.records:
-                for i, val in enumerate(rec):
-                    cur.execute(faker.insert_field_stmt(i), (val,))
-
-            # just in case, but hopefully we should have triggered the problem
-            raise
 
         cur.execute(faker.select_stmt)
         recs = cur.fetchall()
index 9017574df3538ee0aceabd779efe2b547fc58195..0f794584e9f696ea93bfe8df5e56b6779686e737 100644 (file)
@@ -505,7 +505,8 @@ def test_copy_to_leaks(dsn, faker, fmt, method, retries):
             with conn.cursor(binary=fmt) as cur:
                 cur.execute(faker.drop_stmt)
                 cur.execute(faker.create_stmt)
-                cur.executemany(faker.insert_stmt, faker.records)
+                with faker.find_insert_problem(conn):
+                    cur.executemany(faker.insert_stmt, faker.records)
 
                 stmt = sql.SQL(
                     "copy (select {} from {} order by id) to stdout (format {})"
index 8a512a73d2dcabe937ac18b2e5cdff8db737838e..7c755ba0e044cc48d44dd78251e7b0f99af4e1ef 100644 (file)
@@ -481,7 +481,8 @@ async def test_copy_to_leaks(dsn, faker, fmt, method, retries):
             async with conn.cursor(binary=fmt) as cur:
                 await cur.execute(faker.drop_stmt)
                 await cur.execute(faker.create_stmt)
-                await cur.executemany(faker.insert_stmt, faker.records)
+                async with faker.find_insert_problem_async(conn):
+                    await cur.executemany(faker.insert_stmt, faker.records)
 
                 stmt = sql.SQL(
                     "copy (select {} from {} order by id) to stdout (format {})"
index af5ffdc9037a818d3650673e36310c08a148f01d..c008a530c9f755976da2475b7a721dbe8021906b 100644 (file)
@@ -549,42 +549,40 @@ def test_leak(dsn, faker, fmt, fmt_out, fetch, row_factory, retries):
     faker.make_records(10)
     row_factory = getattr(rows, row_factory)
 
+    def work():
+        with psycopg.connect(dsn) as conn:
+            with conn.cursor(binary=fmt_out, row_factory=row_factory) as cur:
+                cur.execute(faker.drop_stmt)
+                cur.execute(faker.create_stmt)
+                with faker.find_insert_problem(conn):
+                    cur.executemany(faker.insert_stmt, faker.records)
+
+                cur.execute(faker.select_stmt)
+
+                if fetch == "one":
+                    while 1:
+                        tmp = cur.fetchone()
+                        if tmp is None:
+                            break
+                elif fetch == "many":
+                    while 1:
+                        tmp = cur.fetchmany(3)
+                        if not tmp:
+                            break
+                elif fetch == "all":
+                    cur.fetchall()
+                elif fetch == "iter":
+                    for rec in cur:
+                        pass
+
     for retry in retries:
         with retry:
             n = []
             gc_collect()
             for i in range(3):
-                with psycopg.connect(dsn) as conn:
-                    with conn.cursor(
-                        binary=fmt_out, row_factory=row_factory
-                    ) as cur:
-                        cur.execute(faker.drop_stmt)
-                        cur.execute(faker.create_stmt)
-                        cur.executemany(faker.insert_stmt, faker.records)
-                        cur.execute(faker.select_stmt)
-
-                        if fetch == "one":
-                            while 1:
-                                tmp = cur.fetchone()
-                                if tmp is None:
-                                    break
-                        elif fetch == "many":
-                            while 1:
-                                tmp = cur.fetchmany(3)
-                                if not tmp:
-                                    break
-                        elif fetch == "all":
-                            cur.fetchall()
-                        elif fetch == "iter":
-                            for rec in cur:
-                                pass
-
-                        tmp = None
-
-                del cur, conn
+                work()
                 gc_collect()
                 n.append(len(gc.get_objects()))
-
             assert (
                 n[0] == n[1] == n[2]
             ), f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
index 7d9dbdd9a0f05d5f28532fbca4a9e2a7641b1a82..f4f35c38ad123b73275562031cee193c663494a4 100644 (file)
@@ -463,39 +463,39 @@ async def test_leak(dsn, faker, fmt, fmt_out, fetch, row_factory, retries):
     faker.make_records(10)
     row_factory = getattr(rows, row_factory)
 
+    async def work():
+        async with await psycopg.AsyncConnection.connect(dsn) as conn:
+            async with conn.cursor(
+                binary=fmt_out, row_factory=row_factory
+            ) as cur:
+                await cur.execute(faker.drop_stmt)
+                await cur.execute(faker.create_stmt)
+                async with faker.find_insert_problem_async(conn):
+                    await cur.executemany(faker.insert_stmt, faker.records)
+                await cur.execute(faker.select_stmt)
+
+                if fetch == "one":
+                    while 1:
+                        tmp = await cur.fetchone()
+                        if tmp is None:
+                            break
+                elif fetch == "many":
+                    while 1:
+                        tmp = await cur.fetchmany(3)
+                        if not tmp:
+                            break
+                elif fetch == "all":
+                    await cur.fetchall()
+                elif fetch == "iter":
+                    async for rec in cur:
+                        pass
+
     async for retry in retries:
         with retry:
             n = []
             gc_collect()
             for i in range(3):
-                async with await psycopg.AsyncConnection.connect(dsn) as conn:
-                    async with conn.cursor(
-                        binary=fmt_out, row_factory=row_factory
-                    ) as cur:
-                        await cur.execute(faker.drop_stmt)
-                        await cur.execute(faker.create_stmt)
-                        await cur.executemany(faker.insert_stmt, faker.records)
-                        await cur.execute(faker.select_stmt)
-
-                        if fetch == "one":
-                            while 1:
-                                tmp = await cur.fetchone()
-                                if tmp is None:
-                                    break
-                        elif fetch == "many":
-                            while 1:
-                                tmp = await cur.fetchmany(3)
-                                if not tmp:
-                                    break
-                        elif fetch == "all":
-                            await cur.fetchall()
-                        elif fetch == "iter":
-                            async for rec in cur:
-                                pass
-
-                        tmp = None
-
-                del cur, conn
+                await work()
                 gc_collect()
                 n.append(len(gc.get_objects()))