From: Daniele Varrazzo Date: Mon, 7 Aug 2023 09:11:24 +0000 (+0100) Subject: refactor(tests): auto-generate the test_cursor module from async_test_cursor X-Git-Tag: pool-3.2.0~12^2~62 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=fac75e476c054447b25b44fa2894bee00a2c6a5b;p=thirdparty%2Fpsycopg.git refactor(tests): auto-generate the test_cursor module from async_test_cursor Use a script to translate an async module to sync. The module is to be extended to cover the entire test suite and then possibly the rest of the code. --- diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 326f699fe..18376125e 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -1,3 +1,6 @@ +# WARNING: this file is auto-generated by 'async_to_sync.py' +# from the original file 'test_cursor_async.py' +# DO NOT CHANGE! Change the original file instead. """ Tests common to psycopg.Cursor and its subclasses. """ @@ -23,7 +26,7 @@ execmany = execmany # avoid F811 underneath @pytest.fixture(params=[psycopg.Cursor, psycopg.ClientCursor, psycopg.RawCursor]) -def conn(conn, request): +def conn(conn, request, anyio_backend): conn.cursor_factory = request.param return conn @@ -167,7 +170,7 @@ def test_query_parse_cache_size(conn): cache.cache_clear() ci = cache.cache_info() - h0, m0 = ci.hits, ci.misses + (h0, m0) = (ci.hits, ci.misses) tests = [ (f"select 1 -- {'x' * 3500}", (), h0, m0 + 1), (f"select 1 -- {'x' * 3500}", (), h0 + 1, m0 + 1), @@ -229,18 +232,18 @@ def test_execute_type_change(conn): cur = conn.cursor() sql = ph(cur, "insert into bug_112 (num) values (%s)") cur.execute(sql, (1,)) - cur.execute(sql, (100_000,)) + cur.execute(sql, (100000,)) cur.execute("select num from bug_112 order by num") - assert cur.fetchall() == [(1,), (100_000,)] + assert cur.fetchall() == [(1,), (100000,)] def test_executemany_type_change(conn): conn.execute("create table bug_112 (num integer)") cur = conn.cursor() sql = ph(cur, "insert into bug_112 (num) values (%s)") - cur.executemany(sql, [(1,), (100_000,)]) + cur.executemany(sql, [(1,), (100000,)]) cur.execute("select num from bug_112 order by num") - assert cur.fetchall() == [(1,), (100_000,)] + assert cur.fetchall() == [(1,), (100000,)] @pytest.mark.parametrize( @@ -304,8 +307,9 @@ def test_binary_cursor_text_override(conn): def test_query_encode(conn, encoding): conn.execute(f"set client_encoding to {encoding}") cur = conn.cursor() - (res,) = cur.execute("select '\u20ac'").fetchone() - assert res == "\u20ac" + cur.execute("select '€'") + (res,) = cur.fetchone() + assert res == "€" @pytest.mark.parametrize("encoding", [crdb_encoding("latin1")]) @@ -313,7 +317,7 @@ def test_query_badenc(conn, encoding): conn.execute(f"set client_encoding to {encoding}") cur = conn.cursor() with pytest.raises(UnicodeEncodeError): - cur.execute("select '\u20ac'") + cur.execute("select '€'") def test_executemany(conn, execmany): @@ -323,7 +327,8 @@ def test_executemany(conn, execmany): [(10, "hello"), (20, "world")], ) cur.execute("select num, data from execmany order by 1") - assert cur.fetchall() == [(10, "hello"), (20, "world")] + rv = cur.fetchall() + assert rv == [(10, "hello"), (20, "world")] def test_executemany_name(conn, execmany): @@ -333,7 +338,8 @@ def test_executemany_name(conn, execmany): [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}], ) cur.execute("select num, data from execmany order by 1") - assert cur.fetchall() == [(11, "hello"), (21, "world")] + rv = cur.fetchall() + assert rv == [(11, "hello"), (21, "world")] def test_executemany_no_data(conn, execmany): @@ -433,7 +439,7 @@ def test_executemany_null_first(conn, fmt_in): ) with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)): cur.executemany( - f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + ph(cur, f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})"), [[1, ""], [3, 4]], ) @@ -490,13 +496,12 @@ def test_rownumber_none(conn, query): def test_rownumber_mixed(conn): cur = conn.cursor() - cur.execute( - """ -select x from generate_series(1, 3) x; -set timezone to utc; -select x from generate_series(4, 6) x; -""" - ) + queries = [ + "select x from generate_series(1, 3) x", + "set timezone to utc", + "select x from generate_series(4, 6) x", + ] + cur.execute(";\n".join(queries)) assert cur.rownumber == 0 assert cur.fetchone() == (1,) assert cur.rownumber == 1 @@ -555,7 +560,8 @@ def test_row_factory(conn): def test_row_factory_none(conn): cur = conn.cursor(row_factory=None) assert cur.row_factory is rows.tuple_row - r = cur.execute("select 1 as a, 2 as b").fetchone() + cur.execute("select 1 as a, 2 as b") + r = cur.fetchone() assert type(r) is tuple assert r == (1, 2) @@ -684,12 +690,7 @@ def test_stream_no_col(conn): @pytest.mark.parametrize( - "query", - [ - "create table test_stream_badq ()", - "copy (select 1) to stdout", - "wat?", - ], + "query", ["create table test_stream_badq ()", "copy (select 1) to stdout", "wat?"] ) def test_stream_badquery(conn, query): cur = conn.cursor() @@ -733,6 +734,7 @@ def test_stream_error_python_consumed(conn): gen = cur.stream("select 1") for rec in gen: 1 / 0 + gen.close() assert conn.info.transaction_status == conn.TransactionStatus.INTRANS diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py new file mode 100755 index 000000000..2e6f98dc1 --- /dev/null +++ b/tools/async_to_sync.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python +"""Convert an async module to a sync module. +""" + +from __future__ import annotations + +import os +import sys +from argparse import ArgumentParser, Namespace + +import ast_comments as ast + + +def main() -> int: + opt = parse_cmdline() + with open(opt.filename) as f: + source = f.read() + + tree = ast.parse(source, filename=opt.filename) + tree = async_to_sync(tree) + output = tree_to_str(tree, opt.filename) + + if opt.output: + with open(opt.output, "w") as f: + print(output, file=f) + else: + print(output) + + return 0 + + +def async_to_sync(tree: ast.AST) -> ast.AST: + tree = BlanksInserter().visit(tree) + tree = AsyncToSync().visit(tree) + tree = RenameAsyncToSync().visit(tree) + tree = FixSetAutocommit().visit(tree) + return tree + + +def tree_to_str(tree: ast.AST, filename: str) -> str: + rv = f"""\ +# WARNING: this file is auto-generated by '{os.path.basename(sys.argv[0])}' +# from the original file '{os.path.basename(filename)}' +# DO NOT CHANGE! Change the original file instead. +""" + rv += ast.unparse(tree) + return rv + + +class AsyncToSync(ast.NodeTransformer): + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: + new_node = ast.FunctionDef( + name=node.name, + args=node.args, + body=node.body, + decorator_list=node.decorator_list, + returns=node.returns, + ) + ast.copy_location(new_node, node) + self.visit(new_node) + return new_node + + def visit_AsyncFor(self, node: ast.AsyncFor) -> ast.AST: + new_node = ast.For( + target=node.target, iter=node.iter, body=node.body, orelse=node.orelse + ) + ast.copy_location(new_node, node) + self.visit(new_node) + return new_node + + def visit_AsyncWith(self, node: ast.AsyncWith) -> ast.AST: + new_node = ast.With(items=node.items, body=node.body) + ast.copy_location(new_node, node) + self.visit(new_node) + return new_node + + def visit_Await(self, node: ast.Await) -> ast.AST: + new_node = node.value + self.visit(new_node) + return new_node + + +class RenameAsyncToSync(ast.NodeTransformer): + names_map = { + "AsyncClientCursor": "ClientCursor", + "AsyncCursor": "Cursor", + "AsyncRawCursor": "RawCursor", + "aclose": "close", + "aclosing": "closing", + "aconn": "conn", + "alist": "list", + "anext": "next", + } + + def visit_Module(self, node: ast.Module) -> ast.AST: + # Replace the content of the module docstring. + if ( + node.body + and isinstance(node.body[0], ast.Expr) + and isinstance(node.body[0].value, ast.Constant) + ): + node.body[0].value.value = node.body[0].value.value.replace("Async", "") + + self.generic_visit(node) + return node + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: + node.name = self.names_map.get(node.name, node.name) + for arg in node.args.args: + arg.arg = self.names_map.get(arg.arg, arg.arg) + self.generic_visit(node) + return node + + _skip_imports = {"alist", "anext"} + + def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None: + # Remove import of async utils eclypsing builtings + if node.module == "utils": + if {n.name for n in node.names} <= self._skip_imports: + return None + + for n in node.names: + n.name = self.names_map.get(n.name, n.name) + return node + + def visit_Name(self, node: ast.Name) -> ast.AST: + if node.id in self.names_map: + node.id = self.names_map[node.id] + return node + + def visit_Attribute(self, node: ast.Attribute) -> ast.AST: + if node.attr in self.names_map: + node.attr = self.names_map[node.attr] + self.generic_visit(node) + return node + + +class FixSetAutocommit(ast.NodeTransformer): + def visit_Call(self, node: ast.Call) -> ast.AST: + new_node = self._fix_autocommit(node) + if new_node: + return new_node + + self.generic_visit(node) + return node + + def _fix_autocommit(self, node: ast.Call) -> ast.AST | None: + if not isinstance(node.func, ast.Attribute): + return None + if node.func.attr != "set_autocommit": + return None + obj = node.func.value + arg = node.args[0] + new_node = ast.Assign( + targets=[ast.Attribute(value=obj, attr="autocommit")], + value=arg, + ) + ast.copy_location(new_node, node) + return new_node + + +class BlanksInserter(ast.NodeTransformer): + """ + Restore the missing spaces in the source (or something similar) + """ + + def generic_visit(self, node: ast.AST) -> ast.AST: + if isinstance(getattr(node, "body", None), list): + node.body = self._inject_blanks(node.body) + super().generic_visit(node) + return node + + def _inject_blanks(self, body: list[ast.Node]) -> list[ast.AST]: + if not body: + return body + + new_body = [] + before = body[0] + new_body.append(before) + for i in range(1, len(body)): + after = body[i] + nblanks = after.lineno - before.end_lineno - 1 + if nblanks > 0: + # Inserting one blank is enough. + blank = ast.Comment( + value="", + inline=False, + lineno=before.end_lineno + 1, + end_lineno=before.end_lineno + 1, + col_offset=0, + end_col_offset=0, + ) + new_body.append(blank) + new_body.append(after) + before = after + + return new_body + + +def parse_cmdline() -> Namespace: + parser = ArgumentParser(description=__doc__) + parser.add_argument("filename", metavar="FILE", help="the file to process") + parser.add_argument( + "output", metavar="OUTPUT", nargs="?", help="file where to write (or stdout)" + ) + opt = parser.parse_args() + + return opt + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tools/convert_async_to_sync.sh b/tools/convert_async_to_sync.sh new file mode 100755 index 000000000..93bbc6548 --- /dev/null +++ b/tools/convert_async_to_sync.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# Convert all the auto-generated sync files from their async counterparts. + +set -euo pipefail + +dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "${dir}/.." + +python "${dir}/async_to_sync.py" tests/test_cursor_async.py > tests/test_cursor.py +black -q tests/test_cursor.py