]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-125618: Make FORWARDREF format succeed more often (#132818)
authorJelle Zijlstra <jelle.zijlstra@gmail.com>
Sun, 4 May 2025 22:21:56 +0000 (15:21 -0700)
committerGitHub <noreply@github.com>
Sun, 4 May 2025 22:21:56 +0000 (15:21 -0700)
Fixes #125618.

Doc/library/annotationlib.rst
Lib/annotationlib.py
Lib/test/test_annotationlib.py
Misc/NEWS.d/next/Library/2025-04-22-16-35-37.gh-issue-125618.PEocn3.rst [new file with mode: 0644]

index b9932a9e4cca1ffcb9935bde0e1455d11bf79b66..ff9578b6088f2827adf32122b9911f97d81e4db4 100644 (file)
@@ -132,7 +132,7 @@ Classes
 
       Values are real annotation values (as per :attr:`Format.VALUE` format)
       for defined values, and :class:`ForwardRef` proxies for undefined
-      values. Real objects may contain references to, :class:`ForwardRef`
+      values. Real objects may contain references to :class:`ForwardRef`
       proxy objects.
 
    .. attribute:: STRING
@@ -172,14 +172,21 @@ Classes
       :class:`~ForwardRef`. The string may not be exactly equivalent
       to the original source.
 
-   .. method:: evaluate(*, owner=None, globals=None, locals=None, type_params=None)
+   .. method:: evaluate(*, owner=None, globals=None, locals=None, type_params=None, format=Format.VALUE)
 
       Evaluate the forward reference, returning its value.
 
-      This may throw an exception, such as :exc:`NameError`, if the forward
+      If the *format* argument is :attr:`~Format.VALUE` (the default),
+      this method may throw an exception, such as :exc:`NameError`, if the forward
       reference refers to a name that cannot be resolved. The arguments to this
       method can be used to provide bindings for names that would otherwise
-      be undefined.
+      be undefined. If the *format* argument is :attr:`~Format.FORWARDREF`,
+      the method will never throw an exception, but may return a :class:`~ForwardRef`
+      instance. For example, if the forward reference object contains the code
+      ``list[undefined]``, where ``undefined`` is a name that is not defined,
+      evaluating it with the :attr:`~Format.FORWARDREF` format will return
+      ``list[ForwardRef('undefined')]``. If the *format* argument is
+      :attr:`~Format.STRING`, the method will return :attr:`~ForwardRef.__forward_arg__`.
 
       The *owner* parameter provides the preferred mechanism for passing scope
       information to this method. The owner of a :class:`~ForwardRef` is the
index cd24679f30abee5b0e78fdce346792ccab2aaae4..5ad0893106a72b8ff853c24df7cc4e06cabad96c 100644 (file)
@@ -92,11 +92,28 @@ class ForwardRef:
     def __init_subclass__(cls, /, *args, **kwds):
         raise TypeError("Cannot subclass ForwardRef")
 
-    def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
+    def evaluate(
+        self,
+        *,
+        globals=None,
+        locals=None,
+        type_params=None,
+        owner=None,
+        format=Format.VALUE,
+    ):
         """Evaluate the forward reference and return the value.
 
         If the forward reference cannot be evaluated, raise an exception.
         """
+        match format:
+            case Format.STRING:
+                return self.__forward_arg__
+            case Format.VALUE:
+                is_forwardref_format = False
+            case Format.FORWARDREF:
+                is_forwardref_format = True
+            case _:
+                raise NotImplementedError(format)
         if self.__cell__ is not None:
             try:
                 return self.__cell__.cell_contents
@@ -159,17 +176,36 @@ class ForwardRef:
         arg = self.__forward_arg__
         if arg.isidentifier() and not keyword.iskeyword(arg):
             if arg in locals:
-                value = locals[arg]
+                return locals[arg]
             elif arg in globals:
-                value = globals[arg]
+                return globals[arg]
             elif hasattr(builtins, arg):
                 return getattr(builtins, arg)
+            elif is_forwardref_format:
+                return self
             else:
                 raise NameError(arg)
         else:
             code = self.__forward_code__
-            value = eval(code, globals=globals, locals=locals)
-        return value
+            try:
+                return eval(code, globals=globals, locals=locals)
+            except Exception:
+                if not is_forwardref_format:
+                    raise
+            new_locals = _StringifierDict(
+                {**builtins.__dict__, **locals},
+                globals=globals,
+                owner=owner,
+                is_class=self.__forward_is_class__,
+                format=format,
+            )
+            try:
+                result = eval(code, globals=globals, locals=new_locals)
+            except Exception:
+                return self
+            else:
+                new_locals.transmogrify()
+                return result
 
     def _evaluate(self, globalns, localns, type_params=_sentinel, *, recursive_guard):
         import typing
@@ -546,6 +582,14 @@ class _StringifierDict(dict):
         self.stringifiers.append(fwdref)
         return fwdref
 
+    def transmogrify(self):
+        for obj in self.stringifiers:
+            obj.__class__ = ForwardRef
+            obj.__stringifier_dict__ = None  # not needed for ForwardRef
+            if isinstance(obj.__ast_node__, str):
+                obj.__arg__ = obj.__ast_node__
+                obj.__ast_node__ = None
+
     def create_unique_name(self):
         name = f"__annotationlib_name_{self.next_id}__"
         self.next_id += 1
@@ -595,19 +639,10 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
         # convert each of those into a string to get an approximation of the
         # original source.
         globals = _StringifierDict({}, format=format)
-        if annotate.__closure__:
-            freevars = annotate.__code__.co_freevars
-            new_closure = []
-            for i, cell in enumerate(annotate.__closure__):
-                if i < len(freevars):
-                    name = freevars[i]
-                else:
-                    name = "__cell__"
-                fwdref = _Stringifier(name, stringifier_dict=globals)
-                new_closure.append(types.CellType(fwdref))
-            closure = tuple(new_closure)
-        else:
-            closure = None
+        is_class = isinstance(owner, type)
+        closure = _build_closure(
+            annotate, owner, is_class, globals, allow_evaluation=False
+        )
         func = types.FunctionType(
             annotate.__code__,
             globals,
@@ -649,32 +684,36 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
             is_class=is_class,
             format=format,
         )
-        if annotate.__closure__:
-            freevars = annotate.__code__.co_freevars
-            new_closure = []
-            for i, cell in enumerate(annotate.__closure__):
-                try:
-                    cell.cell_contents
-                except ValueError:
-                    if i < len(freevars):
-                        name = freevars[i]
-                    else:
-                        name = "__cell__"
-                    fwdref = _Stringifier(
-                        name,
-                        cell=cell,
-                        owner=owner,
-                        globals=annotate.__globals__,
-                        is_class=is_class,
-                        stringifier_dict=globals,
-                    )
-                    globals.stringifiers.append(fwdref)
-                    new_closure.append(types.CellType(fwdref))
-                else:
-                    new_closure.append(cell)
-            closure = tuple(new_closure)
+        closure = _build_closure(
+            annotate, owner, is_class, globals, allow_evaluation=True
+        )
+        func = types.FunctionType(
+            annotate.__code__,
+            globals,
+            closure=closure,
+            argdefs=annotate.__defaults__,
+            kwdefaults=annotate.__kwdefaults__,
+        )
+        try:
+            result = func(Format.VALUE_WITH_FAKE_GLOBALS)
+        except Exception:
+            pass
         else:
-            closure = None
+            globals.transmogrify()
+            return result
+
+        # Try again, but do not provide any globals. This allows us to return
+        # a value in certain cases where an exception gets raised during evaluation.
+        globals = _StringifierDict(
+            {},
+            globals=annotate.__globals__,
+            owner=owner,
+            is_class=is_class,
+            format=format,
+        )
+        closure = _build_closure(
+            annotate, owner, is_class, globals, allow_evaluation=False
+        )
         func = types.FunctionType(
             annotate.__code__,
             globals,
@@ -683,13 +722,21 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
             kwdefaults=annotate.__kwdefaults__,
         )
         result = func(Format.VALUE_WITH_FAKE_GLOBALS)
-        for obj in globals.stringifiers:
-            obj.__class__ = ForwardRef
-            obj.__stringifier_dict__ = None  # not needed for ForwardRef
-            if isinstance(obj.__ast_node__, str):
-                obj.__arg__ = obj.__ast_node__
-                obj.__ast_node__ = None
-        return result
+        globals.transmogrify()
+        if _is_evaluate:
+            if isinstance(result, ForwardRef):
+                return result.evaluate(format=Format.FORWARDREF)
+            else:
+                return result
+        else:
+            return {
+                key: (
+                    val.evaluate(format=Format.FORWARDREF)
+                    if isinstance(val, ForwardRef)
+                    else val
+                )
+                for key, val in result.items()
+            }
     elif format == Format.VALUE:
         # Should be impossible because __annotate__ functions must not raise
         # NotImplementedError for this format.
@@ -698,6 +745,39 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
         raise ValueError(f"Invalid format: {format!r}")
 
 
+def _build_closure(annotate, owner, is_class, stringifier_dict, *, allow_evaluation):
+    if not annotate.__closure__:
+        return None
+    freevars = annotate.__code__.co_freevars
+    new_closure = []
+    for i, cell in enumerate(annotate.__closure__):
+        if i < len(freevars):
+            name = freevars[i]
+        else:
+            name = "__cell__"
+        new_cell = None
+        if allow_evaluation:
+            try:
+                cell.cell_contents
+            except ValueError:
+                pass
+            else:
+                new_cell = cell
+        if new_cell is None:
+            fwdref = _Stringifier(
+                name,
+                cell=cell,
+                owner=owner,
+                globals=annotate.__globals__,
+                is_class=is_class,
+                stringifier_dict=stringifier_dict,
+            )
+            stringifier_dict.stringifiers.append(fwdref)
+            new_cell = types.CellType(fwdref)
+        new_closure.append(new_cell)
+    return tuple(new_closure)
+
+
 def _stringify_single(anno):
     if anno is ...:
         return "..."
@@ -809,7 +889,7 @@ def get_annotations(
             # But if we didn't get it, we use __annotations__ instead.
             ann = _get_dunder_annotations(obj)
             if ann is not None:
-                 return annotations_to_string(ann)
+                return annotations_to_string(ann)
         case Format.VALUE_WITH_FAKE_GLOBALS:
             raise ValueError("The VALUE_WITH_FAKE_GLOBALS format is for internal use only")
         case _:
index d9000b6392277e90866da8ca73d2a9715acad1ca..13c6a2a584bc7055d897b21960ed1762e81ed009 100644 (file)
@@ -276,10 +276,10 @@ class TestStringFormat(unittest.TestCase):
     def test_getitem(self):
         def f(x: undef1[str, undef2]):
             pass
-        anno = annotationlib.get_annotations(f, format=Format.STRING)
+        anno = get_annotations(f, format=Format.STRING)
         self.assertEqual(anno, {"x": "undef1[str, undef2]"})
 
-        anno = annotationlib.get_annotations(f, format=Format.FORWARDREF)
+        anno = get_annotations(f, format=Format.FORWARDREF)
         fwdref = anno["x"]
         self.assertIsInstance(fwdref, ForwardRef)
         self.assertEqual(
@@ -289,18 +289,18 @@ class TestStringFormat(unittest.TestCase):
     def test_slice(self):
         def f(x: a[b:c]):
             pass
-        anno = annotationlib.get_annotations(f, format=Format.STRING)
+        anno = get_annotations(f, format=Format.STRING)
         self.assertEqual(anno, {"x": "a[b:c]"})
 
         def f(x: a[b:c, d:e]):
             pass
-        anno = annotationlib.get_annotations(f, format=Format.STRING)
+        anno = get_annotations(f, format=Format.STRING)
         self.assertEqual(anno, {"x": "a[b:c, d:e]"})
 
         obj = slice(1, 1, 1)
         def f(x: obj):
             pass
-        anno = annotationlib.get_annotations(f, format=Format.STRING)
+        anno = get_annotations(f, format=Format.STRING)
         self.assertEqual(anno, {"x": "obj"})
 
     def test_literals(self):
@@ -316,7 +316,7 @@ class TestStringFormat(unittest.TestCase):
         ):
             pass
 
-        anno = annotationlib.get_annotations(f, format=Format.STRING)
+        anno = get_annotations(f, format=Format.STRING)
         self.assertEqual(
             anno,
             {
@@ -335,7 +335,7 @@ class TestStringFormat(unittest.TestCase):
         # Simple case first
         def f(x: a[[int, str], float]):
             pass
-        anno = annotationlib.get_annotations(f, format=Format.STRING)
+        anno = get_annotations(f, format=Format.STRING)
         self.assertEqual(anno, {"x": "a[[int, str], float]"})
 
         def g(
@@ -345,7 +345,7 @@ class TestStringFormat(unittest.TestCase):
             z: a[(int, str), 5],
         ):
             pass
-        anno = annotationlib.get_annotations(g, format=Format.STRING)
+        anno = get_annotations(g, format=Format.STRING)
         self.assertEqual(
             anno,
             {
@@ -1017,6 +1017,58 @@ class TestGetAnnotations(unittest.TestCase):
             set(results.generic_func.__type_params__),
         )
 
+    def test_partial_evaluation(self):
+        def f(
+            x: builtins.undef,
+            y: list[int],
+            z: 1 + int,
+            a: builtins.int,
+            b: [builtins.undef, builtins.int],
+        ):
+            pass
+
+        self.assertEqual(
+            get_annotations(f, format=Format.FORWARDREF),
+            {
+                "x": support.EqualToForwardRef("builtins.undef", owner=f),
+                "y": list[int],
+                "z": support.EqualToForwardRef("1 + int", owner=f),
+                "a": int,
+                "b": [
+                    support.EqualToForwardRef("builtins.undef", owner=f),
+                    # We can't resolve this because we have to evaluate the whole annotation
+                    support.EqualToForwardRef("builtins.int", owner=f),
+                ],
+            },
+        )
+
+        self.assertEqual(
+            get_annotations(f, format=Format.STRING),
+            {
+                "x": "builtins.undef",
+                "y": "list[int]",
+                "z": "1 + int",
+                "a": "builtins.int",
+                "b": "[builtins.undef, builtins.int]",
+            },
+        )
+
+    def test_partial_evaluation_cell(self):
+        obj = object()
+
+        class RaisesAttributeError:
+            attriberr: obj.missing
+
+        anno = get_annotations(RaisesAttributeError, format=Format.FORWARDREF)
+        self.assertEqual(
+            anno,
+            {
+                "attriberr": support.EqualToForwardRef(
+                    "obj.missing", is_class=True, owner=RaisesAttributeError
+                )
+            },
+        )
+
 
 class TestCallEvaluateFunction(unittest.TestCase):
     def test_evaluation(self):
@@ -1370,6 +1422,38 @@ class TestForwardRefClass(unittest.TestCase):
             with self.assertRaises(TypeError):
                 pickle.dumps(fr, proto)
 
+    def test_evaluate_string_format(self):
+        fr = ForwardRef("set[Any]")
+        self.assertEqual(fr.evaluate(format=Format.STRING), "set[Any]")
+
+    def test_evaluate_forwardref_format(self):
+        fr = ForwardRef("undef")
+        evaluated = fr.evaluate(format=Format.FORWARDREF)
+        self.assertIs(fr, evaluated)
+
+        fr = ForwardRef("set[undefined]")
+        evaluated = fr.evaluate(format=Format.FORWARDREF)
+        self.assertEqual(
+            evaluated,
+            set[support.EqualToForwardRef("undefined")],
+        )
+
+        fr = ForwardRef("a + b")
+        self.assertEqual(
+            fr.evaluate(format=Format.FORWARDREF),
+            support.EqualToForwardRef("a + b"),
+        )
+        self.assertEqual(
+            fr.evaluate(format=Format.FORWARDREF, locals={"a": 1, "b": 2}),
+            3,
+        )
+
+        fr = ForwardRef('"a" + 1')
+        self.assertEqual(
+            fr.evaluate(format=Format.FORWARDREF),
+            support.EqualToForwardRef('"a" + 1'),
+        )
+
     def test_evaluate_with_type_params(self):
         class Gen[T]:
             alias = int
diff --git a/Misc/NEWS.d/next/Library/2025-04-22-16-35-37.gh-issue-125618.PEocn3.rst b/Misc/NEWS.d/next/Library/2025-04-22-16-35-37.gh-issue-125618.PEocn3.rst
new file mode 100644 (file)
index 0000000..42ecf5c
--- /dev/null
@@ -0,0 +1,3 @@
+Add a *format* parameter to :meth:`annotationlib.ForwardRef.evaluate`.
+Evaluating annotations in the ``FORWARDREF`` format now succeeds in more
+cases that would previously have raised an exception.