]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-119180: Improvements to ForwardRef.evaluate (#122210)
authorJelle Zijlstra <jelle.zijlstra@gmail.com>
Sun, 11 Aug 2024 23:42:57 +0000 (16:42 -0700)
committerGitHub <noreply@github.com>
Sun, 11 Aug 2024 23:42:57 +0000 (23:42 +0000)
Noticed some issues while writing documentation for this method.

Lib/annotationlib.py
Lib/test/test_annotationlib.py
Lib/typing.py

index 141e31bbf910e3247b2a46c8091fa998db858379..8f2a93be9158325dea1a6f145f6952ad654af69e 100644 (file)
@@ -74,7 +74,7 @@ class ForwardRef:
     def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
         """Evaluate the forward reference and return the value.
 
-        If the forward reference is not evaluatable, raise an exception.
+        If the forward reference cannot be evaluated, raise an exception.
         """
         if self.__forward_evaluated__:
             return self.__forward_value__
@@ -89,12 +89,10 @@ class ForwardRef:
                 return value
         if owner is None:
             owner = self.__owner__
-        if type_params is None and owner is None:
-            raise TypeError("Either 'type_params' or 'owner' must be provided")
 
-        if self.__forward_module__ is not None:
+        if globals is None and self.__forward_module__ is not None:
             globals = getattr(
-                sys.modules.get(self.__forward_module__, None), "__dict__", globals
+                sys.modules.get(self.__forward_module__, None), "__dict__", None
             )
         if globals is None:
             globals = self.__globals__
@@ -112,14 +110,14 @@ class ForwardRef:
 
         if locals is None:
             locals = {}
-            if isinstance(self.__owner__, type):
-                locals.update(vars(self.__owner__))
+            if isinstance(owner, type):
+                locals.update(vars(owner))
 
-        if type_params is None and self.__owner__ is not None:
+        if type_params is None and owner is not None:
             # "Inject" type parameters into the local namespace
             # (unless they are shadowed by assignments *in* the local namespace),
             # as a way of emulating annotation scopes when calling `eval()`
-            type_params = getattr(self.__owner__, "__type_params__", None)
+            type_params = getattr(owner, "__type_params__", None)
 
         # type parameters require some special handling,
         # as they exist in their own scope
@@ -129,7 +127,14 @@ class ForwardRef:
         # but should in turn be overridden by names in the class scope
         # (which here are called `globalns`!)
         if type_params is not None:
-            globals, locals = dict(globals), dict(locals)
+            if globals is None:
+                globals = {}
+            else:
+                globals = dict(globals)
+            if locals is None:
+                locals = {}
+            else:
+                locals = dict(locals)
             for param in type_params:
                 param_name = param.__name__
                 if not self.__forward_is_class__ or param_name not in globals:
index e4dcdb6b58d009999a15f27cadb2dcb1face8455..db8350c27469834f7e3ed1bcc21036c183a2f9c1 100644 (file)
@@ -5,7 +5,7 @@ import functools
 import itertools
 import pickle
 import unittest
-from annotationlib import Format, get_annotations, get_annotate_function
+from annotationlib import Format, ForwardRef, get_annotations, get_annotate_function
 from typing import Unpack
 
 from test.test_inspect import inspect_stock_annotations
@@ -250,6 +250,46 @@ class TestForwardRefClass(unittest.TestCase):
             with self.assertRaises(TypeError):
                 pickle.dumps(fr, proto)
 
+    def test_evaluate_with_type_params(self):
+        class Gen[T]:
+            alias = int
+
+        with self.assertRaises(NameError):
+            ForwardRef("T").evaluate()
+        with self.assertRaises(NameError):
+            ForwardRef("T").evaluate(type_params=())
+        with self.assertRaises(NameError):
+            ForwardRef("T").evaluate(owner=int)
+
+        T, = Gen.__type_params__
+        self.assertIs(ForwardRef("T").evaluate(type_params=Gen.__type_params__), T)
+        self.assertIs(ForwardRef("T").evaluate(owner=Gen), T)
+
+        with self.assertRaises(NameError):
+            ForwardRef("alias").evaluate(type_params=Gen.__type_params__)
+        self.assertIs(ForwardRef("alias").evaluate(owner=Gen), int)
+        # If you pass custom locals, we don't look at the owner's locals
+        with self.assertRaises(NameError):
+            ForwardRef("alias").evaluate(owner=Gen, locals={})
+        # But if the name exists in the locals, it works
+        self.assertIs(
+            ForwardRef("alias").evaluate(owner=Gen, locals={"alias": str}), str
+        )
+
+    def test_fwdref_with_module(self):
+        self.assertIs(ForwardRef("Format", module=annotationlib).evaluate(), Format)
+
+        with self.assertRaises(NameError):
+            # If globals are passed explicitly, we don't look at the module dict
+            ForwardRef("Format", module=annotationlib).evaluate(globals={})
+
+    def test_fwdref_value_is_cached(self):
+        fr = ForwardRef("hello")
+        with self.assertRaises(NameError):
+            fr.evaluate()
+        self.assertIs(fr.evaluate(globals={"hello": str}), str)
+        self.assertIs(fr.evaluate(), str)
+
 
 class TestGetAnnotations(unittest.TestCase):
     def test_builtin_type(self):
index 39a14ae6f83c286c1be9f1ccad4098fec05d86a1..bcb7bec23a9aa1b5812ede60a248b84f3c4af0a2 100644 (file)
@@ -474,6 +474,10 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f
         _deprecation_warning_for_no_type_params_passed("typing._eval_type")
         type_params = ()
     if isinstance(t, ForwardRef):
+        # If the forward_ref has __forward_module__ set, evaluate() infers the globals
+        # from the module, and it will probably pick better than the globals we have here.
+        if t.__forward_module__ is not None:
+            globalns = None
         return evaluate_forward_ref(t, globals=globalns, locals=localns,
                                     type_params=type_params, owner=owner,
                                     _recursive_guard=recursive_guard, format=format)