]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-119180: Add `annotationlib` module to support PEP 649 (#119891)
authorJelle Zijlstra <jelle.zijlstra@gmail.com>
Tue, 23 Jul 2024 21:16:50 +0000 (14:16 -0700)
committerGitHub <noreply@github.com>
Tue, 23 Jul 2024 21:16:50 +0000 (21:16 +0000)
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
15 files changed:
Doc/howto/descriptor.rst
Lib/annotationlib.py [new file with mode: 0644]
Lib/dataclasses.py
Lib/functools.py
Lib/inspect.py
Lib/test/test_annotationlib.py [new file with mode: 0644]
Lib/test/test_dataclasses/__init__.py
Lib/test/test_functools.py
Lib/test/test_grammar.py
Lib/test/test_inspect/test_inspect.py
Lib/test/test_type_annotations.py
Lib/test/test_typing.py
Lib/typing.py
Misc/NEWS.d/next/Library/2024-06-11-07-17-25.gh-issue-119180.iH-2zy.rst [new file with mode: 0644]
Python/stdlib_module_names.h

index b29488be39a0a3802a9f865c08787bb311986e4a..67e981f9c57abeccca08afc9a4535bd6b0b7e68b 100644 (file)
@@ -1366,11 +1366,15 @@ Using the non-data descriptor protocol, a pure Python version of
         def __call__(self, *args, **kwds):
             return self.f(*args, **kwds)
 
+        @property
+        def __annotations__(self):
+            return self.f.__annotations__
+
 The :func:`functools.update_wrapper` call adds a ``__wrapped__`` attribute
 that refers to the underlying function.  Also it carries forward
 the attributes necessary to make the wrapper look like the wrapped
-function: :attr:`~function.__name__`, :attr:`~function.__qualname__`,
-:attr:`~function.__doc__`, and :attr:`~function.__annotations__`.
+function, including :attr:`~function.__name__`, :attr:`~function.__qualname__`,
+and :attr:`~function.__doc__`.
 
 .. testcode::
     :hide:
diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py
new file mode 100644 (file)
index 0000000..b4036ff
--- /dev/null
@@ -0,0 +1,655 @@
+"""Helpers for introspecting and wrapping annotations."""
+
+import ast
+import enum
+import functools
+import sys
+import types
+
+__all__ = ["Format", "ForwardRef", "call_annotate_function", "get_annotations"]
+
+
+class Format(enum.IntEnum):
+    VALUE = 1
+    FORWARDREF = 2
+    SOURCE = 3
+
+
+_Union = None
+_sentinel = object()
+
+# Slots shared by ForwardRef and _Stringifier. The __forward__ names must be
+# preserved for compatibility with the old typing.ForwardRef class. The remaining
+# names are private.
+_SLOTS = (
+    "__forward_evaluated__",
+    "__forward_value__",
+    "__forward_is_argument__",
+    "__forward_is_class__",
+    "__forward_module__",
+    "__weakref__",
+    "__arg__",
+    "__ast_node__",
+    "__code__",
+    "__globals__",
+    "__owner__",
+    "__cell__",
+)
+
+
+class ForwardRef:
+    """Wrapper that holds a forward reference."""
+
+    __slots__ = _SLOTS
+
+    def __init__(
+        self,
+        arg,
+        *,
+        module=None,
+        owner=None,
+        is_argument=True,
+        is_class=False,
+        _globals=None,
+        _cell=None,
+    ):
+        if not isinstance(arg, str):
+            raise TypeError(f"Forward reference must be a string -- got {arg!r}")
+
+        self.__arg__ = arg
+        self.__forward_evaluated__ = False
+        self.__forward_value__ = None
+        self.__forward_is_argument__ = is_argument
+        self.__forward_is_class__ = is_class
+        self.__forward_module__ = module
+        self.__code__ = None
+        self.__ast_node__ = None
+        self.__globals__ = _globals
+        self.__cell__ = _cell
+        self.__owner__ = owner
+
+    def __init_subclass__(cls, /, *args, **kwds):
+        raise TypeError("Cannot subclass ForwardRef")
+
+    def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
+        """Evaluate the forward reference and return the value.
+
+        If the forward reference is not evaluatable, raise an exception.
+        """
+        if self.__forward_evaluated__:
+            return self.__forward_value__
+        if self.__cell__ is not None:
+            try:
+                value = self.__cell__.cell_contents
+            except ValueError:
+                pass
+            else:
+                self.__forward_evaluated__ = True
+                self.__forward_value__ = value
+                return value
+        if owner is None:
+            owner = self.__owner__
+        if type_params is None and owner is None:
+            raise TypeError("Either 'type_params' or 'owner' must be provided")
+
+        if self.__forward_module__ is not None:
+            globals = getattr(
+                sys.modules.get(self.__forward_module__, None), "__dict__", globals
+            )
+        if globals is None:
+            globals = self.__globals__
+        if globals is None:
+            if isinstance(owner, type):
+                module_name = getattr(owner, "__module__", None)
+                if module_name:
+                    module = sys.modules.get(module_name, None)
+                    if module:
+                        globals = getattr(module, "__dict__", None)
+            elif isinstance(owner, types.ModuleType):
+                globals = getattr(owner, "__dict__", None)
+            elif callable(owner):
+                globals = getattr(owner, "__globals__", None)
+
+        if locals is None:
+            locals = {}
+            if isinstance(self.__owner__, type):
+                locals.update(vars(self.__owner__))
+
+        if type_params is None and self.__owner__ is not None:
+            # "Inject" type parameters into the local namespace
+            # (unless they are shadowed by assignments *in* the local namespace),
+            # as a way of emulating annotation scopes when calling `eval()`
+            type_params = getattr(self.__owner__, "__type_params__", None)
+
+        # type parameters require some special handling,
+        # as they exist in their own scope
+        # but `eval()` does not have a dedicated parameter for that scope.
+        # For classes, names in type parameter scopes should override
+        # names in the global scope (which here are called `localns`!),
+        # but should in turn be overridden by names in the class scope
+        # (which here are called `globalns`!)
+        if type_params is not None:
+            globals, locals = dict(globals), dict(locals)
+            for param in type_params:
+                param_name = param.__name__
+                if not self.__forward_is_class__ or param_name not in globals:
+                    globals[param_name] = param
+                    locals.pop(param_name, None)
+
+        code = self.__forward_code__
+        value = eval(code, globals=globals, locals=locals)
+        self.__forward_evaluated__ = True
+        self.__forward_value__ = value
+        return value
+
+    def _evaluate(self, globalns, localns, type_params=_sentinel, *, recursive_guard):
+        import typing
+        import warnings
+
+        if type_params is _sentinel:
+            typing._deprecation_warning_for_no_type_params_passed(
+                "typing.ForwardRef._evaluate"
+            )
+            type_params = ()
+        warnings._deprecated(
+            "ForwardRef._evaluate",
+            "{name} is a private API and is retained for compatibility, but will be removed"
+            " in Python 3.16. Use ForwardRef.evaluate() or typing.evaluate_forward_ref() instead.",
+            remove=(3, 16),
+        )
+        return typing.evaluate_forward_ref(
+            self,
+            globals=globalns,
+            locals=localns,
+            type_params=type_params,
+            _recursive_guard=recursive_guard,
+        )
+
+    @property
+    def __forward_arg__(self):
+        if self.__arg__ is not None:
+            return self.__arg__
+        if self.__ast_node__ is not None:
+            self.__arg__ = ast.unparse(self.__ast_node__)
+            return self.__arg__
+        raise AssertionError(
+            "Attempted to access '__forward_arg__' on an uninitialized ForwardRef"
+        )
+
+    @property
+    def __forward_code__(self):
+        if self.__code__ is not None:
+            return self.__code__
+        arg = self.__forward_arg__
+        # If we do `def f(*args: *Ts)`, then we'll have `arg = '*Ts'`.
+        # Unfortunately, this isn't a valid expression on its own, so we
+        # do the unpacking manually.
+        if arg.startswith("*"):
+            arg_to_compile = f"({arg},)[0]"  # E.g. (*Ts,)[0] or (*tuple[int, int],)[0]
+        else:
+            arg_to_compile = arg
+        try:
+            self.__code__ = compile(arg_to_compile, "<string>", "eval")
+        except SyntaxError:
+            raise SyntaxError(f"Forward reference must be an expression -- got {arg!r}")
+        return self.__code__
+
+    def __eq__(self, other):
+        if not isinstance(other, ForwardRef):
+            return NotImplemented
+        if self.__forward_evaluated__ and other.__forward_evaluated__:
+            return (
+                self.__forward_arg__ == other.__forward_arg__
+                and self.__forward_value__ == other.__forward_value__
+            )
+        return (
+            self.__forward_arg__ == other.__forward_arg__
+            and self.__forward_module__ == other.__forward_module__
+        )
+
+    def __hash__(self):
+        return hash((self.__forward_arg__, self.__forward_module__))
+
+    def __or__(self, other):
+        global _Union
+        if _Union is None:
+            from typing import Union as _Union
+        return _Union[self, other]
+
+    def __ror__(self, other):
+        global _Union
+        if _Union is None:
+            from typing import Union as _Union
+        return _Union[other, self]
+
+    def __repr__(self):
+        if self.__forward_module__ is None:
+            module_repr = ""
+        else:
+            module_repr = f", module={self.__forward_module__!r}"
+        return f"ForwardRef({self.__forward_arg__!r}{module_repr})"
+
+
+class _Stringifier:
+    # Must match the slots on ForwardRef, so we can turn an instance of one into an
+    # instance of the other in place.
+    __slots__ = _SLOTS
+
+    def __init__(self, node, globals=None, owner=None, is_class=False, cell=None):
+        assert isinstance(node, ast.AST)
+        self.__arg__ = None
+        self.__forward_evaluated__ = False
+        self.__forward_value__ = None
+        self.__forward_is_argument__ = False
+        self.__forward_is_class__ = is_class
+        self.__forward_module__ = None
+        self.__code__ = None
+        self.__ast_node__ = node
+        self.__globals__ = globals
+        self.__cell__ = cell
+        self.__owner__ = owner
+
+    def __convert(self, other):
+        if isinstance(other, _Stringifier):
+            return other.__ast_node__
+        elif isinstance(other, slice):
+            return ast.Slice(
+                lower=self.__convert(other.start) if other.start is not None else None,
+                upper=self.__convert(other.stop) if other.stop is not None else None,
+                step=self.__convert(other.step) if other.step is not None else None,
+            )
+        else:
+            return ast.Constant(value=other)
+
+    def __make_new(self, node):
+        return _Stringifier(
+            node, self.__globals__, self.__owner__, self.__forward_is_class__
+        )
+
+    # Must implement this since we set __eq__. We hash by identity so that
+    # stringifiers in dict keys are kept separate.
+    def __hash__(self):
+        return id(self)
+
+    def __getitem__(self, other):
+        # Special case, to avoid stringifying references to class-scoped variables
+        # as '__classdict__["x"]'.
+        if (
+            isinstance(self.__ast_node__, ast.Name)
+            and self.__ast_node__.id == "__classdict__"
+        ):
+            raise KeyError
+        if isinstance(other, tuple):
+            elts = [self.__convert(elt) for elt in other]
+            other = ast.Tuple(elts)
+        else:
+            other = self.__convert(other)
+        assert isinstance(other, ast.AST), repr(other)
+        return self.__make_new(ast.Subscript(self.__ast_node__, other))
+
+    def __getattr__(self, attr):
+        return self.__make_new(ast.Attribute(self.__ast_node__, attr))
+
+    def __call__(self, *args, **kwargs):
+        return self.__make_new(
+            ast.Call(
+                self.__ast_node__,
+                [self.__convert(arg) for arg in args],
+                [
+                    ast.keyword(key, self.__convert(value))
+                    for key, value in kwargs.items()
+                ],
+            )
+        )
+
+    def __iter__(self):
+        yield self.__make_new(ast.Starred(self.__ast_node__))
+
+    def __repr__(self):
+        return ast.unparse(self.__ast_node__)
+
+    def __format__(self, format_spec):
+        raise TypeError("Cannot stringify annotation containing string formatting")
+
+    def _make_binop(op: ast.AST):
+        def binop(self, other):
+            return self.__make_new(
+                ast.BinOp(self.__ast_node__, op, self.__convert(other))
+            )
+
+        return binop
+
+    __add__ = _make_binop(ast.Add())
+    __sub__ = _make_binop(ast.Sub())
+    __mul__ = _make_binop(ast.Mult())
+    __matmul__ = _make_binop(ast.MatMult())
+    __truediv__ = _make_binop(ast.Div())
+    __mod__ = _make_binop(ast.Mod())
+    __lshift__ = _make_binop(ast.LShift())
+    __rshift__ = _make_binop(ast.RShift())
+    __or__ = _make_binop(ast.BitOr())
+    __xor__ = _make_binop(ast.BitXor())
+    __and__ = _make_binop(ast.BitAnd())
+    __floordiv__ = _make_binop(ast.FloorDiv())
+    __pow__ = _make_binop(ast.Pow())
+
+    del _make_binop
+
+    def _make_rbinop(op: ast.AST):
+        def rbinop(self, other):
+            return self.__make_new(
+                ast.BinOp(self.__convert(other), op, self.__ast_node__)
+            )
+
+        return rbinop
+
+    __radd__ = _make_rbinop(ast.Add())
+    __rsub__ = _make_rbinop(ast.Sub())
+    __rmul__ = _make_rbinop(ast.Mult())
+    __rmatmul__ = _make_rbinop(ast.MatMult())
+    __rtruediv__ = _make_rbinop(ast.Div())
+    __rmod__ = _make_rbinop(ast.Mod())
+    __rlshift__ = _make_rbinop(ast.LShift())
+    __rrshift__ = _make_rbinop(ast.RShift())
+    __ror__ = _make_rbinop(ast.BitOr())
+    __rxor__ = _make_rbinop(ast.BitXor())
+    __rand__ = _make_rbinop(ast.BitAnd())
+    __rfloordiv__ = _make_rbinop(ast.FloorDiv())
+    __rpow__ = _make_rbinop(ast.Pow())
+
+    del _make_rbinop
+
+    def _make_compare(op):
+        def compare(self, other):
+            return self.__make_new(
+                ast.Compare(
+                    left=self.__ast_node__,
+                    ops=[op],
+                    comparators=[self.__convert(other)],
+                )
+            )
+
+        return compare
+
+    __lt__ = _make_compare(ast.Lt())
+    __le__ = _make_compare(ast.LtE())
+    __eq__ = _make_compare(ast.Eq())
+    __ne__ = _make_compare(ast.NotEq())
+    __gt__ = _make_compare(ast.Gt())
+    __ge__ = _make_compare(ast.GtE())
+
+    del _make_compare
+
+    def _make_unary_op(op):
+        def unary_op(self):
+            return self.__make_new(ast.UnaryOp(op, self.__ast_node__))
+
+        return unary_op
+
+    __invert__ = _make_unary_op(ast.Invert())
+    __pos__ = _make_unary_op(ast.UAdd())
+    __neg__ = _make_unary_op(ast.USub())
+
+    del _make_unary_op
+
+
+class _StringifierDict(dict):
+    def __init__(self, namespace, globals=None, owner=None, is_class=False):
+        super().__init__(namespace)
+        self.namespace = namespace
+        self.globals = globals
+        self.owner = owner
+        self.is_class = is_class
+        self.stringifiers = []
+
+    def __missing__(self, key):
+        fwdref = _Stringifier(
+            ast.Name(id=key),
+            globals=self.globals,
+            owner=self.owner,
+            is_class=self.is_class,
+        )
+        self.stringifiers.append(fwdref)
+        return fwdref
+
+
+def call_annotate_function(annotate, format, owner=None):
+    """Call an __annotate__ function. __annotate__ functions are normally
+    generated by the compiler to defer the evaluation of annotations. They
+    can be called with any of the format arguments in the Format enum, but
+    compiler-generated __annotate__ functions only support the VALUE format.
+    This function provides additional functionality to call __annotate__
+    functions with the FORWARDREF and SOURCE formats.
+
+    *annotate* must be an __annotate__ function, which takes a single argument
+    and returns a dict of annotations.
+
+    *format* must be a member of the Format enum or one of the corresponding
+    integer values.
+
+    *owner* can be the object that owns the annotations (i.e., the module,
+    class, or function that the __annotate__ function derives from). With the
+    FORWARDREF format, it is used to provide better evaluation capabilities
+    on the generated ForwardRef objects.
+
+    """
+    try:
+        return annotate(format)
+    except NotImplementedError:
+        pass
+    if format == Format.SOURCE:
+        # SOURCE is implemented by calling the annotate function in a special
+        # environment where every name lookup results in an instance of _Stringifier.
+        # _Stringifier supports every dunder operation and returns a new _Stringifier.
+        # At the end, we get a dictionary that mostly contains _Stringifier objects (or
+        # possibly constants if the annotate function uses them directly). We then
+        # convert each of those into a string to get an approximation of the
+        # original source.
+        globals = _StringifierDict({})
+        if annotate.__closure__:
+            freevars = annotate.__code__.co_freevars
+            new_closure = []
+            for i, cell in enumerate(annotate.__closure__):
+                if i < len(freevars):
+                    name = freevars[i]
+                else:
+                    name = "__cell__"
+                fwdref = _Stringifier(ast.Name(id=name))
+                new_closure.append(types.CellType(fwdref))
+            closure = tuple(new_closure)
+        else:
+            closure = None
+        func = types.FunctionType(annotate.__code__, globals, closure=closure)
+        annos = func(Format.VALUE)
+        return {
+            key: val if isinstance(val, str) else repr(val)
+            for key, val in annos.items()
+        }
+    elif format == Format.FORWARDREF:
+        # FORWARDREF is implemented similarly to SOURCE, but there are two changes,
+        # at the beginning and the end of the process.
+        # First, while SOURCE uses an empty dictionary as the namespace, so that all
+        # name lookups result in _Stringifier objects, FORWARDREF uses the globals
+        # and builtins, so that defined names map to their real values.
+        # Second, instead of returning strings, we want to return either real values
+        # or ForwardRef objects. To do this, we keep track of all _Stringifier objects
+        # created while the annotation is being evaluated, and at the end we convert
+        # them all to ForwardRef objects by assigning to __class__. To make this
+        # technique work, we have to ensure that the _Stringifier and ForwardRef
+        # classes share the same attributes.
+        # We use this technique because while the annotations are being evaluated,
+        # we want to support all operations that the language allows, including even
+        # __getattr__ and __eq__, and return new _Stringifier objects so we can accurately
+        # reconstruct the source. But in the dictionary that we eventually return, we
+        # want to return objects with more user-friendly behavior, such as an __eq__
+        # that returns a bool and an defined set of attributes.
+        namespace = {**annotate.__builtins__, **annotate.__globals__}
+        is_class = isinstance(owner, type)
+        globals = _StringifierDict(namespace, annotate.__globals__, owner, is_class)
+        if annotate.__closure__:
+            freevars = annotate.__code__.co_freevars
+            new_closure = []
+            for i, cell in enumerate(annotate.__closure__):
+                try:
+                    cell.cell_contents
+                except ValueError:
+                    if i < len(freevars):
+                        name = freevars[i]
+                    else:
+                        name = "__cell__"
+                    fwdref = _Stringifier(
+                        ast.Name(id=name),
+                        cell=cell,
+                        owner=owner,
+                        globals=annotate.__globals__,
+                        is_class=is_class,
+                    )
+                    globals.stringifiers.append(fwdref)
+                    new_closure.append(types.CellType(fwdref))
+                else:
+                    new_closure.append(cell)
+            closure = tuple(new_closure)
+        else:
+            closure = None
+        func = types.FunctionType(annotate.__code__, globals, closure=closure)
+        result = func(Format.VALUE)
+        for obj in globals.stringifiers:
+            obj.__class__ = ForwardRef
+        return result
+    elif format == Format.VALUE:
+        # Should be impossible because __annotate__ functions must not raise
+        # NotImplementedError for this format.
+        raise RuntimeError("annotate function does not support VALUE format")
+    else:
+        raise ValueError(f"Invalid format: {format!r}")
+
+
+def get_annotations(
+    obj, *, globals=None, locals=None, eval_str=False, format=Format.VALUE
+):
+    """Compute the annotations dict for an object.
+
+    obj may be a callable, class, or module.
+    Passing in an object of any other type raises TypeError.
+
+    Returns a dict.  get_annotations() returns a new dict every time
+    it's called; calling it twice on the same object will return two
+    different but equivalent dicts.
+
+    This function handles several details for you:
+
+      * If eval_str is true, values of type str will
+        be un-stringized using eval().  This is intended
+        for use with stringized annotations
+        ("from __future__ import annotations").
+      * If obj doesn't have an annotations dict, returns an
+        empty dict.  (Functions and methods always have an
+        annotations dict; classes, modules, and other types of
+        callables may not.)
+      * Ignores inherited annotations on classes.  If a class
+        doesn't have its own annotations dict, returns an empty dict.
+      * All accesses to object members and dict values are done
+        using getattr() and dict.get() for safety.
+      * Always, always, always returns a freshly-created dict.
+
+    eval_str controls whether or not values of type str are replaced
+    with the result of calling eval() on those values:
+
+      * If eval_str is true, eval() is called on values of type str.
+      * If eval_str is false (the default), values of type str are unchanged.
+
+    globals and locals are passed in to eval(); see the documentation
+    for eval() for more information.  If either globals or locals is
+    None, this function may replace that value with a context-specific
+    default, contingent on type(obj):
+
+      * If obj is a module, globals defaults to obj.__dict__.
+      * If obj is a class, globals defaults to
+        sys.modules[obj.__module__].__dict__ and locals
+        defaults to the obj class namespace.
+      * If obj is a callable, globals defaults to obj.__globals__,
+        although if obj is a wrapped function (using
+        functools.update_wrapper()) it is first unwrapped.
+    """
+    if eval_str and format != Format.VALUE:
+        raise ValueError("eval_str=True is only supported with format=Format.VALUE")
+
+    # For VALUE format, we look at __annotations__ directly.
+    if format != Format.VALUE:
+        annotate = getattr(obj, "__annotate__", None)
+        if annotate is not None:
+            ann = call_annotate_function(annotate, format, owner=obj)
+            if not isinstance(ann, dict):
+                raise ValueError(f"{obj!r}.__annotate__ returned a non-dict")
+            return dict(ann)
+
+    ann = getattr(obj, "__annotations__", None)
+    if ann is None:
+        return {}
+
+    if not isinstance(ann, dict):
+        raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
+
+    if not ann:
+        return {}
+
+    if not eval_str:
+        return dict(ann)
+
+    if isinstance(obj, type):
+        # class
+        obj_globals = None
+        module_name = getattr(obj, "__module__", None)
+        if module_name:
+            module = sys.modules.get(module_name, None)
+            if module:
+                obj_globals = getattr(module, "__dict__", None)
+        obj_locals = dict(vars(obj))
+        unwrap = obj
+    elif isinstance(obj, types.ModuleType):
+        # module
+        obj_globals = getattr(obj, "__dict__")
+        obj_locals = None
+        unwrap = None
+    elif callable(obj):
+        # this includes types.Function, types.BuiltinFunctionType,
+        # types.BuiltinMethodType, functools.partial, functools.singledispatch,
+        # "class funclike" from Lib/test/test_inspect... on and on it goes.
+        obj_globals = getattr(obj, "__globals__", None)
+        obj_locals = None
+        unwrap = obj
+    elif ann is not None:
+        obj_globals = obj_locals = unwrap = None
+    else:
+        raise TypeError(f"{obj!r} is not a module, class, or callable.")
+
+    if unwrap is not None:
+        while True:
+            if hasattr(unwrap, "__wrapped__"):
+                unwrap = unwrap.__wrapped__
+                continue
+            if isinstance(unwrap, functools.partial):
+                unwrap = unwrap.func
+                continue
+            break
+        if hasattr(unwrap, "__globals__"):
+            obj_globals = unwrap.__globals__
+
+    if globals is None:
+        globals = obj_globals
+    if locals is None:
+        locals = obj_locals
+
+    # "Inject" type parameters into the local namespace
+    # (unless they are shadowed by assignments *in* the local namespace),
+    # as a way of emulating annotation scopes when calling `eval()`
+    if type_params := getattr(obj, "__type_params__", ()):
+        if locals is None:
+            locals = {}
+        locals = {param.__name__: param for param in type_params} | locals
+
+    return_value = {
+        key: value if not isinstance(value, str) else eval(value, globals, locals)
+        for key, value in ann.items()
+    }
+    return return_value
index 74011b7e28b9f3af22e7b098b6d60fc0a077e063..4cba606dd8dd4d4b2e693eb4b7caaf7a02e87486 100644 (file)
@@ -5,6 +5,7 @@ import types
 import inspect
 import keyword
 import itertools
+import annotationlib
 import abc
 from reprlib import recursive_repr
 
@@ -981,7 +982,8 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
     # actual default value.  Pseudo-fields ClassVars and InitVars are
     # included, despite the fact that they're not real fields.  That's
     # dealt with later.
-    cls_annotations = inspect.get_annotations(cls)
+    cls_annotations = annotationlib.get_annotations(
+        cls, format=annotationlib.Format.FORWARDREF)
 
     # Now find fields in our class.  While doing so, validate some
     # things, and set the default values (as class attributes) where
index a10493f0e25360c7e965ae02cf37425292cab6f5..49ea9a2f6999f5edf5123f7bb199f5437baf80ff 100644 (file)
@@ -32,7 +32,7 @@ GenericAlias = type(list[int])
 # wrapper functions that can handle naive introspection
 
 WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__',
-                       '__annotations__', '__type_params__')
+                       '__annotate__', '__type_params__')
 WRAPPER_UPDATES = ('__dict__',)
 def update_wrapper(wrapper,
                    wrapped,
@@ -882,8 +882,8 @@ def singledispatch(func):
                     f"Invalid first argument to `register()`. "
                     f"{cls!r} is not a class or union type."
                 )
