From: Daniele Varrazzo Date: Fri, 12 Apr 2024 00:47:05 +0000 (+0200) Subject: refactor(async-to-sync): use pattern matching for more compact tests X-Git-Tag: 3.2.0~46^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=de9d188fc9bf0097721ee6c2e478789d930ba46a;p=thirdparty%2Fpsycopg.git refactor(async-to-sync): use pattern matching for more compact tests --- diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index cffda9b2c..16b67e4da 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -22,7 +22,7 @@ from pathlib import Path from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter from importlib.metadata import version -import ast_comments as ast +import ast_comments as ast # type: ignore # The version of Python officially used for the conversion. # Output may differ in other versions. @@ -208,7 +208,7 @@ def tree_to_str(tree: ast.AST, filepath: Path) -> str: return rv -class AsyncToSync(ast.NodeTransformer): +class AsyncToSync(ast.NodeTransformer): # type: ignore def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: new_node = ast.FunctionDef(**node.__dict__) ast.copy_location(new_node, node) @@ -238,53 +238,33 @@ class AsyncToSync(ast.NodeTransformer): # Assume that the test guards an async object becoming sync and remove # the async side, because it will likely contain `await` constructs # illegal into a sync function. - if self._is_async_call(node.test): - for child in node.orelse: - self.visit(child) - return node.orelse - - # Manage `if True: # ASYNC` - # drop the unneeded branch - if (stmts := self._async_test_statements(node)) is not None: - for child in stmts: - self.visit(child) - return stmts + value: bool + comment: str + match node: + # manage `is_async()` + case ast.If(test=ast.Call(func=ast.Name(id="is_async"))): + for child in node.orelse: + self.visit(child) + return node.orelse + + # Manage `if True|False: # ASYNC` + # drop the unneeded branch + case ast.If( + test=ast.Constant(value=bool(value)), + body=[ast.Comment(value=comment), *_], + ) if comment.startswith("# ASYNC"): + stmts: list[ast.AST] + # body[0] is the ASYNC comment, drop it + stmts = node.orelse if value else node.body[1:] + for child in stmts: + self.visit(child) + return stmts self.generic_visit(node) return node - def _is_async_call(self, test: ast.AST) -> bool: - if not isinstance(test, ast.Call): - return False - if not isinstance(test.func, ast.Name): - return False - if test.func.id != "is_async": - return False - return True - - def _async_test_statements(self, node: ast.If) -> list[ast.AST] | None: - if not ( - isinstance(node.test, ast.Constant) and isinstance(node.test.value, bool) - ): - return None - - if not (node.body and isinstance(node.body[0], ast.Comment)): - return None - - comment = node.body[0].value - - if not comment.startswith("# ASYNC"): - return None - stmts: list[ast.AST] - if node.test.value: - stmts = node.orelse - else: - stmts = node.body[1:] # skip the ASYNC comment - return stmts - - -class RenameAsyncToSync(ast.NodeTransformer): +class RenameAsyncToSync(ast.NodeTransformer): # type: ignore names_map = { "ACT": "CT", "ACondition": "Condition", @@ -359,14 +339,12 @@ class RenameAsyncToSync(ast.NodeTransformer): for arg in node.args.args: arg.arg = self.names_map.get(arg.arg, arg.arg) for arg in node.args.args: - ann = arg.annotation - if not ann: - continue - if isinstance(ann, ast.Subscript): - # Remove the [] from the type - ann = ann.value - if isinstance(ann, ast.Attribute): - ann.attr = self.names_map.get(ann.attr, ann.attr) + attr: str + match arg.annotation: + case ast.arg(annotation=ast.Attribute(attr=attr)): + arg.annotation.attr = self.names_map.get(attr, attr) + case ast.arg(annotation=ast.Subscript(value=ast.Attribute(attr=attr))): + arg.annotation.value.attr = self.names_map.get(attr, attr) self.generic_visit(node) return node @@ -379,14 +357,12 @@ class RenameAsyncToSync(ast.NodeTransformer): return node def _fix_docstring(self, body: list[ast.AST]) -> None: - if ( - body - and isinstance(body[0], ast.Expr) - and isinstance(body[0].value, ast.Constant) - and isinstance(body[0].value.value, str) - ): - body[0].value.value = body[0].value.value.replace("Async", "") - body[0].value.value = body[0].value.value.replace("(async", "(sync") + doc: str + match body and body[0]: + case ast.Expr(value=ast.Constant(value=str(doc))): + doc = doc.replace("Async", "") + doc = doc.replace("(async", "(sync") + body[0].value.value = doc def _fix_decorator(self, decorator_list: list[ast.AST]) -> None: for dec in decorator_list: @@ -400,9 +376,10 @@ class RenameAsyncToSync(ast.NodeTransformer): elts[i] = self._convert_if_literal_string(elt) def _convert_if_literal_string(self, node: ast.AST) -> ast.AST: + value: str match node: - case ast.Constant(value=str()): - node.value = self._visit_type_string(node.value) + case ast.Constant(value=str(value)): + node.value = self._visit_type_string(value) return node @@ -424,16 +401,14 @@ class RenameAsyncToSync(ast.NodeTransformer): # Handle : # class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): # the base cannot be a token, even with __future__ annotation. + elts: list[ast.AST] for base in node.bases: - if not isinstance(base, ast.Subscript): - continue - - if isinstance(base.slice, ast.Tuple): - elts = base.slice.elts - for i, elt in enumerate(elts): - elts[i] = self._convert_if_literal_string(elt) - else: - base.slice = self._convert_if_literal_string(base.slice) + match base: + case ast.Subscript(slice=ast.Tuple(elts=elts)): + for i, elt in enumerate(elts): + elts[i] = self._convert_if_literal_string(elt) + case ast.Subscript(slice=ast.Constant()): + base.slice = self._convert_if_literal_string(base.slice) return node @@ -471,18 +446,17 @@ class RenameAsyncToSync(ast.NodeTransformer): return node def _manage_async_generator(self, node: ast.Subscript) -> ast.AST | None: - if not (isinstance(node.value, ast.Name) and node.value.id == "AsyncGenerator"): - return None - - if not (isinstance(node.slice, ast.Tuple) and len(node.slice.elts) == 2): - return None - - node.slice.elts.insert(1, deepcopy(node.slice.elts[1])) - self.generic_visit(node) - return node + match node: + case ast.Subscript( + value=ast.Name(id="AsyncGenerator"), slice=ast.Tuple(elts=[_, _]) + ): + node.slice.elts.insert(1, deepcopy(node.slice.elts[1])) + self.generic_visit(node) + return node + return None -class BlanksInserter(ast.NodeTransformer): +class BlanksInserter(ast.NodeTransformer): # type: ignore """ Restore the missing spaces in the source (or something similar) """ @@ -561,7 +535,7 @@ def _fix_comment_on_decorators(source: str) -> str: return "\n".join(lines) -class Unparser(ast._Unparser): +class Unparser(ast._Unparser): # type: ignore """ Try to emit long strings as multiline.