from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter
from importlib.metadata import version
-import ast_comments as ast
+import ast_comments as ast # type: ignore
# The version of Python officially used for the conversion.
# Output may differ in other versions.
return rv
-class AsyncToSync(ast.NodeTransformer):
+class AsyncToSync(ast.NodeTransformer): # type: ignore
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST:
new_node = ast.FunctionDef(**node.__dict__)
ast.copy_location(new_node, node)
# Assume that the test guards an async object becoming sync and remove
# the async side, because it will likely contain `await` constructs
# illegal into a sync function.
- if self._is_async_call(node.test):
- for child in node.orelse:
- self.visit(child)
- return node.orelse
-
- # Manage `if True: # ASYNC`
- # drop the unneeded branch
- if (stmts := self._async_test_statements(node)) is not None:
- for child in stmts:
- self.visit(child)
- return stmts
+ value: bool
+ comment: str
+ match node:
+ # manage `is_async()`
+ case ast.If(test=ast.Call(func=ast.Name(id="is_async"))):
+ for child in node.orelse:
+ self.visit(child)
+ return node.orelse
+
+ # Manage `if True|False: # ASYNC`
+ # drop the unneeded branch
+ case ast.If(
+ test=ast.Constant(value=bool(value)),
+ body=[ast.Comment(value=comment), *_],
+ ) if comment.startswith("# ASYNC"):
+ stmts: list[ast.AST]
+ # body[0] is the ASYNC comment, drop it
+ stmts = node.orelse if value else node.body[1:]
+ for child in stmts:
+ self.visit(child)
+ return stmts
self.generic_visit(node)
return node
- def _is_async_call(self, test: ast.AST) -> bool:
- if not isinstance(test, ast.Call):
- return False
- if not isinstance(test.func, ast.Name):
- return False
- if test.func.id != "is_async":
- return False
- return True
-
- def _async_test_statements(self, node: ast.If) -> list[ast.AST] | None:
- if not (
- isinstance(node.test, ast.Constant) and isinstance(node.test.value, bool)
- ):
- return None
-
- if not (node.body and isinstance(node.body[0], ast.Comment)):
- return None
-
- comment = node.body[0].value
-
- if not comment.startswith("# ASYNC"):
- return None
- stmts: list[ast.AST]
- if node.test.value:
- stmts = node.orelse
- else:
- stmts = node.body[1:] # skip the ASYNC comment
- return stmts
-
-
-class RenameAsyncToSync(ast.NodeTransformer):
+class RenameAsyncToSync(ast.NodeTransformer): # type: ignore
names_map = {
"ACT": "CT",
"ACondition": "Condition",
for arg in node.args.args:
arg.arg = self.names_map.get(arg.arg, arg.arg)
for arg in node.args.args:
- ann = arg.annotation
- if not ann:
- continue
- if isinstance(ann, ast.Subscript):
- # Remove the [] from the type
- ann = ann.value
- if isinstance(ann, ast.Attribute):
- ann.attr = self.names_map.get(ann.attr, ann.attr)
+ attr: str
+ match arg.annotation:
+ case ast.arg(annotation=ast.Attribute(attr=attr)):
+ arg.annotation.attr = self.names_map.get(attr, attr)
+ case ast.arg(annotation=ast.Subscript(value=ast.Attribute(attr=attr))):
+ arg.annotation.value.attr = self.names_map.get(attr, attr)
self.generic_visit(node)
return node
return node
def _fix_docstring(self, body: list[ast.AST]) -> None:
- if (
- body
- and isinstance(body[0], ast.Expr)
- and isinstance(body[0].value, ast.Constant)
- and isinstance(body[0].value.value, str)
- ):
- body[0].value.value = body[0].value.value.replace("Async", "")
- body[0].value.value = body[0].value.value.replace("(async", "(sync")
+ doc: str
+ match body and body[0]:
+ case ast.Expr(value=ast.Constant(value=str(doc))):
+ doc = doc.replace("Async", "")
+ doc = doc.replace("(async", "(sync")
+ body[0].value.value = doc
def _fix_decorator(self, decorator_list: list[ast.AST]) -> None:
for dec in decorator_list:
elts[i] = self._convert_if_literal_string(elt)
def _convert_if_literal_string(self, node: ast.AST) -> ast.AST:
+ value: str
match node:
- case ast.Constant(value=str()):
- node.value = self._visit_type_string(node.value)
+ case ast.Constant(value=str(value)):
+ node.value = self._visit_type_string(value)
return node
# Handle :
# class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
# the base cannot be a token, even with __future__ annotation.
+ elts: list[ast.AST]
for base in node.bases:
- if not isinstance(base, ast.Subscript):
- continue
-
- if isinstance(base.slice, ast.Tuple):
- elts = base.slice.elts
- for i, elt in enumerate(elts):
- elts[i] = self._convert_if_literal_string(elt)
- else:
- base.slice = self._convert_if_literal_string(base.slice)
+ match base:
+ case ast.Subscript(slice=ast.Tuple(elts=elts)):
+ for i, elt in enumerate(elts):
+ elts[i] = self._convert_if_literal_string(elt)
+ case ast.Subscript(slice=ast.Constant()):
+ base.slice = self._convert_if_literal_string(base.slice)
return node
return node
def _manage_async_generator(self, node: ast.Subscript) -> ast.AST | None:
- if not (isinstance(node.value, ast.Name) and node.value.id == "AsyncGenerator"):
- return None
-
- if not (isinstance(node.slice, ast.Tuple) and len(node.slice.elts) == 2):
- return None
-
- node.slice.elts.insert(1, deepcopy(node.slice.elts[1]))
- self.generic_visit(node)
- return node
+ match node:
+ case ast.Subscript(
+ value=ast.Name(id="AsyncGenerator"), slice=ast.Tuple(elts=[_, _])
+ ):
+ node.slice.elts.insert(1, deepcopy(node.slice.elts[1]))
+ self.generic_visit(node)
+ return node
+ return None
-class BlanksInserter(ast.NodeTransformer):
+class BlanksInserter(ast.NodeTransformer): # type: ignore
"""
Restore the missing spaces in the source (or something similar)
"""
return "\n".join(lines)
-class Unparser(ast._Unparser):
+class Unparser(ast._Unparser): # type: ignore
"""
Try to emit long strings as multiline.