-            ann = getattr(cls, '__annotations__', {})
-            if not ann:
+            ann = getattr(cls, '__annotate__', None)
+            if ann is None:
                 raise TypeError(
                     f"Invalid first argument to `register()`: {cls!r}. "
                     f"Use either `@register(some_class)` or plain `@register` "
@@ -893,13 +893,19 @@ def singledispatch(func):
 
             # only import typing if annotation parsing is necessary
             from typing import get_type_hints
-            argname, cls = next(iter(get_type_hints(func).items()))
+            from annotationlib import Format, ForwardRef
+            argname, cls = next(iter(get_type_hints(func, format=Format.FORWARDREF).items()))
             if not _is_valid_dispatch_type(cls):
                 if _is_union_type(cls):
                     raise TypeError(
                         f"Invalid annotation for {argname!r}. "
                         f"{cls!r} not all arguments are classes."
                     )
+                elif isinstance(cls, ForwardRef):
+                    raise TypeError(
+                        f"Invalid annotation for {argname!r}. "
+                        f"{cls!r} is an unresolved forward reference."
+                    )
                 else:
                     raise TypeError(
                         f"Invalid annotation for {argname!r}. "
index 0e7b40eb39bce8549b3406045e26e1b680a4b846..ba3ecbb87c70268808b6cb1138421b494e5af2c1 100644 (file)
@@ -142,6 +142,7 @@ __all__ = [
 
 
 import abc
+from annotationlib import get_annotations
 import ast
 import dis
 import collections.abc
@@ -173,121 +174,6 @@ del k, v, mod_dict
 TPFLAGS_IS_ABSTRACT = 1 << 20
 
 
-def get_annotations(obj, *, globals=None, locals=None, eval_str=False):
-    """Compute the annotations dict for an object.
-
-    obj may be a callable, class, or module.
-    Passing in an object of any other type raises TypeError.
-
-    Returns a dict.  get_annotations() returns a new dict every time
-    it's called; calling it twice on the same object will return two
-    different but equivalent dicts.
-
-    This function handles several details for you:
-
-      * If eval_str is true, values of type str will
-        be un-stringized using eval().  This is intended
-        for use with stringized annotations
-        ("from __future__ import annotations").
-      * If obj doesn't have an annotations dict, returns an
-        empty dict.  (Functions and methods always have an
-        annotations dict; classes, modules, and other types of
-        callables may not.)
-      * Ignores inherited annotations on classes.  If a class
-        doesn't have its own annotations dict, returns an empty dict.
-      * All accesses to object members and dict values are done
-        using getattr() and dict.get() for safety.
-      * Always, always, always returns a freshly-created dict.
-
-    eval_str controls whether or not values of type str are replaced
-    with the result of calling eval() on those values:
-
-      * If eval_str is true, eval() is called on values of type str.
-      * If eval_str is false (the default), values of type str are unchanged.
-
-    globals and locals are passed in to eval(); see the documentation
-    for eval() for more information.  If either globals or locals is
-    None, this function may replace that value with a context-specific
-    default, contingent on type(obj):
-
-      * If obj is a module, globals defaults to obj.__dict__.
-      * If obj is a class, globals defaults to
-        sys.modules[obj.__module__].__dict__ and locals
-        defaults to the obj class namespace.
-      * If obj is a callable, globals defaults to obj.__globals__,
-        although if obj is a wrapped function (using
-        functools.update_wrapper()) it is first unwrapped.
-    """
-    if isinstance(obj, type):
-        # class
-        ann = obj.__annotations__
-
-        obj_globals = None
-        module_name = getattr(obj, '__module__', None)
-        if module_name:
-            module = sys.modules.get(module_name, None)
-            if module:
-                obj_globals = getattr(module, '__dict__', None)
-        obj_locals = dict(vars(obj))
-        unwrap = obj
-    elif isinstance(obj, types.ModuleType):
-        # module
-        ann = getattr(obj, '__annotations__', None)
-        obj_globals = getattr(obj, '__dict__')
-        obj_locals = None
-        unwrap = None
-    elif callable(obj):
-        # this includes types.Function, types.BuiltinFunctionType,
-        # types.BuiltinMethodType, functools.partial, functools.singledispatch,
-        # "class funclike" from Lib/test/test_inspect... on and on it goes.
-        ann = getattr(obj, '__annotations__', None)
-        obj_globals = getattr(obj, '__globals__', None)
-        obj_locals = None
-        unwrap = obj
-    else:
-        raise TypeError(f"{obj!r} is not a module, class, or callable.")
-
-    if ann is None:
-        return {}
-
-    if not isinstance(ann, dict):
-        raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
-
-    if not ann:
-        return {}
-
-    if not eval_str:
-        return dict(ann)
-
-    if unwrap is not None:
-        while True:
-            if hasattr(unwrap, '__wrapped__'):
-                unwrap = unwrap.__wrapped__
-                continue
-            if isinstance(unwrap, functools.partial):
-                unwrap = unwrap.func
-                continue
-            break
-        if hasattr(unwrap, "__globals__"):
-            obj_globals = unwrap.__globals__
-
-    if globals is None:
-        globals = obj_globals
-    if locals is None:
-        locals = obj_locals or {}
-
-    # "Inject" type parameters into the local namespace
-    # (unless they are shadowed by assignments *in* the local namespace),
-    # as a way of emulating annotation scopes when calling `eval()`
-    if type_params := getattr(obj, "__type_params__", ()):
-        locals = {param.__name__: param for param in type_params} | locals
-
-    return_value = {key:
-        value if not isinstance(value, str) else eval(value, globals, locals)
-        for key, value in ann.items() }
-    return return_value
-
-
 # ----------------------------------------------------------- type-checking
 def ismodule(object):
     """Return true if the object is a module."""
diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py
new file mode 100644 (file)
index 0000000..e68d63c
--- /dev/null
@@ -0,0 +1,771 @@
+"""Tests for the annotations module."""
+
+import annotationlib
+import functools
+import pickle
+import unittest
+from typing import Unpack
+
+from test.test_inspect import inspect_stock_annotations
+from test.test_inspect import inspect_stringized_annotations
+from test.test_inspect import inspect_stringized_annotations_2
+from test.test_inspect import inspect_stringized_annotations_pep695
+
+
+def times_three(fn):
+    @functools.wraps(fn)
+    def wrapper(a, b):
+        return fn(a * 3, b * 3)
+
+    return wrapper
+
+
+class TestFormat(unittest.TestCase):
+    def test_enum(self):
+        self.assertEqual(annotationlib.Format.VALUE.value, 1)
+        self.assertEqual(annotationlib.Format.VALUE, 1)
+
+        self.assertEqual(annotationlib.Format.FORWARDREF.value, 2)
+        self.assertEqual(annotationlib.Format.FORWARDREF, 2)
+
+        self.assertEqual(annotationlib.Format.SOURCE.value, 3)
+        self.assertEqual(annotationlib.Format.SOURCE, 3)
+
+
+class TestForwardRefFormat(unittest.TestCase):
+    def test_closure(self):
+        def inner(arg: x):
+            pass
+
+        anno = annotationlib.get_annotations(
+            inner, format=annotationlib.Format.FORWARDREF
+        )
+        fwdref = anno["arg"]
+        self.assertIsInstance(fwdref, annotationlib.ForwardRef)
+        self.assertEqual(fwdref.__forward_arg__, "x")
+        with self.assertRaises(NameError):
+            fwdref.evaluate()
+
+        x = 1
+        self.assertEqual(fwdref.evaluate(), x)
+
+        anno = annotationlib.get_annotations(
+            inner, format=annotationlib.Format.FORWARDREF
+        )
+        self.assertEqual(anno["arg"], x)
+
+    def test_function(self):
+        def f(x: int, y: doesntexist):
+            pass
+
+        anno = annotationlib.get_annotations(f, format=annotationlib.Format.FORWARDREF)
+        self.assertIs(anno["x"], int)
+        fwdref = anno["y"]
+        self.assertIsInstance(fwdref, annotationlib.ForwardRef)
+        self.assertEqual(fwdref.__forward_arg__, "doesntexist")
+        with self.assertRaises(NameError):
+            fwdref.evaluate()
+        self.assertEqual(fwdref.evaluate(globals={"doesntexist": 1}), 1)
+
+
+class TestSourceFormat(unittest.TestCase):
+    def test_closure(self):
+        x = 0
+
+        def inner(arg: x):
+            pass
+
+        anno = annotationlib.get_annotations(inner, format=annotationlib.Format.SOURCE)
+        self.assertEqual(anno, {"arg": "x"})
+
+    def test_function(self):
+        def f(x: int, y: doesntexist):
+            pass
+
+        anno = annotationlib.get_annotations(f, format=annotationlib.Format.SOURCE)
+        self.assertEqual(anno, {"x": "int", "y": "doesntexist"})
+
+    def test_expressions(self):
+        def f(
+            add: a + b,
+            sub: a - b,
+            mul: a * b,
+            matmul: a @ b,
+            truediv: a / b,
+            mod: a % b,
+            lshift: a << b,
+            rshift: a >> b,
+            or_: a | b,
+            xor: a ^ b,
+            and_: a & b,
+            floordiv: a // b,
+            pow_: a**b,
+            lt: a < b,
+            le: a <= b,
+            eq: a == b,
+            ne: a != b,
+            gt: a > b,
+            ge: a >= b,
+            invert: ~a,
+            neg: -a,
+            pos: +a,
+            getitem: a[b],
+            getattr: a.b,
+            call: a(b, *c, d=e),  # **kwargs are not supported
+            *args: *a,
+        ):
+            pass
+
+        anno = annotationlib.get_annotations(f, format=annotationlib.Format.SOURCE)
+        self.assertEqual(
+            anno,
+            {
+                "add": "a + b",
+                "sub": "a - b",
+                "mul": "a * b",
+                "matmul": "a @ b",
+                "truediv": "a / b",
+                "mod": "a % b",
+                "lshift": "a << b",
+                "rshift": "a >> b",
+                "or_": "a | b",
+                "xor": "a ^ b",
+                "and_": "a & b",
+                "floordiv": "a // b",
+                "pow_": "a ** b",
+                "lt": "a < b",
+                "le": "a <= b",
+                "eq": "a == b",
+                "ne": "a != b",
+                "gt": "a > b",
+                "ge": "a >= b",
+                "invert": "~a",
+                "neg": "-a",
+                "pos": "+a",
+                "getitem": "a[b]",
+                "getattr": "a.b",
+                "call": "a(b, *c, d=e)",
+                "args": "*a",
+            },
+        )
+
+    def test_reverse_ops(self):
+        def f(
+            radd: 1 + a,
+            rsub: 1 - a,
+            rmul: 1 * a,
+            rmatmul: 1 @ a,
+            rtruediv: 1 / a,
+            rmod: 1 % a,
+            rlshift: 1 << a,
+            rrshift: 1 >> a,
+            ror: 1 | a,
+            rxor: 1 ^ a,
+            rand: 1 & a,
+            rfloordiv: 1 // a,
+            rpow: 1**a,
+        ):
+            pass
+
+        anno = annotationlib.get_annotations(f, format=annotationlib.Format.SOURCE)
+        self.assertEqual(
+            anno,
+            {
+                "radd": "1 + a",
+                "rsub": "1 - a",
+                "rmul": "1 * a",
+                "rmatmul": "1 @ a",
+                "rtruediv": "1 / a",
+                "rmod": "1 % a",
+                "rlshift": "1 << a",
+                "rrshift": "1 >> a",
+                "ror": "1 | a",
+                "rxor": "1 ^ a",
+                "rand": "1 & a",
+                "rfloordiv": "1 // a",
+                "rpow": "1 ** a",
+            },
+        )
+
+    def test_nested_expressions(self):
+        def f(
+            nested: list[Annotated[set[int], "set of ints", 4j]],
+            set: {a + b},  # single element because order is not guaranteed
+            dict: {a + b: c + d, "key": e + g},
+            list: [a, b, c],
+            tuple: (a, b, c),
+            slice: (a[b:c], a[b:c:d], a[:c], a[b:], a[:], a[::d], a[b::d]),
+            extended_slice: a[:, :, c:d],
+            unpack1: [*a],
+            unpack2: [*a, b, c],
+        ):
+            pass
+
+        anno = annotationlib.get_annotations(f, format=annotationlib.Format.SOURCE)
+        self.assertEqual(
+            anno,
+            {
+                "nested": "list[Annotated[set[int], 'set of ints', 4j]]",
+                "set": "{a + b}",
+                "dict": "{a + b: c + d, 'key': e + g}",
+                "list": "[a, b, c]",
+                "tuple": "(a, b, c)",
+                "slice": "(a[b:c], a[b:c:d], a[:c], a[b:], a[:], a[::d], a[b::d])",
+                "extended_slice": "a[:, :, c:d]",
+                "unpack1": "[*a]",
+                "unpack2": "[*a, b, c]",
+            },
+        )
+
+    def test_unsupported_operations(self):
+        format_msg = "Cannot stringify annotation containing string formatting"
+
+        def f(fstring: f"{a}"):
+            pass
+
+        with self.assertRaisesRegex(TypeError, format_msg):
+            annotationlib.get_annotations(f, format=annotationlib.Format.SOURCE)
+
+        def f(fstring_format: f"{a:02d}"):
+            pass
+
+        with self.assertRaisesRegex(TypeError, format_msg):
+            annotationlib.get_annotations(f, format=annotationlib.Format.SOURCE)
+
+
+class TestForwardRefClass(unittest.TestCase):
+    def test_special_attrs(self):
+        # Forward refs provide a different introspection API. __name__ and
+        # __qualname__ make little sense for forward refs as they can store
+        # complex typing expressions.
+        fr = annotationlib.ForwardRef("set[Any]")
+        self.assertFalse(hasattr(fr, "__name__"))
+        self.assertFalse(hasattr(fr, "__qualname__"))
+        self.assertEqual(fr.__module__, "annotationlib")
+        # Forward refs are currently unpicklable once they contain a code object.
+        fr.__forward_code__  # fill the cache
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+            with self.assertRaises(TypeError):
+                pickle.dumps(fr, proto)
+
+
+class TestGetAnnotations(unittest.TestCase):
+    def test_builtin_type(self):
+        self.assertEqual(annotationlib.get_annotations(int), {})
+        self.assertEqual(annotationlib.get_annotations(object), {})
+
+    def test_custom_metaclass(self):
+        class Meta(type):
+            pass
+
+        class C(metaclass=Meta):
+            x: int
+
+        self.assertEqual(annotationlib.get_annotations(C), {"x": int})
+
+    def test_missing_dunder_dict(self):
+        class NoDict(type):
+            @property
+            def __dict__(cls):
+                raise AttributeError
+
+            b: str
+
+        class C1(metaclass=NoDict):
+            a: int
+
+        self.assertEqual(annotationlib.get_annotations(C1), {"a": int})
+        self.assertEqual(
+            annotationlib.get_annotations(C1, format=annotationlib.Format.FORWARDREF),
+            {"a": int},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(C1, format=annotationlib.Format.SOURCE),
+            {"a": "int"},
+        )
+        self.assertEqual(annotationlib.get_annotations(NoDict), {"b": str})
+        self.assertEqual(
+            annotationlib.get_annotations(NoDict, format=annotationlib.Format.FORWARDREF),
+            {"b": str},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(NoDict, format=annotationlib.Format.SOURCE),
+            {"b": "str"},
+        )
+
+    def test_format(self):
+        def f1(a: int):
+            pass
+
+        def f2(a: undefined):
+            pass
+
+        self.assertEqual(
+            annotationlib.get_annotations(f1, format=annotationlib.Format.VALUE),
+            {"a": int},
+        )
+        self.assertEqual(annotationlib.get_annotations(f1, format=1), {"a": int})
+
+        fwd = annotationlib.ForwardRef("undefined")
+        self.assertEqual(
+            annotationlib.get_annotations(f2, format=annotationlib.Format.FORWARDREF),
+            {"a": fwd},
+        )
+        self.assertEqual(annotationlib.get_annotations(f2, format=2), {"a": fwd})
+
+        self.assertEqual(
+            annotationlib.get_annotations(f1, format=annotationlib.Format.SOURCE),
+            {"a": "int"},
+        )
+        self.assertEqual(annotationlib.get_annotations(f1, format=3), {"a": "int"})
+
+        with self.assertRaises(ValueError):
+            annotationlib.get_annotations(f1, format=0)
+
+        with self.assertRaises(ValueError):
+            annotationlib.get_annotations(f1, format=4)
+
+    def test_custom_object_with_annotations(self):
+        class C:
+            def __init__(self):
+                self.__annotations__ = {"x": int, "y": str}
+
+        self.assertEqual(annotationlib.get_annotations(C()), {"x": int, "y": str})
+
+    def test_custom_format_eval_str(self):
+        def foo():
+            pass
+
+        with self.assertRaises(ValueError):
+            annotationlib.get_annotations(
+                foo, format=annotationlib.Format.FORWARDREF, eval_str=True
+            )
+            annotationlib.get_annotations(
+                foo, format=annotationlib.Format.SOURCE, eval_str=True
+            )
+
+    def test_stock_annotations(self):
+        def foo(a: int, b: str):
+            pass
+
+        for format in (annotationlib.Format.VALUE, annotationlib.Format.FORWARDREF):
+            with self.subTest(format=format):
+                self.assertEqual(
+                    annotationlib.get_annotations(foo, format=format),
+                    {"a": int, "b": str},
+                )
+        self.assertEqual(
+            annotationlib.get_annotations(foo, format=annotationlib.Format.SOURCE),
+            {"a": "int", "b": "str"},
+        )
+
+        foo.__annotations__ = {"a": "foo", "b": "str"}
+        for format in annotationlib.Format:
+            with self.subTest(format=format):
+                self.assertEqual(
+                    annotationlib.get_annotations(foo, format=format),
+                    {"a": "foo", "b": "str"},
+                )
+
+        self.assertEqual(
+            annotationlib.get_annotations(foo, eval_str=True, locals=locals()),
+            {"a": foo, "b": str},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(foo, eval_str=True, globals=locals()),
+            {"a": foo, "b": str},
+        )
+
+    def test_stock_annotations_in_module(self):
+        isa = inspect_stock_annotations
+
+        for kwargs in [
+            {},
+            {"eval_str": False},
+            {"format": annotationlib.Format.VALUE},
+            {"format": annotationlib.Format.FORWARDREF},
+            {"format": annotationlib.Format.VALUE, "eval_str": False},
+            {"format": annotationlib.Format.FORWARDREF, "eval_str": False},
+        ]:
+            with self.subTest(**kwargs):
+                self.assertEqual(
+                    annotationlib.get_annotations(isa, **kwargs), {"a": int, "b": str}
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.MyClass, **kwargs),
+                    {"a": int, "b": str},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.function, **kwargs),
+                    {"a": int, "b": str, "return": isa.MyClass},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.function2, **kwargs),
+                    {"a": int, "b": "str", "c": isa.MyClass, "return": isa.MyClass},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.function3, **kwargs),
+                    {"a": "int", "b": "str", "c": "MyClass"},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(annotationlib, **kwargs), {}
+                )  # annotations module has no annotations
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.UnannotatedClass, **kwargs), {}
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.unannotated_function, **kwargs),
+                    {},
+                )
+
+        for kwargs in [
+            {"eval_str": True},
+            {"format": annotationlib.Format.VALUE, "eval_str": True},
+        ]:
+            with self.subTest(**kwargs):
+                self.assertEqual(
+                    annotationlib.get_annotations(isa, **kwargs), {"a": int, "b": str}
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.MyClass, **kwargs),
+                    {"a": int, "b": str},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.function, **kwargs),
+                    {"a": int, "b": str, "return": isa.MyClass},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.function2, **kwargs),
+                    {"a": int, "b": str, "c": isa.MyClass, "return": isa.MyClass},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.function3, **kwargs),
+                    {"a": int, "b": str, "c": isa.MyClass},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(annotationlib, **kwargs), {}
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.UnannotatedClass, **kwargs), {}
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.unannotated_function, **kwargs),
+                    {},
+                )
+
+        self.assertEqual(
+            annotationlib.get_annotations(isa, format=annotationlib.Format.SOURCE),
+            {"a": "int", "b": "str"},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(
+                isa.MyClass, format=annotationlib.Format.SOURCE
+            ),
+            {"a": "int", "b": "str"},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(
+                isa.function, format=annotationlib.Format.SOURCE
+            ),
+            {"a": "int", "b": "str", "return": "MyClass"},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(
+                isa.function2, format=annotationlib.Format.SOURCE
+            ),
+            {"a": "int", "b": "str", "c": "MyClass", "return": "MyClass"},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(
+                isa.function3, format=annotationlib.Format.SOURCE
+            ),
+            {"a": "int", "b": "str", "c": "MyClass"},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(
+                annotationlib, format=annotationlib.Format.SOURCE
+            ),
+            {},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(
+                isa.UnannotatedClass, format=annotationlib.Format.SOURCE
+            ),
+            {},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(
+                isa.unannotated_function, format=annotationlib.Format.SOURCE
+            ),
+            {},
+        )
+
+    def test_stock_annotations_on_wrapper(self):
+        isa = inspect_stock_annotations
+
+        wrapped = times_three(isa.function)
+        self.assertEqual(wrapped(1, "x"), isa.MyClass(3, "xxx"))
+        self.assertIsNot(wrapped.__globals__, isa.function.__globals__)
+        self.assertEqual(
+            annotationlib.get_annotations(wrapped),
+            {"a": int, "b": str, "return": isa.MyClass},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(
+                wrapped, format=annotationlib.Format.FORWARDREF
+            ),
+            {"a": int, "b": str, "return": isa.MyClass},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(wrapped, format=annotationlib.Format.SOURCE),
+            {"a": "int", "b": "str", "return": "MyClass"},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(wrapped, eval_str=True),
+            {"a": int, "b": str, "return": isa.MyClass},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(wrapped, eval_str=False),
+            {"a": int, "b": str, "return": isa.MyClass},
+        )
+
+    def test_stringized_annotations_in_module(self):
+        isa = inspect_stringized_annotations
+        for kwargs in [
+            {},
+            {"eval_str": False},
+            {"format": annotationlib.Format.VALUE},
+            {"format": annotationlib.Format.FORWARDREF},
+            {"format": annotationlib.Format.SOURCE},
+            {"format": annotationlib.Format.VALUE, "eval_str": False},
+            {"format": annotationlib.Format.FORWARDREF, "eval_str": False},
+            {"format": annotationlib.Format.SOURCE, "eval_str": False},
+        ]:
+            with self.subTest(**kwargs):
+                self.assertEqual(
+                    annotationlib.get_annotations(isa, **kwargs),
+                    {"a": "int", "b": "str"},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.MyClass, **kwargs),
+                    {"a": "int", "b": "str"},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.function, **kwargs),
+                    {"a": "int", "b": "str", "return": "MyClass"},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.function2, **kwargs),
+                    {"a": "int", "b": "'str'", "c": "MyClass", "return": "MyClass"},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.function3, **kwargs),
+                    {"a": "'int'", "b": "'str'", "c": "'MyClass'"},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.UnannotatedClass, **kwargs), {}
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.unannotated_function, **kwargs),
+                    {},
+                )
+
+        for kwargs in [
+            {"eval_str": True},
+            {"format": annotationlib.Format.VALUE, "eval_str": True},
+        ]:
+            with self.subTest(**kwargs):
+                self.assertEqual(
+                    annotationlib.get_annotations(isa, **kwargs), {"a": int, "b": str}
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.MyClass, **kwargs),
+                    {"a": int, "b": str},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.function, **kwargs),
+                    {"a": int, "b": str, "return": isa.MyClass},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.function2, **kwargs),
+                    {"a": int, "b": "str", "c": isa.MyClass, "return": isa.MyClass},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.function3, **kwargs),
+                    {"a": "int", "b": "str", "c": "MyClass"},
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.UnannotatedClass, **kwargs), {}
+                )
+                self.assertEqual(
+                    annotationlib.get_annotations(isa.unannotated_function, **kwargs),
+                    {},
+                )
+
+    def test_stringized_annotations_in_empty_module(self):
+        isa2 = inspect_stringized_annotations_2
+        self.assertEqual(annotationlib.get_annotations(isa2), {})
+        self.assertEqual(annotationlib.get_annotations(isa2, eval_str=True), {})
+        self.assertEqual(annotationlib.get_annotations(isa2, eval_str=False), {})
+
+    def test_stringized_annotations_on_wrapper(self):
+        isa = inspect_stringized_annotations
+        wrapped = times_three(isa.function)
+        self.assertEqual(wrapped(1, "x"), isa.MyClass(3, "xxx"))
+        self.assertIsNot(wrapped.__globals__, isa.function.__globals__)
+        self.assertEqual(
+            annotationlib.get_annotations(wrapped),
+            {"a": "int", "b": "str", "return": "MyClass"},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(wrapped, eval_str=True),
+            {"a": int, "b": str, "return": isa.MyClass},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(wrapped, eval_str=False),
+            {"a": "int", "b": "str", "return": "MyClass"},
+        )
+
+    def test_stringized_annotations_on_class(self):
+        isa = inspect_stringized_annotations
+        # test that local namespace lookups work
+        self.assertEqual(
+            annotationlib.get_annotations(isa.MyClassWithLocalAnnotations),
+            {"x": "mytype"},
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(
+                isa.MyClassWithLocalAnnotations, eval_str=True
+            ),
+            {"x": int},
+        )
+
+    def test_modify_annotations(self):
+        def f(x: int):
+            pass
+
+        self.assertEqual(annotationlib.get_annotations(f), {"x": int})
+        self.assertEqual(
+            annotationlib.get_annotations(f, format=annotationlib.Format.FORWARDREF),
+            {"x": int},
+        )
+
+        f.__annotations__["x"] = str
+        # The modification is reflected in VALUE (the default)
+        self.assertEqual(annotationlib.get_annotations(f), {"x": str})
+        # ... but not in FORWARDREF, which uses __annotate__
+        self.assertEqual(
+            annotationlib.get_annotations(f, format=annotationlib.Format.FORWARDREF),
+            {"x": int},
+        )
+
+    def test_pep695_generic_class_with_future_annotations(self):
+        ann_module695 = inspect_stringized_annotations_pep695
+        A_annotations = annotationlib.get_annotations(ann_module695.A, eval_str=True)
+        A_type_params = ann_module695.A.__type_params__
+        self.assertIs(A_annotations["x"], A_type_params[0])
+        self.assertEqual(A_annotations["y"].__args__[0], Unpack[A_type_params[1]])
+        self.assertIs(A_annotations["z"].__args__[0], A_type_params[2])
+
+    def test_pep695_generic_class_with_future_annotations_and_local_shadowing(self):
+        B_annotations = annotationlib.get_annotations(
+            inspect_stringized_annotations_pep695.B, eval_str=True
+        )
+        self.assertEqual(B_annotations, {"x": int, "y": str, "z": bytes})
+
+    def test_pep695_generic_class_with_future_annotations_name_clash_with_global_vars(self):
+        ann_module695 = inspect_stringized_annotations_pep695
+        C_annotations = annotationlib.get_annotations(ann_module695.C, eval_str=True)
+        self.assertEqual(
+            set(C_annotations.values()),
+            set(ann_module695.C.__type_params__)
+        )
+
+    def test_pep_695_generic_function_with_future_annotations(self):
+        ann_module695 = inspect_stringized_annotations_pep695
+        generic_func_annotations = annotationlib.get_annotations(
+            ann_module695.generic_function, eval_str=True
+        )
+        func_t_params = ann_module695.generic_function.__type_params__
+        self.assertEqual(
+            generic_func_annotations.keys(), {"x", "y", "z", "zz", "return"}
+        )
+        self.assertIs(generic_func_annotations["x"], func_t_params[0])
+        self.assertEqual(generic_func_annotations["y"], Unpack[func_t_params[1]])
+        self.assertIs(generic_func_annotations["z"].__origin__, func_t_params[2])
+        self.assertIs(generic_func_annotations["zz"].__origin__, func_t_params[2])
+
+    def test_pep_695_generic_function_with_future_annotations_name_clash_with_global_vars(self):
+        self.assertEqual(
+            set(
+                annotationlib.get_annotations(
+                    inspect_stringized_annotations_pep695.generic_function_2,
+                    eval_str=True
+                ).values()
+            ),
+            set(
+                inspect_stringized_annotations_pep695.generic_function_2.__type_params__
+            )
+        )
+
+    def test_pep_695_generic_method_with_future_annotations(self):
+        ann_module695 = inspect_stringized_annotations_pep695
+        generic_method_annotations = annotationlib.get_annotations(
+            ann_module695.D.generic_method, eval_str=True
+        )
+        params = {
+            param.__name__: param
+            for param in ann_module695.D.generic_method.__type_params__
+        }
+        self.assertEqual(
+            generic_method_annotations,
+            {"x": params["Foo"], "y": params["Bar"], "return": None}
+        )
+
+    def test_pep_695_generic_method_with_future_annotations_name_clash_with_global_vars(self):
+        self.assertEqual(
+            set(
+                annotationlib.get_annotations(
+                    inspect_stringized_annotations_pep695.D.generic_method_2,
+                    eval_str=True
+                ).values()
+            ),
+            set(
+                inspect_stringized_annotations_pep695.D.generic_method_2.__type_params__
+            )
+        )
+
+    def test_pep_695_generic_method_with_future_annotations_name_clash_with_global_and_local_vars(self):
+        self.assertEqual(
+            annotationlib.get_annotations(
+                inspect_stringized_annotations_pep695.E, eval_str=True
+            ),
+            {"x": str},
+        )
+
+    def test_pep_695_generics_with_future_annotations_nested_in_function(self):
+        results = inspect_stringized_annotations_pep695.nested()
+
+        self.assertEqual(
+            set(results.F_annotations.values()),
+            set(results.F.__type_params__)
+        )
+        self.assertEqual(
+            set(results.F_meth_annotations.values()),
+            set(results.F.generic_method.__type_params__)
+        )
+        self.assertNotEqual(
+            set(results.F_meth_annotations.values()),
+            set(results.F.__type_params__)
+        )
+        self.assertEqual(
+            set(results.F_meth_annotations.values()).intersection(results.F.__type_params__),
+            set()
+        )
+
+        self.assertEqual(results.G_annotations, {"x": str})
+
+        self.assertEqual(
+            set(results.generic_func_annotations.values()),
+            set(results.generic_func.__type_params__)
+        )
index ffb8bbe75c504f24016e03d64fab74f6aff448d0..b93c99d8c90bf35d4d91e28b232aff83b5e02178 100644 (file)
@@ -4807,6 +4807,16 @@ class TestKeywordArgs(unittest.TestCase):
         self.assertTrue(fields(B)[0].kw_only)
         self.assertFalse(fields(B)[1].kw_only)
 
