]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-119180: Avoid going through AST and eval() when possible in annotationlib (#124337)
authorJelle Zijlstra <jelle.zijlstra@gmail.com>
Wed, 25 Sep 2024 21:14:03 +0000 (14:14 -0700)
committerGitHub <noreply@github.com>
Wed, 25 Sep 2024 21:14:03 +0000 (21:14 +0000)
Often, ForwardRefs represent a single simple name. In that case, we
can avoid going through the overhead of creating AST nodes and code
objects and calling eval(): we can simply look up the name directly
in the relevant namespaces.

Co-authored-by: Victor Stinner <vstinner@python.org>
Lib/annotationlib.py
Lib/test/test_annotationlib.py

index 0a67742a2b3081fb6d718fed365ee3c79c3f9556..be3bc275817f50ad66f439728c3d53d0c576e40d 100644 (file)
@@ -1,8 +1,10 @@
 """Helpers for introspecting and wrapping annotations."""
 
 import ast
+import builtins
 import enum
 import functools
+import keyword
 import sys
 import types
 
@@ -154,8 +156,19 @@ class ForwardRef:
                     globals[param_name] = param
                     locals.pop(param_name, None)
 
-        code = self.__forward_code__
-        value = eval(code, globals=globals, locals=locals)
+        arg = self.__forward_arg__
+        if arg.isidentifier() and not keyword.iskeyword(arg):
+            if arg in locals:
+                value = locals[arg]
+            elif arg in globals:
+                value = globals[arg]
+            elif hasattr(builtins, arg):
+                return getattr(builtins, arg)
+            else:
+                raise NameError(arg)
+        else:
+            code = self.__forward_code__
+            value = eval(code, globals=globals, locals=locals)
         self.__forward_evaluated__ = True
         self.__forward_value__ = value
         return value
@@ -254,7 +267,9 @@ class _Stringifier:
     __slots__ = _SLOTS
 
     def __init__(self, node, globals=None, owner=None, is_class=False, cell=None):
-        assert isinstance(node, ast.AST)
+        # Either an AST node or a simple str (for the common case where a ForwardRef
+        # represent a single name).
+        assert isinstance(node, (ast.AST, str))
         self.__arg__ = None
         self.__forward_evaluated__ = False
         self.__forward_value__ = None
@@ -267,18 +282,26 @@ class _Stringifier:
         self.__cell__ = cell
         self.__owner__ = owner
 
-    def __convert(self, other):
+    def __convert_to_ast(self, other):
         if isinstance(other, _Stringifier):
+            if isinstance(other.__ast_node__, str):
+                return ast.Name(id=other.__ast_node__)
             return other.__ast_node__
         elif isinstance(other, slice):
             return ast.Slice(
-                lower=self.__convert(other.start) if other.start is not None else None,
-                upper=self.__convert(other.stop) if other.stop is not None else None,
-                step=self.__convert(other.step) if other.step is not None else None,
+                lower=self.__convert_to_ast(other.start) if other.start is not None else None,
+                upper=self.__convert_to_ast(other.stop) if other.stop is not None else None,
+                step=self.__convert_to_ast(other.step) if other.step is not None else None,
             )
         else:
             return ast.Constant(value=other)
 
+    def __get_ast(self):
+        node = self.__ast_node__
+        if isinstance(node, str):
+            return ast.Name(id=node)
+        return node
+
     def __make_new(self, node):
         return _Stringifier(
             node, self.__globals__, self.__owner__, self.__forward_is_class__
@@ -292,38 +315,37 @@ class _Stringifier:
     def __getitem__(self, other):
         # Special case, to avoid stringifying references to class-scoped variables
         # as '__classdict__["x"]'.
-        if (
-            isinstance(self.__ast_node__, ast.Name)
-            and self.__ast_node__.id == "__classdict__"
-        ):
+        if self.__ast_node__ == "__classdict__":
             raise KeyError
         if isinstance(other, tuple):
-            elts = [self.__convert(elt) for elt in other]
+            elts = [self.__convert_to_ast(elt) for elt in other]
             other = ast.Tuple(elts)
         else:
-            other = self.__convert(other)
+            other = self.__convert_to_ast(other)
         assert isinstance(other, ast.AST), repr(other)
-        return self.__make_new(ast.Subscript(self.__ast_node__, other))
+        return self.__make_new(ast.Subscript(self.__get_ast(), other))
 
     def __getattr__(self, attr):
-        return self.__make_new(ast.Attribute(self.__ast_node__, attr))
+        return self.__make_new(ast.Attribute(self.__get_ast(), attr))
 
     def __call__(self, *args, **kwargs):
         return self.__make_new(
             ast.Call(
-                self.__ast_node__,
-                [self.__convert(arg) for arg in args],
+                self.__get_ast(),
+                [self.__convert_to_ast(arg) for arg in args],
                 [
-                    ast.keyword(key, self.__convert(value))
+                    ast.keyword(key, self.__convert_to_ast(value))
                     for key, value in kwargs.items()
                 ],
             )
         )
 
     def __iter__(self):
-        yield self.__make_new(ast.Starred(self.__ast_node__))
+        yield self.__make_new(ast.Starred(self.__get_ast()))
 
     def __repr__(self):
+        if isinstance(self.__ast_node__, str):
+            return self.__ast_node__
         return ast.unparse(self.__ast_node__)
 
     def __format__(self, format_spec):
@@ -332,7 +354,7 @@ class _Stringifier:
     def _make_binop(op: ast.AST):
         def binop(self, other):
             return self.__make_new(
-                ast.BinOp(self.__ast_node__, op, self.__convert(other))
+                ast.BinOp(self.__get_ast(), op, self.__convert_to_ast(other))
             )
 
         return binop
@@ -356,7 +378,7 @@ class _Stringifier:
     def _make_rbinop(op: ast.AST):
         def rbinop(self, other):
             return self.__make_new(
-                ast.BinOp(self.__convert(other), op, self.__ast_node__)
+                ast.BinOp(self.__convert_to_ast(other), op, self.__get_ast())
             )
 
         return rbinop
@@ -381,9 +403,9 @@ class _Stringifier:
         def compare(self, other):
             return self.__make_new(
                 ast.Compare(
-                    left=self.__ast_node__,
+                    left=self.__get_ast(),
                     ops=[op],
-                    comparators=[self.__convert(other)],
+                    comparators=[self.__convert_to_ast(other)],
                 )
             )
 
@@ -400,7 +422,7 @@ class _Stringifier:
 
     def _make_unary_op(op):
         def unary_op(self):
-            return self.__make_new(ast.UnaryOp(op, self.__ast_node__))
+            return self.__make_new(ast.UnaryOp(op, self.__get_ast()))
 
         return unary_op
 
@@ -422,7 +444,7 @@ class _StringifierDict(dict):
 
     def __missing__(self, key):
         fwdref = _Stringifier(
-            ast.Name(id=key),
+            key,
             globals=self.globals,
             owner=self.owner,
             is_class=self.is_class,
@@ -480,7 +502,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
                     name = freevars[i]
                 else:
                     name = "__cell__"
-                fwdref = _Stringifier(ast.Name(id=name))
+                fwdref = _Stringifier(name)
                 new_closure.append(types.CellType(fwdref))
             closure = tuple(new_closure)
         else:
@@ -532,7 +554,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
                     else:
                         name = "__cell__"
                     fwdref = _Stringifier(
-                        ast.Name(id=name),
+                        name,
                         cell=cell,
                         owner=owner,
                         globals=annotate.__globals__,
@@ -555,6 +577,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
         result = func(Format.VALUE)
         for obj in globals.stringifiers:
             obj.__class__ = ForwardRef
+            if isinstance(obj.__ast_node__, str):
+                obj.__arg__ = obj.__ast_node__
+                obj.__ast_node__ = None
         return result
     elif format == Format.VALUE:
         # Should be impossible because __annotate__ functions must not raise
index dd8ceb55a411fbf87936512a45f9339e59188d1f..cc051ef3b9365817f651bcf90c352749a89c6257 100644 (file)
@@ -1,6 +1,7 @@
 """Tests for the annotations module."""
 
 import annotationlib
+import builtins
 import collections
 import functools
 import itertools
@@ -280,7 +281,14 @@ class TestForwardRefClass(unittest.TestCase):
 
     def test_fwdref_with_module(self):
         self.assertIs(ForwardRef("Format", module="annotationlib").evaluate(), Format)
-        self.assertIs(ForwardRef("Counter", module="collections").evaluate(), collections.Counter)
+        self.assertIs(
+            ForwardRef("Counter", module="collections").evaluate(),
+            collections.Counter
+        )
+        self.assertEqual(
+            ForwardRef("Counter[int]", module="collections").evaluate(),
+            collections.Counter[int],
+        )
 
         with self.assertRaises(NameError):
             # If globals are passed explicitly, we don't look at the module dict
@@ -305,6 +313,33 @@ class TestForwardRefClass(unittest.TestCase):
         self.assertIs(fr.evaluate(globals={"hello": str}), str)
         self.assertIs(fr.evaluate(), str)
 
+    def test_fwdref_with_owner(self):
+        self.assertEqual(
+            ForwardRef("Counter[int]", owner=collections).evaluate(),
+            collections.Counter[int],
+        )
+
+    def test_name_lookup_without_eval(self):
+        # test the codepath where we look up simple names directly in the
+        # namespaces without going through eval()
+        self.assertIs(ForwardRef("int").evaluate(), int)
+        self.assertIs(ForwardRef("int").evaluate(locals={"int": str}), str)
+        self.assertIs(ForwardRef("int").evaluate(locals={"int": float}, globals={"int": str}), float)
+        self.assertIs(ForwardRef("int").evaluate(globals={"int": str}), str)
+        with support.swap_attr(builtins, "int", dict):
+            self.assertIs(ForwardRef("int").evaluate(), dict)
+
+        with self.assertRaises(NameError):
+            ForwardRef("doesntexist").evaluate()
+
+    def test_fwdref_invalid_syntax(self):
+        fr = ForwardRef("if")
+        with self.assertRaises(SyntaxError):
+            fr.evaluate()
+        fr = ForwardRef("1+")
+        with self.assertRaises(SyntaxError):
+            fr.evaluate()
+
 
 class TestGetAnnotations(unittest.TestCase):
     def test_builtin_type(self):