]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(async-to-sync): use pattern matching for more compact tests
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Apr 2024 00:47:05 +0000 (02:47 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Apr 2024 17:38:54 +0000 (19:38 +0200)
tools/async_to_sync.py

index cffda9b2cb03957eabc3afcd7fa91d4b80002e00..16b67e4da077a5869e20cca6911d53b63fdc83e9 100755 (executable)
@@ -22,7 +22,7 @@ from pathlib import Path
 from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter
 from importlib.metadata import version
 
-import ast_comments as ast
+import ast_comments as ast  # type: ignore
 
 # The version of Python officially used for the conversion.
 # Output may differ in other versions.
@@ -208,7 +208,7 @@ def tree_to_str(tree: ast.AST, filepath: Path) -> str:
     return rv
 
 
-class AsyncToSync(ast.NodeTransformer):
+class AsyncToSync(ast.NodeTransformer):  # type: ignore
     def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST:
         new_node = ast.FunctionDef(**node.__dict__)
         ast.copy_location(new_node, node)
@@ -238,53 +238,33 @@ class AsyncToSync(ast.NodeTransformer):
         # Assume that the test guards an async object becoming sync and remove
         # the async side, because it will likely contain `await` constructs
         # illegal into a sync function.
-        if self._is_async_call(node.test):
-            for child in node.orelse:
-                self.visit(child)
-            return node.orelse
-
-        # Manage `if True:  # ASYNC`
-        # drop the unneeded branch
-        if (stmts := self._async_test_statements(node)) is not None:
-            for child in stmts:
-                self.visit(child)
-            return stmts
+        value: bool
+        comment: str
+        match node:
+            # manage `is_async()`
+            case ast.If(test=ast.Call(func=ast.Name(id="is_async"))):
+                for child in node.orelse:
+                    self.visit(child)
+                return node.orelse
+
+            # Manage `if True|False:  # ASYNC`
+            # drop the unneeded branch
+            case ast.If(
+                test=ast.Constant(value=bool(value)),
+                body=[ast.Comment(value=comment), *_],
+            ) if comment.startswith("# ASYNC"):
+                stmts: list[ast.AST]
+                # body[0] is the ASYNC comment, drop it
+                stmts = node.orelse if value else node.body[1:]
+                for child in stmts:
+                    self.visit(child)
+                return stmts
 
         self.generic_visit(node)
         return node
 
-    def _is_async_call(self, test: ast.AST) -> bool:
-        if not isinstance(test, ast.Call):
-            return False
-        if not isinstance(test.func, ast.Name):
-            return False
-        if test.func.id != "is_async":
-            return False
-        return True
-
-    def _async_test_statements(self, node: ast.If) -> list[ast.AST] | None:
-        if not (
-            isinstance(node.test, ast.Constant) and isinstance(node.test.value, bool)
-        ):
-            return None
-
-        if not (node.body and isinstance(node.body[0], ast.Comment)):
-            return None
-
-        comment = node.body[0].value
-
-        if not comment.startswith("# ASYNC"):
-            return None
 
-        stmts: list[ast.AST]
-        if node.test.value:
-            stmts = node.orelse
-        else:
-            stmts = node.body[1:]  # skip the ASYNC comment
-        return stmts
-
-
-class RenameAsyncToSync(ast.NodeTransformer):
+class RenameAsyncToSync(ast.NodeTransformer):  # type: ignore
     names_map = {
         "ACT": "CT",
         "ACondition": "Condition",
@@ -359,14 +339,12 @@ class RenameAsyncToSync(ast.NodeTransformer):
         for arg in node.args.args:
             arg.arg = self.names_map.get(arg.arg, arg.arg)
         for arg in node.args.args:
-            ann = arg.annotation
-            if not ann:
-                continue
-            if isinstance(ann, ast.Subscript):
-                # Remove the [] from the type
-                ann = ann.value
-            if isinstance(ann, ast.Attribute):
-                ann.attr = self.names_map.get(ann.attr, ann.attr)
+            attr: str
+            match arg.annotation:
+                case ast.arg(annotation=ast.Attribute(attr=attr)):
+                    arg.annotation.attr = self.names_map.get(attr, attr)
+                case ast.arg(annotation=ast.Subscript(value=ast.Attribute(attr=attr))):
+                    arg.annotation.value.attr = self.names_map.get(attr, attr)
 
         self.generic_visit(node)
         return node
@@ -379,14 +357,12 @@ class RenameAsyncToSync(ast.NodeTransformer):
         return node
 
     def _fix_docstring(self, body: list[ast.AST]) -> None:
-        if (
-            body
-            and isinstance(body[0], ast.Expr)
-            and isinstance(body[0].value, ast.Constant)
-            and isinstance(body[0].value.value, str)
-        ):
-            body[0].value.value = body[0].value.value.replace("Async", "")
-            body[0].value.value = body[0].value.value.replace("(async", "(sync")
+        doc: str
+        match body and body[0]:
+            case ast.Expr(value=ast.Constant(value=str(doc))):
+                doc = doc.replace("Async", "")
+                doc = doc.replace("(async", "(sync")
+                body[0].value.value = doc
 
     def _fix_decorator(self, decorator_list: list[ast.AST]) -> None:
         for dec in decorator_list:
@@ -400,9 +376,10 @@ class RenameAsyncToSync(ast.NodeTransformer):
                         elts[i] = self._convert_if_literal_string(elt)
 
     def _convert_if_literal_string(self, node: ast.AST) -> ast.AST:
+        value: str
         match node:
-            case ast.Constant(value=str()):
-                node.value = self._visit_type_string(node.value)
+            case ast.Constant(value=str(value)):
+                node.value = self._visit_type_string(value)
 
         return node
 
@@ -424,16 +401,14 @@ class RenameAsyncToSync(ast.NodeTransformer):
         # Handle :
         #   class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         # the base cannot be a token, even with __future__ annotation.
+        elts: list[ast.AST]
         for base in node.bases:
-            if not isinstance(base, ast.Subscript):
-                continue
-
-            if isinstance(base.slice, ast.Tuple):
-                elts = base.slice.elts
-                for i, elt in enumerate(elts):
-                    elts[i] = self._convert_if_literal_string(elt)
-            else:
-                base.slice = self._convert_if_literal_string(base.slice)
+            match base:
+                case ast.Subscript(slice=ast.Tuple(elts=elts)):
+                    for i, elt in enumerate(elts):
+                        elts[i] = self._convert_if_literal_string(elt)
+                case ast.Subscript(slice=ast.Constant()):
+                    base.slice = self._convert_if_literal_string(base.slice)
 
         return node
 
@@ -471,18 +446,17 @@ class RenameAsyncToSync(ast.NodeTransformer):
         return node
 
     def _manage_async_generator(self, node: ast.Subscript) -> ast.AST | None:
-        if not (isinstance(node.value, ast.Name) and node.value.id == "AsyncGenerator"):
-            return None
-
-        if not (isinstance(node.slice, ast.Tuple) and len(node.slice.elts) == 2):
-            return None
-
-        node.slice.elts.insert(1, deepcopy(node.slice.elts[1]))
-        self.generic_visit(node)
-        return node
+        match node:
+            case ast.Subscript(
+                value=ast.Name(id="AsyncGenerator"), slice=ast.Tuple(elts=[_, _])
+            ):
+                node.slice.elts.insert(1, deepcopy(node.slice.elts[1]))
+                self.generic_visit(node)
+                return node
+        return None
 
 
-class BlanksInserter(ast.NodeTransformer):
+class BlanksInserter(ast.NodeTransformer):  # type: ignore
     """
     Restore the missing spaces in the source (or something similar)
     """
@@ -561,7 +535,7 @@ def _fix_comment_on_decorators(source: str) -> str:
     return "\n".join(lines)
 
 
-class Unparser(ast._Unparser):
+class Unparser(ast._Unparser):  # type: ignore
     """
     Try to emit long strings as multiline.