+    def test_deferred_annotations(self):
+        @dataclass
+        class A:
+            x: undefined
+            y: ClassVar[undefined]
+
+        fs = fields(A)
+        self.assertEqual(len(fs), 1)
+        self.assertEqual(fs[0].name, 'x')
+
 
 if __name__ == '__main__':
     unittest.main()
index 492a16a8c7ff4582bfb2a5a513d333bd78ef65f9..837f3795f0842dca6803013e0d13821324e196ba 100644 (file)
@@ -741,6 +741,26 @@ class TestUpdateWrapper(unittest.TestCase):
         self.assertEqual(wrapper.__annotations__, {})
         self.assertEqual(wrapper.__type_params__, ())
 
+    def test_update_wrapper_annotations(self):
+        def inner(x: int): pass
+        def wrapper(*args): pass
+
+        functools.update_wrapper(wrapper, inner)
+        self.assertEqual(wrapper.__annotations__, {'x': int})
+        self.assertIs(wrapper.__annotate__, inner.__annotate__)
+
+        def with_forward_ref(x: undefined): pass
+        def wrapper(*args): pass
+
+        functools.update_wrapper(wrapper, with_forward_ref)
+
+        self.assertIs(wrapper.__annotate__, with_forward_ref.__annotate__)
+        with self.assertRaises(NameError):
+            wrapper.__annotations__
+
+        undefined = str
+        self.assertEqual(wrapper.__annotations__, {'x': undefined})
+
 
 class TestWraps(TestUpdateWrapper):
 
