]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(tests): auto-generate the test_cursor module from async_test_cursor
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 7 Aug 2023 09:11:24 +0000 (10:11 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
Use a script to translate an async module to sync. The module is to be
extended to cover the entire test suite and then possibly the rest of
the code.

tests/test_cursor.py
tools/async_to_sync.py [new file with mode: 0755]
tools/convert_async_to_sync.sh [new file with mode: 0755]

index 326f699fe261d0c3afed3499d018897228e8e4d3..18376125edc9b18c73e5c322116ebcbb7f5489e6 100644 (file)
@@ -1,3 +1,6 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'test_cursor_async.py'
+# DO NOT CHANGE! Change the original file instead.
 """
 Tests common to psycopg.Cursor and its subclasses.
 """
@@ -23,7 +26,7 @@ execmany = execmany  # avoid F811 underneath
 
 
 @pytest.fixture(params=[psycopg.Cursor, psycopg.ClientCursor, psycopg.RawCursor])
-def conn(conn, request):
+def conn(conn, request, anyio_backend):
     conn.cursor_factory = request.param
     return conn
 
@@ -167,7 +170,7 @@ def test_query_parse_cache_size(conn):
 
     cache.cache_clear()
     ci = cache.cache_info()
-    h0, m0 = ci.hits, ci.misses
+    (h0, m0) = (ci.hits, ci.misses)
     tests = [
         (f"select 1 -- {'x' * 3500}", (), h0, m0 + 1),
         (f"select 1 -- {'x' * 3500}", (), h0 + 1, m0 + 1),
@@ -229,18 +232,18 @@ def test_execute_type_change(conn):
     cur = conn.cursor()
     sql = ph(cur, "insert into bug_112 (num) values (%s)")
     cur.execute(sql, (1,))
-    cur.execute(sql, (100_000,))
+    cur.execute(sql, (100000,))
     cur.execute("select num from bug_112 order by num")
-    assert cur.fetchall() == [(1,), (100_000,)]
+    assert cur.fetchall() == [(1,), (100000,)]
 
 
 def test_executemany_type_change(conn):
     conn.execute("create table bug_112 (num integer)")
     cur = conn.cursor()
     sql = ph(cur, "insert into bug_112 (num) values (%s)")
-    cur.executemany(sql, [(1,), (100_000,)])
+    cur.executemany(sql, [(1,), (100000,)])
     cur.execute("select num from bug_112 order by num")
-    assert cur.fetchall() == [(1,), (100_000,)]
+    assert cur.fetchall() == [(1,), (100000,)]
 
 
 @pytest.mark.parametrize(
@@ -304,8 +307,9 @@ def test_binary_cursor_text_override(conn):
 def test_query_encode(conn, encoding):
     conn.execute(f"set client_encoding to {encoding}")
     cur = conn.cursor()
-    (res,) = cur.execute("select '\u20ac'").fetchone()
-    assert res == "\u20ac"
+    cur.execute("select '€'")
+    (res,) = cur.fetchone()
+    assert res == "€"
 
 
 @pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
@@ -313,7 +317,7 @@ def test_query_badenc(conn, encoding):
     conn.execute(f"set client_encoding to {encoding}")
     cur = conn.cursor()
     with pytest.raises(UnicodeEncodeError):
-        cur.execute("select '\u20ac'")
+        cur.execute("select ''")
 
 
 def test_executemany(conn, execmany):
@@ -323,7 +327,8 @@ def test_executemany(conn, execmany):
         [(10, "hello"), (20, "world")],
     )
     cur.execute("select num, data from execmany order by 1")
-    assert cur.fetchall() == [(10, "hello"), (20, "world")]
+    rv = cur.fetchall()
+    assert rv == [(10, "hello"), (20, "world")]
 
 
 def test_executemany_name(conn, execmany):
@@ -333,7 +338,8 @@ def test_executemany_name(conn, execmany):
         [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}],
     )
     cur.execute("select num, data from execmany order by 1")
-    assert cur.fetchall() == [(11, "hello"), (21, "world")]
+    rv = cur.fetchall()
+    assert rv == [(11, "hello"), (21, "world")]
 
 
 def test_executemany_no_data(conn, execmany):
@@ -433,7 +439,7 @@ def test_executemany_null_first(conn, fmt_in):
     )
     with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)):
         cur.executemany(
-            f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+            ph(cur, f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})"),
             [[1, ""], [3, 4]],
         )
 
