]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(tests): generate test_pipeline from async counterpart
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 10 Aug 2023 00:07:08 +0000 (01:07 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
tests/test_pipeline.py
tests/test_pipeline_async.py
tools/async_to_sync.py
tools/convert_async_to_sync.sh

index 2de7fabb3ff2deafafb1af33d0c0c63324380019..35028795fbc42a766fa441cad8357bec978414d5 100644 (file)
@@ -1,5 +1,7 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'test_pipeline_async.py'
+# DO NOT CHANGE! Change the original file instead.
 import logging
-import concurrent.futures
 from typing import Any
 from operator import attrgetter
 from itertools import groupby
@@ -10,17 +12,19 @@ import psycopg
 from psycopg import pq
 from psycopg import errors as e
 
+from .utils import is_async
+
 pytestmark = [
     pytest.mark.pipeline,
     pytest.mark.skipif("not psycopg.Pipeline.is_supported()"),
 ]
-
 pipeline_aborted = pytest.mark.flakey("the server might get in pipeline aborted")
 
 
 def test_repr(conn):
     with conn.pipeline() as p:
-        assert "psycopg.Pipeline" in repr(p)
+        name = "psycopg.AsyncPipeline" if is_async(conn) else "psycopg.Pipeline"
+        assert name in repr(p)
         assert "[IDLE, pipeline=ON]" in repr(p)
 
     conn.close()
@@ -105,7 +109,7 @@ def test_pipeline_nested_sync_trace(conn, trace):
 def test_cursor_stream(conn):
     with conn.pipeline(), conn.cursor() as cur:
         with pytest.raises(psycopg.ProgrammingError):
-            cur.stream("select 1").__next__()
+            next(cur.stream("select 1"))
 
 
 def test_server_cursor(conn):
@@ -124,8 +128,8 @@ def test_copy(conn):
     with conn.pipeline():
         cur = conn.cursor()
         with pytest.raises(e.NotSupportedError):
-            with cur.copy("copy (select 1) to stdout"):
-                pass
+            with cur.copy("copy (select 1) to stdout") as copy:
+                copy.read()
 
 
 def test_pipeline_processed_at_exit(conn):
@@ -318,8 +322,7 @@ def test_executemany(conn):
     conn.autocommit = True
     conn.execute("drop table if exists execmanypipeline")
     conn.execute(
-        "create unlogged table execmanypipeline ("
-        " id serial primary key, num integer)"
+        "create unlogged table execmanypipeline (id serial primary key, num integer)"
     )
     with conn.pipeline(), conn.cursor() as cur:
         cur.executemany(
@@ -339,8 +342,8 @@ def test_executemany_no_returning(conn):
     conn.autocommit = True
     conn.execute("drop table if exists execmanypipelinenoreturning")
     conn.execute(
-        "create unlogged table execmanypipelinenoreturning ("
-        " id serial primary key, num integer)"
+        """create unlogged table execmanypipelinenoreturning
+            (id serial primary key, num integer)"""
     )
     with conn.pipeline(), conn.cursor() as cur:
         cur.executemany(
@@ -369,7 +372,7 @@ def test_executemany_trace(conn, trace):
     items = list(t)
     assert items[-1].type == "Terminate"
     del items[-1]
-    roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))]
+    roundtrips = [k for (k, g) in groupby(items, key=attrgetter("direction"))]
     assert roundtrips == ["F", "B"]
     assert len([i for i in items if i.type == "Sync"]) == 1
 
@@ -391,7 +394,7 @@ def test_executemany_trace_returning(conn, trace):
     items = list(t)
     assert items[-1].type == "Terminate"
     del items[-1]
-    roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))]
+    roundtrips = [k for (k, g) in groupby(items, key=attrgetter("direction"))]
     assert roundtrips == ["F", "B"] * 3
     assert items[-2].direction == "F"  # last 2 items are F B
     assert len([i for i in items if i.type == "Sync"]) == 1
@@ -413,7 +416,6 @@ def test_prepared(conn):
 
 
 def test_auto_prepare(conn):