@@ -3059,6 +3079,27 @@ class TestSingleDispatch(unittest.TestCase):
         self.assertEqual(f(""), "default")
         self.assertEqual(f(b""), "default")
 
+    def test_forward_reference(self):
+        @functools.singledispatch
+        def f(arg, arg2=None):
+            return "default"
+
+        @f.register
+        def _(arg: str, arg2: undefined = None):
+            return "forward reference"
+
+        self.assertEqual(f(1), "default")
+        self.assertEqual(f(""), "forward reference")
+
+    def test_unresolved_forward_reference(self):
+        @functools.singledispatch
+        def f(arg):
+            return "default"
+
+        with self.assertRaisesRegex(TypeError, "is an unresolved forward reference"):
+            @f.register
+            def _(arg: undefined):
+                return "forward reference"
 
 class CachedCostItem:
     _cost = 1
index 5b7a639c025a0f3b802b3b00e4d4652fd22475fb..6a841587f49166ca908a7b25c2f6c134eb0eb5d7 100644 (file)
@@ -3,6 +3,7 @@
 
 from test.support import check_syntax_error
 from test.support import import_helper
+import annotationlib
 import inspect
 import unittest
 import sys
@@ -459,7 +460,7 @@ class GrammarTests(unittest.TestCase):
         gns = {}; lns = {}
         exec("'docstring'\n"
              "x: int = 5\n", gns, lns)