@@ -490,13 +496,12 @@ def test_rownumber_none(conn, query):
 
 def test_rownumber_mixed(conn):
     cur = conn.cursor()
-    cur.execute(
-        """
-select x from generate_series(1, 3) x;
-set timezone to utc;
-select x from generate_series(4, 6) x;
-"""
-    )
+    queries = [
+        "select x from generate_series(1, 3) x",
+        "set timezone to utc",
+        "select x from generate_series(4, 6) x",
+    ]
+    cur.execute(";\n".join(queries))
     assert cur.rownumber == 0
     assert cur.fetchone() == (1,)
     assert cur.rownumber == 1
@@ -555,7 +560,8 @@ def test_row_factory(conn):
 def test_row_factory_none(conn):
     cur = conn.cursor(row_factory=None)
     assert cur.row_factory is rows.tuple_row
-    r = cur.execute("select 1 as a, 2 as b").fetchone()
+    cur.execute("select 1 as a, 2 as b")
+    r = cur.fetchone()
     assert type(r) is tuple
     assert r == (1, 2)
 
@@ -684,12 +690,7 @@ def test_stream_no_col(conn):
 
 
 @pytest.mark.parametrize(
-    "query",
-    [
-        "create table test_stream_badq ()",
-        "copy (select 1) to stdout",
-        "wat?",
-    ],
+    "query", ["create table test_stream_badq ()", "copy (select 1) to stdout", "wat?"]
 )
 def test_stream_badquery(conn, query):
     cur = conn.cursor()
@@ -733,6 +734,7 @@ def test_stream_error_python_consumed(conn):
         gen = cur.stream("select 1")
         for rec in gen:
             1 / 0
+
     gen.close()
     assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
 
diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py
new file mode 100755 (executable)
index 0000000..2e6f98d
--- /dev/null
@@ -0,0 +1,212 @@
+#!/usr/bin/env python
+"""Convert an async module to a sync module.
+"""
+
+from __future__ import annotations
+
+import os
+import sys
+from argparse import ArgumentParser, Namespace
+
+import ast_comments as ast
+
+
+def main() -> int:
+    opt = parse_cmdline()
+    with open(opt.filename) as f:
+        source = f.read()
+
+    tree = ast.parse(source, filename=opt.filename)
+    tree = async_to_sync(tree)
+    output = tree_to_str(tree, opt.filename)
+
+    if opt.output:
+        with open(opt.output, "w") as f:
+            print(output, file=f)
+    else:
+        print(output)
+
+    return 0
+
+
+def async_to_sync(tree: ast.AST) -> ast.AST:
+    tree = BlanksInserter().visit(tree)
+    tree = AsyncToSync().visit(tree)
+    tree = RenameAsyncToSync().visit(tree)
+    tree = FixSetAutocommit().visit(tree)
+    return tree
+
+
+def tree_to_str(tree: ast.AST, filename: str) -> str:
+    rv = f"""\
+# WARNING: this file is auto-generated by '{os.path.basename(sys.argv[0])}'
+# from the original file '{os.path.basename(filename)}'
+# DO NOT CHANGE! Change the original file instead.
+"""
+    rv += ast.unparse(tree)
+    return rv
+
+
+class AsyncToSync(ast.NodeTransformer):
+    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST:
+        new_node = ast.FunctionDef(
+            name=node.name,
+            args=node.args,
+            body=node.body,
+            decorator_list=node.decorator_list,
+            returns=node.returns,
+        )
+        ast.copy_location(new_node, node)
+        self.visit(new_node)
+        return new_node
+
+    def visit_AsyncFor(self, node: ast.AsyncFor) -> ast.AST:
+        new_node = ast.For(
+            target=node.target, iter=node.iter, body=node.body, orelse=node.orelse
+        )
+        ast.copy_location(new_node, node)
+        self.visit(new_node)
+        return new_node
+
+    def visit_AsyncWith(self, node: ast.AsyncWith) -> ast.AST:
+        new_node = ast.With(items=node.items, body=node.body)
+        ast.copy_location(new_node, node)
+        self.visit(new_node)
+        return new_node
+
+    def visit_Await(self, node: ast.Await) -> ast.AST:
+        new_node = node.value
+        self.visit(new_node)
+        return new_node
+
+
+class RenameAsyncToSync(ast.NodeTransformer):
+    names_map = {
+        "AsyncClientCursor": "ClientCursor",
+        "AsyncCursor": "Cursor",
+        "AsyncRawCursor": "RawCursor",
+        "aclose": "close",
+        "aclosing": "closing",
+        "aconn": "conn",
+        "alist": "list",
+        "anext": "next",
+    }
+
+    def visit_Module(self, node: ast.Module) -> ast.AST:
+        # Replace the content of the module docstring.
+        if (
+            node.body
+            and isinstance(node.body[0], ast.Expr)
+            and isinstance(node.body[0].value, ast.Constant)
+        ):
+            node.body[0].value.value = node.body[0].value.value.replace("Async", "")
+
+        self.generic_visit(node)
+        return node
+
+    def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
+        node.name = self.names_map.get(node.name, node.name)
+        for arg in node.args.args:
+            arg.arg = self.names_map.get(arg.arg, arg.arg)
+        self.generic_visit(node)
+        return node
+
+    _skip_imports = {"alist", "anext"}
+
+    def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None:
+        # Remove import of async utils eclypsing builtings
+        if node.module == "utils":
+            if {n.name for n in node.names} <= self._skip_imports:
+                return None
+
+        for n in node.names:
+            n.name = self.names_map.get(n.name, n.name)
+        return node
+
+    def visit_Name(self, node: ast.Name) -> ast.AST:
+        if node.id in self.names_map:
+            node.id = self.names_map[node.id]
+        return node
+
+    def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
+        if node.attr in self.names_map:
+            node.attr = self.names_map[node.attr]
+        self.generic_visit(node)
+        return node
+
+
+class FixSetAutocommit(ast.NodeTransformer):
+    def visit_Call(self, node: ast.Call) -> ast.AST:
+        new_node = self._fix_autocommit(node)
+        if new_node:
+            return new_node
+
+        self.generic_visit(node)
+        return node
+
+    def _fix_autocommit(self, node: ast.Call) -> ast.AST | None:
+        if not isinstance(node.func, ast.Attribute):
+            return None
+        if node.func.attr != "set_autocommit":
+            return None
+        obj = node.func.value
+        arg = node.args[0]
+        new_node = ast.Assign(
+            targets=[ast.Attribute(value=obj, attr="autocommit")],
+            value=arg,
+        )
+        ast.copy_location(new_node, node)
+        return new_node
+
+
+class BlanksInserter(ast.NodeTransformer):
+    """
+    Restore the missing spaces in the source (or something similar)
+    """
+
+    def generic_visit(self, node: ast.AST) -> ast.AST:
+        if isinstance(getattr(node, "body", None), list):
+            node.body = self._inject_blanks(node.body)
+        super().generic_visit(node)
+        return node
+
+    def _inject_blanks(self, body: list[ast.Node]) -> list[ast.AST]:
+        if not body:
+            return body
+
+        new_body = []
+        before = body[0]
+        new_body.append(before)
+        for i in range(1, len(body)):
+            after = body[i]
+            nblanks = after.lineno - before.end_lineno - 1
+            if nblanks > 0:
+                # Inserting one blank is enough.
+                blank = ast.Comment(
+                    value="",
+                    inline=False,
+                    lineno=before.end_lineno + 1,
+                    end_lineno=before.end_lineno + 1,
+                    col_offset=0,
+                    end_col_offset=0,
+                )
+                new_body.append(blank)
+            new_body.append(after)
+            before = after
+
+        return new_body
+
+
+def parse_cmdline() -> Namespace:
+    parser = ArgumentParser(description=__doc__)
+    parser.add_argument("filename", metavar="FILE", help="the file to process")
+    parser.add_argument(
+        "output", metavar="OUTPUT", nargs="?", help="file where to write (or stdout)"
+    )
+    opt = parser.parse_args()
+
+    return opt
+
+
+if __name__ == "__main__":
+    sys.exit(main())
diff --git a/tools/convert_async_to_sync.sh b/tools/convert_async_to_sync.sh
new file mode 100755 (executable)
index 0000000..93bbc65
--- /dev/null
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+# Convert all the auto-generated sync files from their async counterparts.
+
+set -euo pipefail
+
+dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+cd "${dir}/.."
+
+python "${dir}/async_to_sync.py" tests/test_cursor_async.py > tests/test_cursor.py
+black -q tests/test_cursor.py