From: Daniele Varrazzo Date: Thu, 10 Aug 2023 00:07:08 +0000 (+0100) Subject: refactor(tests): generate test_pipeline from async counterpart X-Git-Tag: pool-3.2.0~12^2~57 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=7d84fbe6aea6588e60640b66908c089b47da19e5;p=thirdparty%2Fpsycopg.git refactor(tests): generate test_pipeline from async counterpart --- diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 2de7fabb3..35028795f 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,5 +1,7 @@ +# WARNING: this file is auto-generated by 'async_to_sync.py' +# from the original file 'test_pipeline_async.py' +# DO NOT CHANGE! Change the original file instead. import logging -import concurrent.futures from typing import Any from operator import attrgetter from itertools import groupby @@ -10,17 +12,19 @@ import psycopg from psycopg import pq from psycopg import errors as e +from .utils import is_async + pytestmark = [ pytest.mark.pipeline, pytest.mark.skipif("not psycopg.Pipeline.is_supported()"), ] - pipeline_aborted = pytest.mark.flakey("the server might get in pipeline aborted") def test_repr(conn): with conn.pipeline() as p: - assert "psycopg.Pipeline" in repr(p) + name = "psycopg.AsyncPipeline" if is_async(conn) else "psycopg.Pipeline" + assert name in repr(p) assert "[IDLE, pipeline=ON]" in repr(p) conn.close() @@ -105,7 +109,7 @@ def test_pipeline_nested_sync_trace(conn, trace): def test_cursor_stream(conn): with conn.pipeline(), conn.cursor() as cur: with pytest.raises(psycopg.ProgrammingError): - cur.stream("select 1").__next__() + next(cur.stream("select 1")) def test_server_cursor(conn): @@ -124,8 +128,8 @@ def test_copy(conn): with conn.pipeline(): cur = conn.cursor() with pytest.raises(e.NotSupportedError): - with cur.copy("copy (select 1) to stdout"): - pass + with cur.copy("copy (select 1) to stdout") as copy: + copy.read() def test_pipeline_processed_at_exit(conn): @@ -318,8 +322,7 @@ def test_executemany(conn): conn.autocommit = True conn.execute("drop table if exists execmanypipeline") conn.execute( - "create unlogged table execmanypipeline (" - " id serial primary key, num integer)" + "create unlogged table execmanypipeline (id serial primary key, num integer)" ) with conn.pipeline(), conn.cursor() as cur: cur.executemany( @@ -339,8 +342,8 @@ def test_executemany_no_returning(conn): conn.autocommit = True conn.execute("drop table if exists execmanypipelinenoreturning") conn.execute( - "create unlogged table execmanypipelinenoreturning (" - " id serial primary key, num integer)" + """create unlogged table execmanypipelinenoreturning + (id serial primary key, num integer)""" ) with conn.pipeline(), conn.cursor() as cur: cur.executemany( @@ -369,7 +372,7 @@ def test_executemany_trace(conn, trace): items = list(t) assert items[-1].type == "Terminate" del items[-1] - roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))] + roundtrips = [k for (k, g) in groupby(items, key=attrgetter("direction"))] assert roundtrips == ["F", "B"] assert len([i for i in items if i.type == "Sync"]) == 1 @@ -391,7 +394,7 @@ def test_executemany_trace_returning(conn, trace): items = list(t) assert items[-1].type == "Terminate" del items[-1] - roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))] + roundtrips = [k for (k, g) in groupby(items, key=attrgetter("direction"))] assert roundtrips == ["F", "B"] * 3 assert items[-2].direction == "F" # last 2 items are F B assert len([i for i in items if i.type == "Sync"]) == 1 @@ -413,7 +416,6 @@ def test_prepared(conn): def test_auto_prepare(conn): - conn.autocommit = True conn.prepared_threshold = 5 with conn.pipeline(): cursors = [ @@ -433,6 +435,7 @@ def test_prepare_error(conn): An invalid prepared statement, in a pipeline, should be discarded at exit and not reused. """ + conn.autocommit = True stmt = "INSERT INTO nosuchtable(data) VALUES (%s)" with pytest.raises(psycopg.errors.UndefinedTable): @@ -564,10 +567,9 @@ def test_concurrency(conn): conn.execute("drop table if exists accessed") with conn.transaction(): conn.execute( - "create unlogged table pipeline_concurrency (" - " id serial primary key," - " value integer" - ")" + """create unlogged table pipeline_concurrency ( + id serial primary key, + value integer)""" ) conn.execute("create unlogged table accessed as (select now() as value)") @@ -585,11 +587,15 @@ def test_concurrency(conn): values = range(1, 10) with conn.pipeline(): - with concurrent.futures.ThreadPoolExecutor() as e: + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor() as e: cursors = e.map(update, values, timeout=len(values)) - assert sum(cur.fetchone()[0] for cur in cursors) == sum(values) - (s,) = conn.execute("select sum(value) from pipeline_concurrency").fetchone() + assert sum([cur.fetchone()[0] for cur in cursors]) == sum(values) + + cur = conn.execute("select sum(value) from pipeline_concurrency") + (s,) = cur.fetchone() assert s == sum(values) (after,) = conn.execute("select value from accessed").fetchone() assert after > before diff --git a/tests/test_pipeline_async.py b/tests/test_pipeline_async.py index b20c1de7f..5dcdc685a 100644 --- a/tests/test_pipeline_async.py +++ b/tests/test_pipeline_async.py @@ -1,4 +1,3 @@ -import asyncio import logging from typing import Any from operator import attrgetter @@ -10,17 +9,19 @@ import psycopg from psycopg import pq from psycopg import errors as e -from .test_pipeline import pipeline_aborted +from .utils import is_async, anext pytestmark = [ pytest.mark.pipeline, - pytest.mark.skipif("not psycopg.AsyncPipeline.is_supported()"), + pytest.mark.skipif("not psycopg.Pipeline.is_supported()"), ] +pipeline_aborted = pytest.mark.flakey("the server might get in pipeline aborted") async def test_repr(aconn): async with aconn.pipeline() as p: - assert "psycopg.AsyncPipeline" in repr(p) + name = "psycopg.AsyncPipeline" if is_async(aconn) else "psycopg.Pipeline" + assert name in repr(p) assert "[IDLE, pipeline=ON]" in repr(p) await aconn.close() @@ -105,7 +106,7 @@ async def test_pipeline_nested_sync_trace(aconn, trace): async def test_cursor_stream(aconn): async with aconn.pipeline(), aconn.cursor() as cur: with pytest.raises(psycopg.ProgrammingError): - await cur.stream("select 1").__anext__() + await anext(cur.stream("select 1")) async def test_server_cursor(aconn): @@ -318,8 +319,7 @@ async def test_executemany(aconn): await aconn.set_autocommit(True) await aconn.execute("drop table if exists execmanypipeline") await aconn.execute( - "create unlogged table execmanypipeline (" - " id serial primary key, num integer)" + "create unlogged table execmanypipeline (id serial primary key, num integer)" ) async with aconn.pipeline(), aconn.cursor() as cur: await cur.executemany( @@ -339,8 +339,8 @@ async def test_executemany_no_returning(aconn): await aconn.set_autocommit(True) await aconn.execute("drop table if exists execmanypipelinenoreturning") await aconn.execute( - "create unlogged table execmanypipelinenoreturning (" - " id serial primary key, num integer)" + """create unlogged table execmanypipelinenoreturning + (id serial primary key, num integer)""" ) async with aconn.pipeline(), aconn.cursor() as cur: await cur.executemany( @@ -567,10 +567,9 @@ async def test_concurrency(aconn): await aconn.execute("drop table if exists accessed") async with aconn.transaction(): await aconn.execute( - "create unlogged table pipeline_concurrency (" - " id serial primary key," - " value integer" - ")" + """create unlogged table pipeline_concurrency ( + id serial primary key, + value integer)""" ) await aconn.execute("create unlogged table accessed as (select now() as value)") @@ -588,16 +587,23 @@ async def test_concurrency(aconn): values = range(1, 10) async with aconn.pipeline(): - cursors = await asyncio.wait_for( - asyncio.gather(*[update(value) for value in values]), - timeout=len(values), - ) + if is_async(aconn): + import asyncio + + cursors = await asyncio.wait_for( + asyncio.gather(*[update(value) for value in values]), + timeout=len(values), + ) + else: + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor() as e: + cursors = e.map(update, values, timeout=len(values)) - assert sum([(await cur.fetchone())[0] for cur in cursors]) == sum(values) + assert sum([(await cur.fetchone())[0] for cur in cursors]) == sum(values) - (s,) = await ( - await aconn.execute("select sum(value) from pipeline_concurrency") - ).fetchone() + cur = await aconn.execute("select sum(value) from pipeline_concurrency") + (s,) = await cur.fetchone() assert s == sum(values) (after,) = await (await aconn.execute("select value from accessed")).fetchone() assert after > before diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index 02bc04539..25d627363 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -32,8 +32,8 @@ def main() -> int: def async_to_sync(tree: ast.AST) -> ast.AST: tree = BlanksInserter().visit(tree) - tree = AsyncToSync().visit(tree) tree = RenameAsyncToSync().visit(tree) + tree = AsyncToSync().visit(tree) tree = FixAsyncSetters().visit(tree) return tree @@ -80,10 +80,32 @@ class AsyncToSync(ast.NodeTransformer): self.visit(new_node) return new_node + def visit_If(self, node: ast.If) -> ast.AST: + # Drop `if is_async()` branch. + # + # Assume that the test guards an async object becoming sync and remove + # the async side, because it will likely contain `await` constructs + # illegal into a sync function. + if self._is_async_call(node.test): + for child in node.orelse: + self.visit(child) + return node.orelse + + self.generic_visit(node) + return node + + def _is_async_call(self, test: ast.AST) -> bool: + if not isinstance(test, ast.Call): + return False + if test.func.id != "is_async": + return False + return True + class RenameAsyncToSync(ast.NodeTransformer): names_map = { "AsyncClientCursor": "ClientCursor", + "AsyncConnection": "Connection", "AsyncCursor": "Cursor", "AsyncRawCursor": "RawCursor", "AsyncServerCursor": "ServerCursor", @@ -108,10 +130,20 @@ class RenameAsyncToSync(ast.NodeTransformer): self.generic_visit(node) return node - def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: + def visit_AsyncFunctionDef(self, node: ast.FunctionDef) -> ast.AST: node.name = self.names_map.get(node.name, node.name) for arg in node.args.args: arg.arg = self.names_map.get(arg.arg, arg.arg) + for arg in node.args.args: + ann = arg.annotation + if not ann: + continue + if isinstance(ann, ast.Subscript): + # Remove the [] from the type + ann = ann.value + if isinstance(ann, ast.Attribute): + ann.attr = self.names_map.get(ann.attr, ann.attr) + self.generic_visit(node) return node diff --git a/tools/convert_async_to_sync.sh b/tools/convert_async_to_sync.sh index e9b9031e8..8727c5406 100755 --- a/tools/convert_async_to_sync.sh +++ b/tools/convert_async_to_sync.sh @@ -7,7 +7,13 @@ set -euo pipefail dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "${dir}/.." -python "${dir}/async_to_sync.py" tests/test_connection_async.py > tests/test_connection.py -black -q tests/test_connection.py -python "${dir}/async_to_sync.py" tests/test_cursor_async.py > tests/test_cursor.py -black -q tests/test_cursor.py +for async in \ + tests/test_connection_async.py \ + tests/test_cursor_async.py \ + tests/test_pipeline_async.py +do + sync=${async/_async/} + echo "converting '${async}' -> '${sync}'" >&2 + python "${dir}/async_to_sync.py" ${async} > ${sync} + black -q ${sync} +done