-        self.assertEqual(lns["__annotate__"](1), {'x': int})
+        self.assertEqual(lns["__annotate__"](annotationlib.Format.VALUE), {'x': int})
         with self.assertRaises(KeyError):
             gns['__annotate__']
 
index d39c3ccdc847bd095413c22e3ef277cc1c6920a8..5521528a524762c5eeff6bfa7b7de53b4c32653e 100644 (file)
@@ -45,10 +45,7 @@ from test import support
 
 from test.test_inspect import inspect_fodder as mod
 from test.test_inspect import inspect_fodder2 as mod2
-from test.test_inspect import inspect_stock_annotations
 from test.test_inspect import inspect_stringized_annotations
-from test.test_inspect import inspect_stringized_annotations_2
-from test.test_inspect import inspect_stringized_annotations_pep695
 
 
 # Functions tested in this suite:
@@ -126,7 +123,7 @@ class IsTestBase(unittest.TestCase):
             self.assertFalse(other(obj), 'not %s(%s)' % (other.__name__, exp))
 
     def test__all__(self):
-        support.check__all__(self, inspect, not_exported=("modulesbyfile",))
+        support.check__all__(self, inspect, not_exported=("modulesbyfile",), extra=("get_annotations",))
 
 def generator_function_example(self):
     for i in range(2):
@@ -1595,216 +1592,6 @@ class TestClassesAndFunctions(unittest.TestCase):
         attrs = [a[0] for a in inspect.getmembers(C)]
         self.assertNotIn('missing', attrs)
 