-    conn.autocommit = True
     conn.prepared_threshold = 5
     with conn.pipeline():
         cursors = [
@@ -433,6 +435,7 @@ def test_prepare_error(conn):
     An invalid prepared statement, in a pipeline, should be discarded at exit
     and not reused.
     """
+
     conn.autocommit = True
     stmt = "INSERT INTO nosuchtable(data) VALUES (%s)"
     with pytest.raises(psycopg.errors.UndefinedTable):
@@ -564,10 +567,9 @@ def test_concurrency(conn):
         conn.execute("drop table if exists accessed")
     with conn.transaction():
         conn.execute(
-            "create unlogged table pipeline_concurrency ("
-            " id serial primary key,"
-            " value integer"
-            ")"
+            """create unlogged table pipeline_concurrency (
+                id serial primary key,
+                value integer)"""
         )
         conn.execute("create unlogged table accessed as (select now() as value)")
 
@@ -585,11 +587,15 @@ def test_concurrency(conn):
 
     values = range(1, 10)
     with conn.pipeline():
-        with concurrent.futures.ThreadPoolExecutor() as e:
+        from concurrent.futures import ThreadPoolExecutor
+
+        with ThreadPoolExecutor() as e:
             cursors = e.map(update, values, timeout=len(values))
-            assert sum(cur.fetchone()[0] for cur in cursors) == sum(values)
 
-    (s,) = conn.execute("select sum(value) from pipeline_concurrency").fetchone()
+        assert sum([cur.fetchone()[0] for cur in cursors]) == sum(values)
+
+    cur = conn.execute("select sum(value) from pipeline_concurrency")
+    (s,) = cur.fetchone()
     assert s == sum(values)
     (after,) = conn.execute("select value from accessed").fetchone()
     assert after > before
index b20c1de7f40f224a388a691eed7fba339a15fb4a..5dcdc685a0563aa86a100a49c021f66e70f613d3 100644 (file)
@@ -1,4 +1,3 @@
-import asyncio
 import logging
 from typing import Any
 from operator import attrgetter
@@ -10,17 +9,19 @@ import psycopg
 from psycopg import pq
 from psycopg import errors as e
 
-from .test_pipeline import pipeline_aborted
+from .utils import is_async, anext
 
 pytestmark = [
     pytest.mark.pipeline,
-    pytest.mark.skipif("not psycopg.AsyncPipeline.is_supported()"),
+    pytest.mark.skipif("not psycopg.Pipeline.is_supported()"),
 ]
+pipeline_aborted = pytest.mark.flakey("the server might get in pipeline aborted")
 
 
 async def test_repr(aconn):
     async with aconn.pipeline() as p:
-        assert "psycopg.AsyncPipeline" in repr(p)
+        name = "psycopg.AsyncPipeline" if is_async(aconn) else "psycopg.Pipeline"
+        assert name in repr(p)
         assert "[IDLE, pipeline=ON]" in repr(p)
 
     await aconn.close()
@@ -105,7 +106,7 @@ async def test_pipeline_nested_sync_trace(aconn, trace):
 async def test_cursor_stream(aconn):
     async with aconn.pipeline(), aconn.cursor() as cur:
         with pytest.raises(psycopg.ProgrammingError):
-            await cur.stream("select 1").__anext__()
+            await anext(cur.stream("select 1"))
 
 
 async def test_server_cursor(aconn):
@@ -318,8 +319,7 @@ async def test_executemany(aconn):
     await aconn.set_autocommit(True)
     await aconn.execute("drop table if exists execmanypipeline")
     await aconn.execute(
-        "create unlogged table execmanypipeline ("
-        " id serial primary key, num integer)"
+        "create unlogged table execmanypipeline (id serial primary key, num integer)"
     )
     async with aconn.pipeline(), aconn.cursor() as cur:
         await cur.executemany(
@@ -339,8 +339,8 @@ async def test_executemany_no_returning(aconn):
     await aconn.set_autocommit(True)
     await aconn.execute("drop table if exists execmanypipelinenoreturning")
     await aconn.execute(
-        "create unlogged table execmanypipelinenoreturning ("
-        " id serial primary key, num integer)"
+        """create unlogged table execmanypipelinenoreturning
+            (id serial primary key, num integer)"""
     )
     async with aconn.pipeline(), aconn.cursor() as cur:
         await cur.executemany(
@@ -567,10 +567,9 @@ async def test_concurrency(aconn):
         await aconn.execute("drop table if exists accessed")
     async with aconn.transaction():
         await aconn.execute(
-            "create unlogged table pipeline_concurrency ("
-            " id serial primary key,"
-            " value integer"
-            ")"
+            """create unlogged table pipeline_concurrency (
+                id serial primary key,
+                value integer)"""
         )
         await aconn.execute("create unlogged table accessed as (select now() as value)")
 
@@ -588,16 +587,23 @@ async def test_concurrency(aconn):
 
     values = range(1, 10)
     async with aconn.pipeline():
-        cursors = await asyncio.wait_for(
-            asyncio.gather(*[update(value) for value in values]),
-            timeout=len(values),
-        )
+        if is_async(aconn):
+            import asyncio
+
+            cursors = await asyncio.wait_for(
+                asyncio.gather(*[update(value) for value in values]),
+                timeout=len(values),
+            )
+        else:
+            from concurrent.futures import ThreadPoolExecutor
+
+            with ThreadPoolExecutor() as e:
+                cursors = e.map(update, values, timeout=len(values))
 
-    assert sum([(await cur.fetchone())[0] for cur in cursors]) == sum(values)
+        assert sum([(await cur.fetchone())[0] for cur in cursors]) == sum(values)
 
-    (s,) = await (
-        await aconn.execute("select sum(value) from pipeline_concurrency")
-    ).fetchone()
+    cur = await aconn.execute("select sum(value) from pipeline_concurrency")
+    (s,) = await cur.fetchone()
     assert s == sum(values)
     (after,) = await (await aconn.execute("select value from accessed")).fetchone()
     assert after > before
index 02bc04539850f87faf6a9eab08a1ef8e874f5612..25d6273639c266aca3a7293c2ba0aa11eb0617fd 100755 (executable)
@@ -32,8 +32,8 @@ def main() -> int:
 
 def async_to_sync(tree: ast.AST) -> ast.AST:
     tree = BlanksInserter().visit(tree)
-    tree = AsyncToSync().visit(tree)
     tree = RenameAsyncToSync().visit(tree)
+    tree = AsyncToSync().visit(tree)
     tree = FixAsyncSetters().visit(tree)
     return tree
 
@@ -80,10 +80,32 @@ class AsyncToSync(ast.NodeTransformer):
         self.visit(new_node)
         return new_node
 
+    def visit_If(self, node: ast.If) -> ast.AST:
+        # Drop `if is_async()` branch.
+        #
+        # Assume that the test guards an async object becoming sync and remove
+        # the async side, because it will likely contain `await` constructs
+        # illegal into a sync function.
+        if self._is_async_call(node.test):
+            for child in node.orelse:
+                self.visit(child)
+            return node.orelse
+
+        self.generic_visit(node)
+        return node
+
+    def _is_async_call(self, test: ast.AST) -> bool:
+        if not isinstance(test, ast.Call):
+            return False
+        if test.func.id != "is_async":
+            return False
+        return True
+
 
 class RenameAsyncToSync(ast.NodeTransformer):
     names_map = {
         "AsyncClientCursor": "ClientCursor",
+        "AsyncConnection": "Connection",
         "AsyncCursor": "Cursor",
         "AsyncRawCursor": "RawCursor",
         "AsyncServerCursor": "ServerCursor",
@@ -108,10 +130,20 @@ class RenameAsyncToSync(ast.NodeTransformer):
         self.generic_visit(node)
         return node
 
-    def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
+    def visit_AsyncFunctionDef(self, node: ast.FunctionDef) -> ast.AST:
         node.name = self.names_map.get(node.name, node.name)
         for arg in node.args.args:
             arg.arg = self.names_map.get(arg.arg, arg.arg)
+        for arg in node.args.args:
+            ann = arg.annotation
+            if not ann:
+                continue
+            if isinstance(ann, ast.Subscript):
+                # Remove the [] from the type
+                ann = ann.value
+            if isinstance(ann, ast.Attribute):
+                ann.attr = self.names_map.get(ann.attr, ann.attr)
+
         self.generic_visit(node)
         return node
 
index e9b9031e829f8196692b469c6d5275c288b4fe65..8727c54064ee0fba9d4e0f0509ba583e7d72c190 100755 (executable)
@@ -7,7 +7,13 @@ set -euo pipefail
 dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
 cd "${dir}/.."
 
-python "${dir}/async_to_sync.py" tests/test_connection_async.py > tests/test_connection.py
-black -q tests/test_connection.py
-python "${dir}/async_to_sync.py" tests/test_cursor_async.py > tests/test_cursor.py
-black -q tests/test_cursor.py
+for async in \
+    tests/test_connection_async.py \
+    tests/test_cursor_async.py \
+    tests/test_pipeline_async.py
+do
+    sync=${async/_async/}
+    echo "converting '${async}' -> '${sync}'" >&2
+    python "${dir}/async_to_sync.py" ${async} > ${sync}
+    black -q ${sync}
+done