]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added reference leak test for other fetch methods
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jan 2021 22:47:39 +0000 (23:47 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 14 Jan 2021 15:08:48 +0000 (16:08 +0100)
tests/fix_faker.py
tests/test_cursor.py

index cb982345548cd7d45ba677d0c7ae795f6fa1f988..d6928ae8615c7a252ae35407aa7afe596d61e99f 100644 (file)
@@ -106,12 +106,12 @@ class Faker:
             fields, self.table_name
         )
 
-    def choose_schema(self, types=None, nfields=20):
+    def choose_schema(self, types=None, ncols=20):
         if not types:
             types = self.get_supported_types()
 
         types_list = sorted(types, key=lambda cls: cls.__name__)
-        schema = [choice(types_list) for i in range(nfields)]
+        schema = [choice(types_list) for i in range(ncols)]
         for i, cls in enumerate(schema):
             # choose the type of the array
             if cls is list:
@@ -121,9 +121,7 @@ class Faker:
                         break
                 schema[i] = [scls]
             elif cls is tuple:
-                schema[i] = tuple(
-                    self.choose_schema(types=types, nfields=nfields)
-                )
+                schema[i] = tuple(self.choose_schema(types=types, ncols=ncols))
 
         return schema
 
index c32dff95abcff228f8117683c0cec2c810709737..d9869f89196768b14202f1f56f8ae601e34937d3 100644 (file)
@@ -406,13 +406,14 @@ def test_str(conn):
 
 @pytest.mark.slow
 @pytest.mark.parametrize("fmt", [Format.AUTO, Format.TEXT, Format.BINARY])
-def test_leak_fetchall(dsn, faker, fmt):
+@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
+def test_leak(dsn, faker, fmt, fetch):
     if fmt != Format.BINARY:
         pytest.xfail("faker to extend to all text dumpers")
 
     faker.format = fmt
-    faker.choose_schema()
-    faker.make_records(100)
+    faker.choose_schema(ncols=5)
+    faker.make_records(10)
 
     n = []
     for i in range(3):
@@ -422,8 +423,31 @@ def test_leak_fetchall(dsn, faker, fmt):
                 cur.execute(faker.create_stmt)
                 cur.executemany(faker.insert_stmt, faker.records)
                 cur.execute(faker.select_stmt)
-                for got, want in zip(cur.fetchall(), faker.records):
+
+                recs = []
+                if fetch == "one":
+                    while 1:
+                        tmp = cur.fetchone()
+                        if tmp is None:
+                            break
+                        recs.append(tmp)
+                elif fetch == "many":
+                    while 1:
+                        tmp = cur.fetchmany(3)
+                        if not tmp:
+                            break
+                        recs.extend(tmp)
+                elif fetch == "all":
+                    recs.extend(cur.fetchall())
+                elif fetch == "iter":
+                    for rec in cur:
+                        recs.append(rec)
+
+                for got, want in zip(recs, faker.records):
                     faker.assert_record(got, want)
+
+                recs = tmp = None
+
         del cur, conn
         gc.collect()
         gc.collect()