-    def test_get_annotations_with_stock_annotations(self):
-        def foo(a:int, b:str): pass
-        self.assertEqual(inspect.get_annotations(foo), {'a': int, 'b': str})
-
-        foo.__annotations__ = {'a': 'foo', 'b':'str'}
-        self.assertEqual(inspect.get_annotations(foo), {'a': 'foo', 'b': 'str'})
-
-        self.assertEqual(inspect.get_annotations(foo, eval_str=True, locals=locals()), {'a': foo, 'b': str})
-        self.assertEqual(inspect.get_annotations(foo, eval_str=True, globals=locals()), {'a': foo, 'b': str})
-
-        isa = inspect_stock_annotations
-        self.assertEqual(inspect.get_annotations(isa), {'a': int, 'b': str})
-        self.assertEqual(inspect.get_annotations(isa.MyClass), {'a': int, 'b': str})
-        self.assertEqual(inspect.get_annotations(isa.function), {'a': int, 'b': str, 'return': isa.MyClass})
-        self.assertEqual(inspect.get_annotations(isa.function2), {'a': int, 'b': 'str', 'c': isa.MyClass, 'return': isa.MyClass})
-        self.assertEqual(inspect.get_annotations(isa.function3), {'a': 'int', 'b': 'str', 'c': 'MyClass'})
-        self.assertEqual(inspect.get_annotations(inspect), {}) # inspect module has no annotations
-        self.assertEqual(inspect.get_annotations(isa.UnannotatedClass), {})
-        self.assertEqual(inspect.get_annotations(isa.unannotated_function), {})
-
-        self.assertEqual(inspect.get_annotations(isa, eval_str=True), {'a': int, 'b': str})
-        self.assertEqual(inspect.get_annotations(isa.MyClass, eval_str=True), {'a': int, 'b': str})
-        self.assertEqual(inspect.get_annotations(isa.function, eval_str=True), {'a': int, 'b': str, 'return': isa.MyClass})
-        self.assertEqual(inspect.get_annotations(isa.function2, eval_str=True), {'a': int, 'b': str, 'c': isa.MyClass, 'return': isa.MyClass})
-        self.assertEqual(inspect.get_annotations(isa.function3, eval_str=True), {'a': int, 'b': str, 'c': isa.MyClass})
-        self.assertEqual(inspect.get_annotations(inspect, eval_str=True), {})
-        self.assertEqual(inspect.get_annotations(isa.UnannotatedClass, eval_str=True), {})
-        self.assertEqual(inspect.get_annotations(isa.unannotated_function, eval_str=True), {})
-
-        self.assertEqual(inspect.get_annotations(isa, eval_str=False), {'a': int, 'b': str})
-        self.assertEqual(inspect.get_annotations(isa.MyClass, eval_str=False), {'a': int, 'b': str})
-        self.assertEqual(inspect.get_annotations(isa.function, eval_str=False), {'a': int, 'b': str, 'return': isa.MyClass})
-        self.assertEqual(inspect.get_annotations(isa.function2, eval_str=False), {'a': int, 'b': 'str', 'c': isa.MyClass, 'return': isa.MyClass})
-        self.assertEqual(inspect.get_annotations(isa.function3, eval_str=False), {'a': 'int', 'b': 'str', 'c': 'MyClass'})
-        self.assertEqual(inspect.get_annotations(inspect, eval_str=False), {})
-        self.assertEqual(inspect.get_annotations(isa.UnannotatedClass, eval_str=False), {})
-        self.assertEqual(inspect.get_annotations(isa.unannotated_function, eval_str=False), {})
-
-        def times_three(fn):
-            @functools.wraps(fn)
-            def wrapper(a, b):
-                return fn(a*3, b*3)
-            return wrapper
-
-        wrapped = times_three(isa.function)
-        self.assertEqual(wrapped(1, 'x'), isa.MyClass(3, 'xxx'))
-        self.assertIsNot(wrapped.__globals__, isa.function.__globals__)
-        self.assertEqual(inspect.get_annotations(wrapped), {'a': int, 'b': str, 'return': isa.MyClass})
-        self.assertEqual(inspect.get_annotations(wrapped, eval_str=True), {'a': int, 'b': str, 'return': isa.MyClass})
-        self.assertEqual(inspect.get_annotations(wrapped, eval_str=False), {'a': int, 'b': str, 'return': isa.MyClass})
-
-    def test_get_annotations_with_stringized_annotations(self):
-        isa = inspect_stringized_annotations
-        self.assertEqual(inspect.get_annotations(isa), {'a': 'int', 'b': 'str'})
-        self.assertEqual(inspect.get_annotations(isa.MyClass), {'a': 'int', 'b': 'str'})
-        self.assertEqual(inspect.get_annotations(isa.function), {'a': 'int', 'b': 'str', 'return': 'MyClass'})
-        self.assertEqual(inspect.get_annotations(isa.function2), {'a': 'int', 'b': "'str'", 'c': 'MyClass', 'return': 'MyClass'})
-        self.assertEqual(inspect.get_annotations(isa.function3), {'a': "'int'", 'b': "'str'", 'c': "'MyClass'"})
-        self.assertEqual(inspect.get_annotations(isa.UnannotatedClass), {})
-        self.assertEqual(inspect.get_annotations(isa.unannotated_function), {})
-
-        self.assertEqual(inspect.get_annotations(isa, eval_str=True), {'a': int, 'b': str})
-        self.assertEqual(inspect.get_annotations(isa.MyClass, eval_str=True), {'a': int, 'b': str})
-        self.assertEqual(inspect.get_annotations(isa.function, eval_str=True), {'a': int, 'b': str, 'return': isa.MyClass})
-        self.assertEqual(inspect.get_annotations(isa.function2, eval_str=True), {'a': int, 'b': 'str', 'c': isa.MyClass, 'return': isa.MyClass})
-        self.assertEqual(inspect.get_annotations(isa.function3, eval_str=True), {'a': 'int', 'b': 'str', 'c': 'MyClass'})
-        self.assertEqual(inspect.get_annotations(isa.UnannotatedClass, eval_str=True), {})
-        self.assertEqual(inspect.get_annotations(isa.unannotated_function, eval_str=True), {})
-
-        self.assertEqual(inspect.get_annotations(isa, eval_str=False), {'a': 'int', 'b': 'str'})
-        self.assertEqual(inspect.get_annotations(isa.MyClass, eval_str=False), {'a': 'int', 'b': 'str'})
-        self.assertEqual(inspect.get_annotations(isa.function, eval_str=False), {'a': 'int', 'b': 'str', 'return': 'MyClass'})
-        self.assertEqual(inspect.get_annotations(isa.function2, eval_str=False), {'a': 'int', 'b': "'str'", 'c': 'MyClass', 'return': 'MyClass'})
-        self.assertEqual(inspect.get_annotations(isa.function3, eval_str=False), {'a': "'int'", 'b': "'str'", 'c': "'MyClass'"})
-        self.assertEqual(inspect.get_annotations(isa.UnannotatedClass, eval_str=False), {})
-        self.assertEqual(inspect.get_annotations(isa.unannotated_function, eval_str=False), {})
-
-        isa2 = inspect_stringized_annotations_2
-        self.assertEqual(inspect.get_annotations(isa2), {})
-        self.assertEqual(inspect.get_annotations(isa2, eval_str=True), {})
-        self.assertEqual(inspect.get_annotations(isa2, eval_str=False), {})
-
-        def times_three(fn):
-            @functools.wraps(fn)
-            def wrapper(a, b):
-                return fn(a*3, b*3)
-            return wrapper
-
-        wrapped = times_three(isa.function)
-        self.assertEqual(wrapped(1, 'x'), isa.MyClass(3, 'xxx'))
-        self.assertIsNot(wrapped.__globals__, isa.function.__globals__)
-        self.assertEqual(inspect.get_annotations(wrapped), {'a': 'int', 'b': 'str', 'return': 'MyClass'})
-        self.assertEqual(inspect.get_annotations(wrapped, eval_str=True), {'a': int, 'b': str, 'return': isa.MyClass})
-        self.assertEqual(inspect.get_annotations(wrapped, eval_str=False), {'a': 'int', 'b': 'str', 'return': 'MyClass'})
-
-        # test that local namespace lookups work
-        self.assertEqual(inspect.get_annotations(isa.MyClassWithLocalAnnotations), {'x': 'mytype'})
-        self.assertEqual(inspect.get_annotations(isa.MyClassWithLocalAnnotations, eval_str=True), {'x': int})
-
-    def test_pep695_generic_class_with_future_annotations(self):
-        ann_module695 = inspect_stringized_annotations_pep695
-        A_annotations = inspect.get_annotations(ann_module695.A, eval_str=True)
-        A_type_params = ann_module695.A.__type_params__
-        self.assertIs(A_annotations["x"], A_type_params[0])
-        self.assertEqual(A_annotations["y"].__args__[0], Unpack[A_type_params[1]])
-        self.assertIs(A_annotations["z"].__args__[0], A_type_params[2])
-
-    def test_pep695_generic_class_with_future_annotations_and_local_shadowing(self):
-        B_annotations = inspect.get_annotations(
-            inspect_stringized_annotations_pep695.B, eval_str=True
-        )
-        self.assertEqual(B_annotations, {"x": int, "y": str, "z": bytes})
-
-    def test_pep695_generic_class_with_future_annotations_name_clash_with_global_vars(self):
-        ann_module695 = inspect_stringized_annotations_pep695
-        C_annotations = inspect.get_annotations(ann_module695.C, eval_str=True)
-        self.assertEqual(
-            set(C_annotations.values()),
-            set(ann_module695.C.__type_params__)
-        )
-
-    def test_pep_695_generic_function_with_future_annotations(self):
-        ann_module695 = inspect_stringized_annotations_pep695
-        generic_func_annotations = inspect.get_annotations(
-            ann_module695.generic_function, eval_str=True
-        )
-        func_t_params = ann_module695.generic_function.__type_params__
-        self.assertEqual(
-            generic_func_annotations.keys(), {"x", "y", "z", "zz", "return"}
-        )
-        self.assertIs(generic_func_annotations["x"], func_t_params[0])
-        self.assertEqual(generic_func_annotations["y"], Unpack[func_t_params[1]])
-        self.assertIs(generic_func_annotations["z"].__origin__, func_t_params[2])
-        self.assertIs(generic_func_annotations["zz"].__origin__, func_t_params[2])
-
-    def test_pep_695_generic_function_with_future_annotations_name_clash_with_global_vars(self):
-        self.assertEqual(
-            set(
-                inspect.get_annotations(
-                    inspect_stringized_annotations_pep695.generic_function_2,
-                    eval_str=True
-                ).values()
-            ),
-            set(
-                inspect_stringized_annotations_pep695.generic_function_2.__type_params__
-            )
-        )
-
-    def test_pep_695_generic_method_with_future_annotations(self):
-        ann_module695 = inspect_stringized_annotations_pep695
-        generic_method_annotations = inspect.get_annotations(
-            ann_module695.D.generic_method, eval_str=True
-        )
-        params = {
-            param.__name__: param
-            for param in ann_module695.D.generic_method.__type_params__
-        }
-        self.assertEqual(
-            generic_method_annotations,
-            {"x": params["Foo"], "y": params["Bar"], "return": None}
-        )
-
-    def test_pep_695_generic_method_with_future_annotations_name_clash_with_global_vars(self):
-        self.assertEqual(
-            set(
-                inspect.get_annotations(
-                    inspect_stringized_annotations_pep695.D.generic_method_2,
-                    eval_str=True
-                ).values()
-            ),
-            set(
-                inspect_stringized_annotations_pep695.D.generic_method_2.__type_params__
-            )
-        )
-
-    def test_pep_695_generic_method_with_future_annotations_name_clash_with_global_and_local_vars(self):
-        self.assertEqual(
-            inspect.get_annotations(
-                inspect_stringized_annotations_pep695.E, eval_str=True
-            ),
-            {"x": str},
-        )
-
-    def test_pep_695_generics_with_future_annotations_nested_in_function(self):
-        results = inspect_stringized_annotations_pep695.nested()
-
-        self.assertEqual(
-            set(results.F_annotations.values()),
-            set(results.F.__type_params__)
-        )
-        self.assertEqual(
-            set(results.F_meth_annotations.values()),
-            set(results.F.generic_method.__type_params__)
-        )
-        self.assertNotEqual(
-            set(results.F_meth_annotations.values()),
-            set(results.F.__type_params__)
-        )
-        self.assertEqual(
-            set(results.F_meth_annotations.values()).intersection(results.F.__type_params__),
-            set()
-        )
-
-        self.assertEqual(results.G_annotations, {"x": str})
-
-        self.assertEqual(
-            set(results.generic_func_annotations.values()),
-            set(results.generic_func.__type_params__)
-        )
-
 
 class TestFormatAnnotation(unittest.TestCase):
     def test_typing_replacement(self):
index a9be1f5aa8468111de09b15880c9258b6293540f..91082e6b23c04ba3321277d1c07f6817b762623e 100644 (file)
@@ -1,12 +1,9 @@
+import annotationlib
 import textwrap
 import types
 import unittest
 from test.support import run_code, check_syntax_error
 
-VALUE = 1
-FORWARDREF = 2
-SOURCE = 3
-
 
 class TypeAnnotationTests(unittest.TestCase):
 
@@ -376,12 +373,12 @@ class DeferredEvaluationTests(unittest.TestCase):
                 self.assertIsInstance(annotate, types.FunctionType)
                 self.assertEqual(annotate.__name__, "__annotate__")
                 with self.assertRaises(NotImplementedError):
-                    annotate(FORWARDREF)
+                    annotate(annotationlib.Format.FORWARDREF)
                 with self.assertRaises(NotImplementedError):
-                    annotate(SOURCE)
+                    annotate(annotationlib.Format.SOURCE)
                 with self.assertRaises(NotImplementedError):
                     annotate(None)
-                self.assertEqual(annotate(VALUE), {"x": int})
+                self.assertEqual(annotate(annotationlib.Format.VALUE), {"x": int})
 
     def test_comprehension_in_annotation(self):
         # This crashed in an earlier version of the code
@@ -398,7 +395,7 @@ class DeferredEvaluationTests(unittest.TestCase):
         f = ns["f"]
         self.assertIsInstance(f.__annotate__, types.FunctionType)
         annos = {"x": "int", "return": "int"}
-        self.assertEqual(f.__annotate__(VALUE), annos)
+        self.assertEqual(f.__annotate__(annotationlib.Format.VALUE), annos)
         self.assertEqual(f.__annotations__, annos)
 
     def test_name_clash_with_format(self):
index a931da55908236c470415ab7d62d7febd1db4f76..290b3c63a762e9d237a6d284cddb1043c01e688d 100644 (file)
@@ -1,3 +1,4 @@
+import annotationlib
 import contextlib
 import collections
 import collections.abc
@@ -45,7 +46,7 @@ import typing
 import weakref
 import types
 
-from test.support import captured_stderr, cpython_only, infinite_recursion, requires_docstrings, import_helper
+from test.support import captured_stderr, cpython_only, infinite_recursion, requires_docstrings, import_helper, run_code
 from test.typinganndata import ann_module695, mod_generics_cache, _typed_dict_helper
 
 
@@ -7812,6 +7813,48 @@ class NamedTupleTests(BaseTestCase):
                 def _source(self):
                     return 'no chance for this as well'
 
+    def test_annotation_type_check(self):
+        # These are rejected by _type_check
+        with self.assertRaises(TypeError):
+            class X(NamedTuple):
+                a: Final
+        with self.assertRaises(TypeError):
+            class Y(NamedTuple):
+                a: (1, 2)
+
+        # Conversion by _type_convert
+        class Z(NamedTuple):
+            a: None
+            b: "str"
+        annos = {'a': type(None), 'b': ForwardRef("str")}
+        self.assertEqual(Z.__annotations__, annos)
+        self.assertEqual(Z.__annotate__(annotationlib.Format.VALUE), annos)
+        self.assertEqual(Z.__annotate__(annotationlib.Format.FORWARDREF), annos)
+        self.assertEqual(Z.__annotate__(annotationlib.Format.SOURCE), {"a": "None", "b": "str"})
+
+    def test_future_annotations(self):
+        code = """
+        from __future__ import annotations
+        from typing import NamedTuple
+        class X(NamedTuple):
+            a: int
+            b: None
+        """
+        ns = run_code(textwrap.dedent(code))
+        X = ns['X']
+        self.assertEqual(X.__annotations__, {'a': ForwardRef("int"), 'b': ForwardRef("None")})
+
+    def test_deferred_annotations(self):
+        class X(NamedTuple):
+            y: undefined
+
+        self.assertEqual(X._fields, ('y',))
+        with self.assertRaises(NameError):
+            X.__annotations__
+
+        undefined = int
+        self.assertEqual(X.__annotations__, {'y': int})
+
     def test_multiple_inheritance(self):
         class A:
             pass
@@ -8126,7 +8169,11 @@ class TypedDictTests(BaseTestCase):
         self.assertEqual(Emp.__name__, 'Emp')
         self.assertEqual(Emp.__module__, __name__)
         self.assertEqual(Emp.__bases__, (dict,))
-        self.assertEqual(Emp.__annotations__, {'name': str, 'id': int})
+        annos = {'name': str, 'id': int}
+        self.assertEqual(Emp.__annotations__, annos)
+        self.assertEqual(Emp.__annotate__(annotationlib.Format.VALUE), annos)
+        self.assertEqual(Emp.__annotate__(annotationlib.Format.FORWARDREF), annos)
+        self.assertEqual(Emp.__annotate__(annotationlib.Format.SOURCE), {'name': 'str', 'id': 'int'})
         self.assertEqual(Emp.__total__, True)
         self.assertEqual(Emp.__required_keys__, {'name', 'id'})
         self.assertIsInstance(Emp.__required_keys__, frozenset)
@@ -8487,6 +8534,8 @@ class TypedDictTests(BaseTestCase):
         self.assertEqual(A.__bases__, (Generic, dict))
         self.assertEqual(A.__orig_bases__, (TypedDict, Generic[T]))
         self.assertEqual(A.__mro__, (A, Generic, dict, object))
+        self.assertEqual(A.__annotations__, {'a': T})
+        self.assertEqual(A.__annotate__(annotationlib.Format.SOURCE), {'a': 'T'})
         self.assertEqual(A.__parameters__, (T,))
         self.assertEqual(A[str].__parameters__, ())
         self.assertEqual(A[str].__args__, (str,))
@@ -8498,6 +8547,8 @@ class TypedDictTests(BaseTestCase):
         self.assertEqual(A.__bases__, (Generic, dict))
         self.assertEqual(A.__orig_bases__, (TypedDict, Generic[T]))
         self.assertEqual(A.__mro__, (A, Generic, dict, object))
+        self.assertEqual(A.__annotations__, {'a': T})
+        self.assertEqual(A.__annotate__(annotationlib.Format.SOURCE), {'a': 'T'})
         self.assertEqual(A.__parameters__, (T,))
         self.assertEqual(A[str].__parameters__, ())
         self.assertEqual(A[str].__args__, (str,))
@@ -8508,6 +8559,8 @@ class TypedDictTests(BaseTestCase):
         self.assertEqual(A2.__bases__, (Generic, dict))
         self.assertEqual(A2.__orig_bases__, (Generic[T], TypedDict))
         self.assertEqual(A2.__mro__, (A2, Generic, dict, object))
+        self.assertEqual(A2.__annotations__, {'a': T})
+        self.assertEqual(A2.__annotate__(annotationlib.Format.SOURCE), {'a': 'T'})
         self.assertEqual(A2.__parameters__, (T,))
         self.assertEqual(A2[str].__parameters__, ())
         self.assertEqual(A2[str].__args__, (str,))
