):
elts = dec.keywords[0].value.elts
for i, elt in enumerate(elts):
- elts[i] = self._visit_str(elt)
+ elts[i] = self._convert_if_literal_string(elt)
- def visit_Call(self, node: ast.Call) -> ast.AST:
- if isinstance(node.func, ast.Name) and node.func.id == "TypeVar":
- node = self._visit_Call_TypeVar(node)
-
- self.generic_visit(node)
- return node
-
- def _visit_Call_TypeVar(self, node: ast.Call) -> ast.AST:
- for kw in node.keywords:
- if kw.arg != "bound":
- continue
- if not isinstance(kw.value, ast.Constant):
- continue
- if not isinstance(kw.value.value, str):
- continue
- kw.value.value = self._visit_type_string(kw.value.value)
-
- return node
-
- def _visit_str(self, node: ast.AST) -> ast.AST:
+ def _convert_if_literal_string(self, node: ast.AST) -> ast.AST:
match node:
case ast.Constant(value=str()):
node.value = self._visit_type_string(node.value)
if not isinstance(base, ast.Subscript):
continue
- if isinstance(base.slice, ast.Constant):
- if not isinstance(base.slice.value, str):
- continue
- base.slice.value = self._visit_type_string(base.slice.value)
- elif isinstance(base.slice, ast.Tuple):
- for elt in base.slice.elts:
- if not (
- isinstance(elt, ast.Constant) and isinstance(elt.value, str)
- ):
- continue
- elt.value = self._visit_type_string(elt.value)
+ if isinstance(base.slice, ast.Tuple):
+ elts = base.slice.elts
+ for i, elt in enumerate(elts):
+ elts[i] = self._convert_if_literal_string(elt)
+ else:
+ base.slice = self._convert_if_literal_string(base.slice)
return node