]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(async-to-sync): simpler pattern to convert string literals
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 11 Apr 2024 23:32:18 +0000 (01:32 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Apr 2024 17:38:54 +0000 (19:38 +0200)
tools/async_to_sync.py

index 4c1ed2d8aa1e51b40f5fc563a25f75d295e08823..cffda9b2cb03957eabc3afcd7fa91d4b80002e00 100755 (executable)
@@ -397,28 +397,9 @@ class RenameAsyncToSync(ast.NodeTransformer):
                 ):
                     elts = dec.keywords[0].value.elts
                     for i, elt in enumerate(elts):
-                        elts[i] = self._visit_str(elt)
+                        elts[i] = self._convert_if_literal_string(elt)
 
-    def visit_Call(self, node: ast.Call) -> ast.AST:
-        if isinstance(node.func, ast.Name) and node.func.id == "TypeVar":
-            node = self._visit_Call_TypeVar(node)
-
-        self.generic_visit(node)
-        return node
-
-    def _visit_Call_TypeVar(self, node: ast.Call) -> ast.AST:
-        for kw in node.keywords:
-            if kw.arg != "bound":
-                continue
-            if not isinstance(kw.value, ast.Constant):
-                continue
-            if not isinstance(kw.value.value, str):
-                continue
-            kw.value.value = self._visit_type_string(kw.value.value)
-
-        return node
-
-    def _visit_str(self, node: ast.AST) -> ast.AST:
+    def _convert_if_literal_string(self, node: ast.AST) -> ast.AST:
         match node:
             case ast.Constant(value=str()):
                 node.value = self._visit_type_string(node.value)
@@ -447,17 +428,12 @@ class RenameAsyncToSync(ast.NodeTransformer):
             if not isinstance(base, ast.Subscript):
                 continue
 
-            if isinstance(base.slice, ast.Constant):
-                if not isinstance(base.slice.value, str):
-                    continue
-                base.slice.value = self._visit_type_string(base.slice.value)
-            elif isinstance(base.slice, ast.Tuple):
-                for elt in base.slice.elts:
-                    if not (
-                        isinstance(elt, ast.Constant) and isinstance(elt.value, str)
-                    ):
-                        continue
-                    elt.value = self._visit_type_string(elt.value)
+            if isinstance(base.slice, ast.Tuple):
+                elts = base.slice.elts
+                for i, elt in enumerate(elts):
+                    elts[i] = self._convert_if_literal_string(elt)
+            else:
+                base.slice = self._convert_if_literal_string(base.slice)
 
         return node