@@ -8518,6 +8571,8 @@ class TypedDictTests(BaseTestCase):
         self.assertEqual(B.__bases__, (Generic, dict))
         self.assertEqual(B.__orig_bases__, (A[KT],))
         self.assertEqual(B.__mro__, (B, Generic, dict, object))
+        self.assertEqual(B.__annotations__, {'a': T, 'b': KT})
+        self.assertEqual(B.__annotate__(annotationlib.Format.SOURCE), {'a': 'T', 'b': 'KT'})
         self.assertEqual(B.__parameters__, (KT,))
         self.assertEqual(B.__total__, False)
         self.assertEqual(B.__optional_keys__, frozenset(['b']))
@@ -8542,6 +8597,11 @@ class TypedDictTests(BaseTestCase):
             'b': KT,
             'c': int,
         })
+        self.assertEqual(C.__annotate__(annotationlib.Format.SOURCE), {
+            'a': 'T',
+            'b': 'KT',
+            'c': 'int',
+        })
         with self.assertRaises(TypeError):
             C[str]
 
@@ -8561,6 +8621,11 @@ class TypedDictTests(BaseTestCase):
             'b': T,
             'c': KT,
         })
+        self.assertEqual(Point3D.__annotate__(annotationlib.Format.SOURCE), {
+            'a': 'T',
+            'b': 'T',
+            'c': 'KT',
+        })
         self.assertEqual(Point3D[int, str].__origin__, Point3D)
 
         with self.assertRaises(TypeError):
@@ -8592,6 +8657,11 @@ class TypedDictTests(BaseTestCase):
             'b': KT,
             'c': int,
         })
+        self.assertEqual(WithImplicitAny.__annotate__(annotationlib.Format.SOURCE), {
+            'a': 'T',
+            'b': 'KT',
+            'c': 'int',
+        })
         with self.assertRaises(TypeError):
             WithImplicitAny[str]
 
@@ -8748,6 +8818,54 @@ class TypedDictTests(BaseTestCase):
             },
         )
 
+    def test_annotations(self):
+        # _type_check is applied
+        with self.assertRaisesRegex(TypeError, "Plain typing.Final is not valid as type argument"):
+            class X(TypedDict):
+                a: Final
+
+        # _type_convert is applied
+        class Y(TypedDict):
+            a: None
+            b: "int"
+        fwdref = ForwardRef('int', module='test.test_typing')
+        self.assertEqual(Y.__annotations__, {'a': type(None), 'b': fwdref})
+        self.assertEqual(Y.__annotate__(annotationlib.Format.FORWARDREF), {'a': type(None), 'b': fwdref})
+
+        # _type_check is also applied later
+        class Z(TypedDict):
+            a: undefined
+
+        with self.assertRaises(NameError):
+            Z.__annotations__
+
+        undefined = Final
+        with self.assertRaisesRegex(TypeError, "Plain typing.Final is not valid as type argument"):
+            Z.__annotations__
+
+        undefined = None
+        self.assertEqual(Z.__annotations__, {'a': type(None)})
+
+    def test_deferred_evaluation(self):
+        class A(TypedDict):
+            x: NotRequired[undefined]
+            y: ReadOnly[undefined]
+            z: Required[undefined]
+
+        self.assertEqual(A.__required_keys__, frozenset({'y', 'z'}))
+        self.assertEqual(A.__optional_keys__, frozenset({'x'}))
+        self.assertEqual(A.__readonly_keys__, frozenset({'y'}))
+        self.assertEqual(A.__mutable_keys__, frozenset({'x', 'z'}))
+
+        with self.assertRaises(NameError):
+            A.__annotations__
+
+        self.assertEqual(
+            A.__annotate__(annotationlib.Format.SOURCE),
+            {'x': 'NotRequired[undefined]', 'y': 'ReadOnly[undefined]',
+             'z': 'Required[undefined]'},
+        )
+
 
 class RequiredTests(BaseTestCase):
 
@@ -10075,7 +10193,6 @@ class SpecialAttrsTests(BaseTestCase):
             typing.ClassVar: 'ClassVar',
             typing.Concatenate: 'Concatenate',
             typing.Final: 'Final',
-            typing.ForwardRef: 'ForwardRef',
             typing.Literal: 'Literal',
             typing.NewType: 'NewType',
             typing.NoReturn: 'NoReturn',
@@ -10087,7 +10204,7 @@ class SpecialAttrsTests(BaseTestCase):
             typing.TypeVar: 'TypeVar',
             typing.Union: 'Union',
             typing.Self: 'Self',
-            # Subscribed special forms
+            # Subscripted special forms
             typing.Annotated[Any, "Annotation"]: 'Annotated',
             typing.Annotated[int, 'Annotation']: 'Annotated',
             typing.ClassVar[Any]: 'ClassVar',
@@ -10102,7 +10219,6 @@ class SpecialAttrsTests(BaseTestCase):
             typing.Union[Any]: 'Any',
             typing.Union[int, float]: 'Union',
             # Incompatible special forms (tested in test_special_attrs2)
-            # - typing.ForwardRef('set[Any]')
             # - typing.NewType('TypeName', Any)
             # - typing.ParamSpec('SpecialAttrsP')
             # - typing.TypeVar('T')
@@ -10121,18 +10237,6 @@ class SpecialAttrsTests(BaseTestCase):
     TypeName = typing.NewType('SpecialAttrsTests.TypeName', Any)
 
     def test_special_attrs2(self):
-        # Forward refs provide a different introspection API. __name__ and
-        # __qualname__ make little sense for forward refs as they can store
-        # complex typing expressions.
-        fr = typing.ForwardRef('set[Any]')
-        self.assertFalse(hasattr(fr, '__name__'))
-        self.assertFalse(hasattr(fr, '__qualname__'))
-        self.assertEqual(fr.__module__, 'typing')
-        # Forward refs are currently unpicklable.
-        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
-            with self.assertRaises(TypeError):
-                pickle.dumps(fr, proto)
-
         self.assertEqual(SpecialAttrsTests.TypeName.__name__, 'TypeName')
         self.assertEqual(
             SpecialAttrsTests.TypeName.__qualname__,
index bc17d136082891bcb6cfaf511d70d3497c3d624d..626053d8166160b6faa63f5a0471dd25bdbf269f 100644 (file)
@@ -19,6 +19,8 @@ that may be changed without notice. Use at your own risk!
 """
 
 from abc import abstractmethod, ABCMeta
+import annotationlib
+from annotationlib import ForwardRef
 import collections
 from collections import defaultdict
 import collections.abc
@@ -125,6 +127,7 @@ __all__ = [
     'cast',
     'clear_overloads',
     'dataclass_transform',
+    'evaluate_forward_ref',
     'final',
     'get_args',
     'get_origin',
@@ -165,7 +168,7 @@ def _type_convert(arg, module=None, *, allow_special_forms=False):
     if arg is None:
         return type(None)
     if isinstance(arg, str):
-        return ForwardRef(arg, module=module, is_class=allow_special_forms)
+        return _make_forward_ref(arg, module=module, is_class=allow_special_forms)
     return arg
 
 
@@ -459,7 +462,8 @@ class _Sentinel:
 _sentinel = _Sentinel()
 
 
-def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=frozenset()):
+def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=frozenset(),
+               format=annotationlib.Format.VALUE, owner=None):
     """Evaluate all forward references in the given type t.
 
     For use of globalns and localns see the docstring for get_type_hints().
@@ -470,11 +474,13 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f
         _deprecation_warning_for_no_type_params_passed("typing._eval_type")
         type_params = ()
     if isinstance(t, ForwardRef):
-        return t._evaluate(globalns, localns, type_params, recursive_guard=recursive_guard)
+        return evaluate_forward_ref(t, globals=globalns, locals=localns,
+                                    type_params=type_params, owner=owner,
+                                    _recursive_guard=recursive_guard, format=format)
     if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)):
         if isinstance(t, GenericAlias):
             args = tuple(
-                ForwardRef(arg) if isinstance(arg, str) else arg
+                _make_forward_ref(arg) if isinstance(arg, str) else arg
                 for arg in t.__args__
             )
             is_unpacked = t.__unpacked__
@@ -487,7 +493,8 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f
 
         ev_args = tuple(
             _eval_type(
-                a, globalns, localns, type_params, recursive_guard=recursive_guard
+                a, globalns, localns, type_params, recursive_guard=recursive_guard,
+                format=format, owner=owner,
             )
             for a in t.__args__
         )
@@ -1011,111 +1018,77 @@ def TypeIs(self, parameters):
     return _GenericAlias(self, (item,))
 
 
-class ForwardRef(_Final, _root=True):
-    """Internal wrapper to hold a forward reference."""
+def _make_forward_ref(code, **kwargs):
+    forward_ref = ForwardRef(code, **kwargs)
+    # For compatibility, eagerly compile the forwardref's code.
+    forward_ref.__forward_code__
+    return forward_ref
 
-    __slots__ = ('__forward_arg__', '__forward_code__',
-                 '__forward_evaluated__', '__forward_value__',
-                 '__forward_is_argument__', '__forward_is_class__',
-                 '__forward_module__')
 
-    def __init__(self, arg, is_argument=True, module=None, *, is_class=False):
-        if not isinstance(arg, str):
-            raise TypeError(f"Forward reference must be a string -- got {arg!r}")
-
-        # If we do `def f(*args: *Ts)`, then we'll have `arg = '*Ts'`.
-        # Unfortunately, this isn't a valid expression on its own, so we
-        # do the unpacking manually.
-        if arg.startswith('*'):
-            arg_to_compile = f'({arg},)[0]'  # E.g. (*Ts,)[0] or (*tuple[int, int],)[0]
-        else:
-            arg_to_compile = arg
-        try:
-            code = compile(arg_to_compile, '<string>', 'eval')
-        except SyntaxError:
-            raise SyntaxError(f"Forward reference must be an expression -- got {arg!r}")
-
-        self.__forward_arg__ = arg
-        self.__forward_code__ = code
-        self.__forward_evaluated__ = False
-        self.__forward_value__ = None
-        self.__forward_is_argument__ = is_argument
-        self.__forward_is_class__ = is_class
-        self.__forward_module__ = module
-
-    def _evaluate(self, globalns, localns, type_params=_sentinel, *, recursive_guard):
-        if type_params is _sentinel:
-            _deprecation_warning_for_no_type_params_passed("typing.ForwardRef._evaluate")
-            type_params = ()
-        if self.__forward_arg__ in recursive_guard:
-            return self
-        if not self.__forward_evaluated__ or localns is not globalns:
-            if globalns is None and localns is None:
-                globalns = localns = {}
-            elif globalns is None:
-                globalns = localns
-            elif localns is None:
-                localns = globalns
-            if self.__forward_module__ is not None:
-                globalns = getattr(
-                    sys.modules.get(self.__forward_module__, None), '__dict__', globalns
-                )
-
-            # type parameters require some special handling,
-            # as they exist in their own scope
-            # but `eval()` does not have a dedicated parameter for that scope.
-            # For classes, names in type parameter scopes should override
-            # names in the global scope (which here are called `localns`!),
-            # but should in turn be overridden by names in the class scope
-            # (which here are called `globalns`!)
-            if type_params:
-                globalns, localns = dict(globalns), dict(localns)
-                for param in type_params:
-                    param_name = param.__name__
-                    if not self.__forward_is_class__ or param_name not in globalns:
-                        globalns[param_name] = param
-                        localns.pop(param_name, None)
-
-            type_ = _type_check(
-                eval(self.__forward_code__, globalns, localns),
-                "Forward references must evaluate to types.",
-                is_argument=self.__forward_is_argument__,
-                allow_special_forms=self.__forward_is_class__,
-            )
-            self.__forward_value__ = _eval_type(
-                type_,
-                globalns,
-                localns,
-                type_params,
-                recursive_guard=(recursive_guard | {self.__forward_arg__}),
-            )
-            self.__forward_evaluated__ = True
-        return self.__forward_value__
-
-    def __eq__(self, other):
-        if not isinstance(other, ForwardRef):
-            return NotImplemented
-        if self.__forward_evaluated__ and other.__forward_evaluated__:
-            return (self.__forward_arg__ == other.__forward_arg__ and
-                    self.__forward_value__ == other.__forward_value__)
-        return (self.__forward_arg__ == other.__forward_arg__ and
-                self.__forward_module__ == other.__forward_module__)
-
-    def __hash__(self):
-        return hash((self.__forward_arg__, self.__forward_module__))
-
-    def __or__(self, other):
-        return Union[self, other]
+def evaluate_forward_ref(
+    forward_ref,
+    *,
+    owner=None,
+    globals=None,
+    locals=None,
+    type_params=None,
+    format=annotationlib.Format.VALUE,
+    _recursive_guard=frozenset(),
+):
+    """Evaluate a forward reference as a type hint.
+
+    This is similar to calling the ForwardRef.evaluate() method,
+    but unlike that method, evaluate_forward_ref() also:
+
+    * Recursively evaluates forward references nested within the type hint.
+    * Rejects certain objects that are not valid type hints.
+    * Replaces type hints that evaluate to None with types.NoneType.
+    * Supports the *FORWARDREF* and *SOURCE* formats.
+
+    *forward_ref* must be an instance of ForwardRef. *owner*, if given,
+    should be the object that holds the annotations that the forward reference
+    derived from, such as a module, class object, or function. It is used to
+    infer the namespaces to use for looking up names. *globals* and *locals*
+    can also be explicitly given to provide the global and local namespaces.
+    *type_params* is a tuple of type parameters that are in scope when
+    evaluating the forward reference. This parameter must be provided (though
+    it may be an empty tuple) if *owner* is not given and the forward reference
+    does not already have an owner set. *format* specifies the format of the
+    annotation and is a member of the annoations.Format enum.
 
-    def __ror__(self, other):
-        return Union[other, self]
+    """
+    if type_params is _sentinel:
+        _deprecation_warning_for_no_type_params_passed("typing.evaluate_forward_ref")
+        type_params = ()
+    if format == annotationlib.Format.SOURCE:
+        return forward_ref.__forward_arg__
+    if forward_ref.__forward_arg__ in _recursive_guard:
+        return forward_ref
 
