]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix missing assert in test
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 21 Sep 2021 14:53:08 +0000 (15:53 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 21 Sep 2021 17:00:56 +0000 (18:00 +0100)
Add acopy() helper to write async tests more similar to sync ones

tests/test_copy_async.py

index 5fe8e1980f87eb211f26d1a28cc0241d671e76c4..6322f2289f2eb3d5a7b72efe97d55791c3977bc9 100644 (file)
@@ -61,7 +61,7 @@ async def test_copy_out_iter(aconn, format):
     async with cur.copy(
         f"copy ({sample_values}) to stdout (format {format.name})"
     ) as copy:
-        [row async for row in copy] == want
+        assert await alist(copy) == want
 
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
 
@@ -90,7 +90,7 @@ async def test_rows(aconn, format):
         f"copy ({sample_values}) to stdout (format {format.name})"
     ) as copy:
         copy.set_types("int4 int4 text".split())
-        rows = [row async for row in copy.rows()]
+        rows = await alist(copy.rows())
 
     assert rows == sample_records
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
@@ -101,14 +101,14 @@ async def test_set_custom_type(aconn, hstore):
     cur = aconn.cursor()
 
     async with cur.copy(command) as copy:
-        rows = [row async for row in copy.rows()]
+        rows = await alist(copy.rows())
 
     assert rows == [('"a"=>"1", "b"=>"2"',)]
 
     register_hstore(await TypeInfo.fetch(aconn, "hstore"), cur)
     async with cur.copy(command) as copy:
         copy.set_types(["hstore"])
-        rows = [row async for row in copy.rows()]
+        rows = await alist(copy.rows())
 
     assert rows == [({"a": "1", "b": "2"},)]
 
@@ -160,7 +160,7 @@ async def test_rows_notypes(aconn, format):
     async with cur.copy(
         f"copy ({sample_values}) to stdout (format {format.name})"
     ) as copy:
-        rows = [row async for row in copy.rows()]
+        rows = await alist(copy.rows())
     ref = [
         tuple(py_to_raw(i, format) for i in record)
         for record in sample_records
@@ -226,7 +226,7 @@ async def test_copy_bad_result(aconn):
 
     with pytest.raises(e.ProgrammingError):
         async with cur.copy("copy (select 1) to stdout; select 1") as copy:
-            [_ async for _ in copy]
+            await alist(copy)
 
     with pytest.raises(e.ProgrammingError):
         async with cur.copy("select 1; copy (select 1) to stdout"):
@@ -487,28 +487,24 @@ async def test_copy_query(aconn):
     async with cur.copy("copy (select 1) to stdout") as copy:
         assert cur._query.query == b"copy (select 1) to stdout"
         assert not cur._query.params
-        async for record in copy:
-            pass
+        await alist(copy)
 
 
 async def test_cant_reenter(aconn):
     cur = aconn.cursor()
     async with cur.copy("copy (select 1) to stdout") as copy:
-        async for record in copy:
-            pass
+        await alist(copy)
 
     with pytest.raises(TypeError):
         async with copy:
-            async for record in copy:
-                pass
+            await alist(copy)
 
 
 async def test_str(aconn):
     cur = aconn.cursor()
     async with cur.copy("copy (select 1) to stdout") as copy:
         assert "[ACTIVE]" in str(copy)
-        async for record in copy:
-            pass
+        await alist(copy)
 
     assert "[INTRANS]" in str(copy)
 
@@ -570,16 +566,14 @@ async def test_copy_to_leaks(dsn, faker, fmt, method, retries):
                             if not tmp:
                                 break
                     elif method == "iter":
-                        async for x in copy:
-                            pass
+                        await alist(copy)
                     elif method == "row":
                         while 1:
                             tmp = await copy.read_row()
                             if tmp is None:
                                 break
                     elif method == "rows":
-                        async for x in copy.rows():
-                            pass
+                        await alist(copy.rows())
 
     gc_collect()
     async for retry in retries:
@@ -693,3 +687,7 @@ class DataGenerator:
                 block = block.encode()
             m.update(block)
         return m.hexdigest()
+
+
+async def alist(it):
+    return [i async for i in it]