-    def __repr__(self):
-        if self.__forward_module__ is None:
-            module_repr = ''
+    try:
+        value = forward_ref.evaluate(globals=globals, locals=locals,
+                                     type_params=type_params, owner=owner)
+    except NameError:
+        if format == annotationlib.Format.FORWARDREF:
+            return forward_ref
         else:
-            module_repr = f', module={self.__forward_module__!r}'
-        return f'ForwardRef({self.__forward_arg__!r}{module_repr})'
+            raise
+
+    type_ = _type_check(
+        value,
+        "Forward references must evaluate to types.",
+        is_argument=forward_ref.__forward_is_argument__,
+        allow_special_forms=forward_ref.__forward_is_class__,
+    )
+    return _eval_type(
+        type_,
+        globals,
+        locals,
+        type_params,
+        recursive_guard=_recursive_guard | {forward_ref.__forward_arg__},
+        format=format,
+        owner=owner,
+    )
 
 
 def _is_unpacked_typevartuple(x: Any) -> bool:
@@ -2196,7 +2169,7 @@ class _AnnotatedAlias(_NotIterable, _GenericAlias, _root=True):
     """Runtime representation of an annotated type.
 
     At its core 'Annotated[t, dec1, dec2, ...]' is an alias for the type 't'
-    with extra annotations. The alias behaves like a normal typing alias.
+    with extra metadata. The alias behaves like a normal typing alias.
     Instantiating is the same as instantiating the underlying type; binding
     it to types is also the same.
 
@@ -2380,7 +2353,8 @@ _allowed_types = (types.FunctionType, types.BuiltinFunctionType,
                   WrapperDescriptorType, MethodWrapperType, MethodDescriptorType)
 
 
-def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
+def get_type_hints(obj, globalns=None, localns=None, include_extras=False,
+                   *, format=annotationlib.Format.VALUE):
     """Return type hints for an object.
 
     This is often the same as obj.__annotations__, but it handles
@@ -2417,13 +2391,14 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
     if isinstance(obj, type):
         hints = {}
         for base in reversed(obj.__mro__):
+            ann = annotationlib.get_annotations(base, format=format)
+            if format is annotationlib.Format.SOURCE:
+                hints.update(ann)
+                continue
             if globalns is None:
                 base_globals = getattr(sys.modules.get(base.__module__, None), '__dict__', {})
             else:
                 base_globals = globalns
-            ann = getattr(base, '__annotations__', {})
-            if isinstance(ann, types.GetSetDescriptorType):
-                ann = {}
             base_locals = dict(vars(base)) if localns is None else localns
             if localns is None and globalns is None:
                 # This is surprising, but required.  Before Python 3.10,
@@ -2437,10 +2412,26 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
                 if value is None:
                     value = type(None)
                 if isinstance(value, str):
-                    value = ForwardRef(value, is_argument=False, is_class=True)
-                value = _eval_type(value, base_globals, base_locals, base.__type_params__)
+                    value = _make_forward_ref(value, is_argument=False, is_class=True)
+                value = _eval_type(value, base_globals, base_locals, base.__type_params__,
+                                   format=format, owner=obj)
                 hints[name] = value
-        return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()}
+        if include_extras or format is annotationlib.Format.SOURCE:
+            return hints
+        else:
+            return {k: _strip_annotations(t) for k, t in hints.items()}
+
+    hints = annotationlib.get_annotations(obj, format=format)
+    if (
+        not hints
+        and not isinstance(obj, types.ModuleType)
+        and not callable(obj)
+        and not hasattr(obj, '__annotations__')
+        and not hasattr(obj, '__annotate__')
+    ):
+        raise TypeError(f"{obj!r} is not a module, class, or callable.")
+    if format is annotationlib.Format.SOURCE:
+        return hints
 
     if globalns is None:
         if isinstance(obj, types.ModuleType):
@@ -2455,15 +2446,6 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
             localns = globalns
     elif localns is None:
         localns = globalns
-    hints = getattr(obj, '__annotations__', None)
-    if hints is None:
-        # Return empty annotations for something that _could_ have them.
-        if isinstance(obj, _allowed_types):
-            return {}
-        else:
-            raise TypeError('{!r} is not a module, class, method, '
-                            'or function.'.format(obj))
-    hints = dict(hints)
     type_params = getattr(obj, "__type_params__", ())
     for name, value in hints.items():
         if value is None:
@@ -2471,12 +2453,12 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
         if isinstance(value, str):
             # class-level forward refs were handled above, this must be either
             # a module-level annotation or a function argument annotation
-            value = ForwardRef(
+            value = _make_forward_ref(
                 value,
                 is_argument=not isinstance(obj, types.ModuleType),
                 is_class=False,
             )
-        hints[name] = _eval_type(value, globalns, localns, type_params)
+        hints[name] = _eval_type(value, globalns, localns, type_params, format=format, owner=obj)
     return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()}
 
 
@@ -2953,22 +2935,34 @@ class SupportsRound[T](Protocol):
         pass
 
 
-def _make_nmtuple(name, types, module, defaults = ()):
-    fields = [n for n, t in types]
-    types = {n: _type_check(t, f"field {n} annotation must be a type")
-             for n, t in types}
+def _make_nmtuple(name, fields, annotate_func, module, defaults = ()):
     nm_tpl = collections.namedtuple(name, fields,
                                     defaults=defaults, module=module)
-    nm_tpl.__annotations__ = nm_tpl.__new__.__annotations__ = types
+    nm_tpl.__annotate__ = nm_tpl.__new__.__annotate__ = annotate_func
     return nm_tpl
 
 
+def _make_eager_annotate(types):
+    checked_types = {key: _type_check(val, f"field {key} annotation must be a type")
+                     for key, val in types.items()}
+    def annotate(format):
+        if format in (annotationlib.Format.VALUE, annotationlib.Format.FORWARDREF):
+            return checked_types
+        else:
+            return _convert_to_source(types)
+    return annotate
+
+
+def _convert_to_source(types):
+    return {n: t if isinstance(t, str) else _type_repr(t) for n, t in types.items()}
+
+
 # attributes prohibited to set in NamedTuple class syntax
 _prohibited = frozenset({'__new__', '__init__', '__slots__', '__getnewargs__',
                          '_fields', '_field_defaults',
                          '_make', '_replace', '_asdict', '_source'})
 
-_special = frozenset({'__module__', '__name__', '__annotations__'})
+_special = frozenset({'__module__', '__name__', '__annotations__', '__annotate__'})
 
 
 class NamedTupleMeta(type):
@@ -2981,12 +2975,29 @@ class NamedTupleMeta(type):
         bases = tuple(tuple if base is _NamedTuple else base for base in bases)
         if "__annotations__" in ns:
             types = ns["__annotations__"]
+            field_names = list(types)
+            annotate = _make_eager_annotate(types)
         elif "__annotate__" in ns:
-            types = ns["__annotate__"](1)  # VALUE
+            original_annotate = ns["__annotate__"]
+            types = annotationlib.call_annotate_function(original_annotate, annotationlib.Format.FORWARDREF)
+            field_names = list(types)
+
+            # For backward compatibility, type-check all the types at creation time
+            for typ in types.values():
+                _type_check(typ, "field annotation must be a type")
+
+            def annotate(format):
+                annos = annotationlib.call_annotate_function(original_annotate, format)
+                if format != annotationlib.Format.SOURCE:
+                    return {key: _type_check(val, f"field {key} annotation must be a type")
+                            for key, val in annos.items()}
+                return annos
         else:
-            types = {}
+            # Empty NamedTuple
+            field_names = []
+            annotate = lambda format: {}
         default_names = []
-        for field_name in types:
+        for field_name in field_names:
             if field_name in ns:
                 default_names.append(field_name)
             elif default_names:
@@ -2994,7 +3005,7 @@ class NamedTupleMeta(type):
                                 f"cannot follow default field"
                                 f"{'s' if len(default_names) > 1 else ''} "
                                 f"{', '.join(default_names)}")
-        nm_tpl = _make_nmtuple(typename, types.items(),
+        nm_tpl = _make_nmtuple(typename, field_names, annotate,
                                defaults=[ns[n] for n in default_names],
                                module=ns['__module__'])
         nm_tpl.__bases__ = bases
@@ -3085,7 +3096,11 @@ def NamedTuple(typename, fields=_sentinel, /, **kwargs):
         import warnings
         warnings._deprecated(deprecated_thing, message=deprecation_msg, remove=(3, 15))
         fields = kwargs.items()
-    nt = _make_nmtuple(typename, fields, module=_caller())
+    types = {n: _type_check(t, f"field {n} annotation must be a type")
+             for n, t in fields}
+    field_names = [n for n, _ in fields]
+
+    nt = _make_nmtuple(typename, field_names, _make_eager_annotate(types), module=_caller())
     nt.__orig_bases__ = (NamedTuple,)
     return nt
 
@@ -3144,15 +3159,19 @@ class _TypedDictMeta(type):
         if not hasattr(tp_dict, '__orig_bases__'):
             tp_dict.__orig_bases__ = bases
 
-        annotations = {}
         if "__annotations__" in ns:
+            own_annotate = None
             own_annotations = ns["__annotations__"]
         elif "__annotate__" in ns:
-            own_annotations = ns["__annotate__"](1)  # VALUE
+            own_annotate = ns["__annotate__"]
+            own_annotations = annotationlib.call_annotate_function(
+                own_annotate, annotationlib.Format.FORWARDREF, owner=tp_dict
+            )
         else:
+            own_annotate = None
             own_annotations = {}
         msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"
-        own_annotations = {
+        own_checked_annotations = {
             n: _type_check(tp, msg, module=tp_dict.__module__)
             for n, tp in own_annotations.items()
         }
@@ -3162,13 +3181,6 @@ class _TypedDictMeta(type):
         mutable_keys = set()
 
         for base in bases:
-            # TODO: Avoid eagerly evaluating annotations in VALUE format.
-            # Instead, evaluate in FORWARDREF format to figure out which
-            # keys have Required/NotRequired/ReadOnly qualifiers, and create
-            # a new __annotate__ function for the resulting TypedDict that
-            # combines the annotations from this class and its parents.
-            annotations.update(base.__annotations__)
-
             base_required = base.__dict__.get('__required_keys__', set())
             required_keys |= base_required
             optional_keys -= base_required
@@ -3180,8 +3192,7 @@ class _TypedDictMeta(type):
             readonly_keys.update(base.__dict__.get('__readonly_keys__', ()))
             mutable_keys.update(base.__dict__.get('__mutable_keys__', ()))
 
-        annotations.update(own_annotations)
-        for annotation_key, annotation_type in own_annotations.items():
+        for annotation_key, annotation_type in own_checked_annotations.items():
             qualifiers = set(_get_typeddict_qualifiers(annotation_type))
             if Required in qualifiers:
                 is_required = True
@@ -3212,7 +3223,32 @@ class _TypedDictMeta(type):
             f"Required keys overlap with optional keys in {name}:"
             f" {required_keys=}, {optional_keys=}"
         )
-        tp_dict.__annotations__ = annotations
+
+        def __annotate__(format):
+            annos = {}
+            for base in bases:
+                if base is Generic:
+                    continue
+                base_annotate = base.__annotate__
+                if base_annotate is None:
+                    continue
+                base_annos = annotationlib.call_annotate_function(base.__annotate__, format, owner=base)
+                annos.update(base_annos)
+            if own_annotate is not None:
+                own = annotationlib.call_annotate_function(own_annotate, format, owner=tp_dict)
+                if format != annotationlib.Format.SOURCE:
+                    own = {
+                        n: _type_check(tp, msg, module=tp_dict.__module__)
+                        for n, tp in own.items()
+                    }
+            elif format == annotationlib.Format.SOURCE:
+                own = _convert_to_source(own_annotations)
+            else:
+                own = own_checked_annotations
+            annos.update(own)
+            return annos
+
+        tp_dict.__annotate__ = __annotate__
         tp_dict.__required_keys__ = frozenset(required_keys)
         tp_dict.__optional_keys__ = frozenset(optional_keys)
         tp_dict.__readonly_keys__ = frozenset(readonly_keys)
diff --git a/Misc/NEWS.d/next/Library/2024-06-11-07-17-25.gh-issue-119180.iH-2zy.rst b/Misc/NEWS.d/next/Library/2024-06-11-07-17-25.gh-issue-119180.iH-2zy.rst
new file mode 100644 (file)
index 0000000..f24d7bd
--- /dev/null
@@ -0,0 +1,4 @@
+As part of implementing :pep:`649` and :pep:`749`, add a new module
+``annotationlib``. Add support for unresolved forward references in
+annotations to :mod:`dataclasses`, :class:`typing.TypedDict`, and
+:class:`typing.NamedTuple`.
index 9686d10563aa4d9eef878d11a76365e777cb4e82..4d595d98445a058bb134406ed838312019a8a60d 100644 (file)
@@ -99,6 +99,7 @@ static const char* _Py_stdlib_module_names[] = {
 "_winapi",
 "_zoneinfo",
 "abc",
+"annotationlib",
 "antigravity",
 "argparse",
 "array",