]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
[3.14] gh-137530: generate an __annotate__ function for dataclasses __init__ (GH...
authorJelle Zijlstra <jelle.zijlstra@gmail.com>
Mon, 10 Nov 2025 15:14:32 +0000 (07:14 -0800)
committerGitHub <noreply@github.com>
Mon, 10 Nov 2025 15:14:32 +0000 (07:14 -0800)
(cherry picked from commit 12837c63635559873a5abddf511d38456d69617b)

Co-authored-by: David Ellis <ducksual@gmail.com>
Lib/dataclasses.py
Lib/test/test_dataclasses/__init__.py
Misc/NEWS.d/next/Library/2025-10-21-15-54-13.gh-issue-137530.ZyIVUH.rst [new file with mode: 0644]

index d29f1615f276d2c2b9b2da2437c87be2fb1252dd..fb7e1701cce0a46dec11009159fbc3a8b8ac7a59 100644 (file)
@@ -441,9 +441,11 @@ class _FuncBuilder:
         self.locals = {}
         self.overwrite_errors = {}
         self.unconditional_adds = {}
+        self.method_annotations = {}
 
     def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
-               overwrite_error=False, unconditional_add=False, decorator=None):
+               overwrite_error=False, unconditional_add=False, decorator=None,
+               annotation_fields=None):
         if locals is not None:
             self.locals.update(locals)
 
@@ -464,16 +466,14 @@ class _FuncBuilder:
 
         self.names.append(name)
 
-        if return_type is not MISSING:
-            self.locals[f'__dataclass_{name}_return_type__'] = return_type
-            return_annotation = f'->__dataclass_{name}_return_type__'
-        else:
-            return_annotation = ''
+        if annotation_fields is not None:
+            self.method_annotations[name] = (annotation_fields, return_type)
+
         args = ','.join(args)
         body = '\n'.join(body)
 
         # Compute the text of the entire function, add it to the text we're generating.
-        self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}){return_annotation}:\n{body}')
+        self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}):\n{body}')
 
     def add_fns_to_class(self, cls):
         # The source to all of the functions we're generating.
@@ -509,6 +509,15 @@ class _FuncBuilder:
         # Now that we've generated the functions, assign them into cls.
         for name, fn in zip(self.names, fns):
             fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"
+
+            try:
+                annotation_fields, return_type = self.method_annotations[name]
+            except KeyError:
+                pass
+            else:
+                annotate_fn = _make_annotate_function(cls, name, annotation_fields, return_type)
+                fn.__annotate__ = annotate_fn
+
             if self.unconditional_adds.get(name, False):
                 setattr(cls, name, fn)
             else:
@@ -524,6 +533,44 @@ class _FuncBuilder:
                     raise TypeError(error_msg)
 
 
+def _make_annotate_function(__class__, method_name, annotation_fields, return_type):
+    # Create an __annotate__ function for a dataclass
+    # Try to return annotations in the same format as they would be
+    # from a regular __init__ function
+
+    def __annotate__(format, /):
+        Format = annotationlib.Format
+        match format:
+            case Format.VALUE | Format.FORWARDREF | Format.STRING:
+                cls_annotations = {}
+                for base in reversed(__class__.__mro__):
+                    cls_annotations.update(
+                        annotationlib.get_annotations(base, format=format)
+                    )
+
+                new_annotations = {}
+                for k in annotation_fields:
+                    new_annotations[k] = cls_annotations[k]
+
+                if return_type is not MISSING:
+                    if format == Format.STRING:
+                        new_annotations["return"] = annotationlib.type_repr(return_type)
+                    else:
+                        new_annotations["return"] = return_type
+
+                return new_annotations
+
+            case _:
+                raise NotImplementedError(format)
+
+    # This is a flag for _add_slots to know it needs to regenerate this method
+    # In order to remove references to the original class when it is replaced
+    __annotate__.__generated_by_dataclasses__ = True
+    __annotate__.__qualname__ = f"{__class__.__qualname__}.{method_name}.__annotate__"
+
+    return __annotate__
+
+
 def _field_assign(frozen, name, value, self_name):
     # If we're a frozen class, then assign to our fields in __init__
     # via object.__setattr__.  Otherwise, just use a simple
@@ -612,7 +659,7 @@ def _init_param(f):
     elif f.default_factory is not MISSING:
         # There's a factory function.  Set a marker.
         default = '=__dataclass_HAS_DEFAULT_FACTORY__'
-    return f'{f.name}:__dataclass_type_{f.name}__{default}'
+    return f'{f.name}{default}'
 
 
 def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
@@ -635,11 +682,10 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
                 raise TypeError(f'non-default argument {f.name!r} '
                                 f'follows default argument {seen_default.name!r}')
 
-    locals = {**{f'__dataclass_type_{f.name}__': f.type for f in fields},
-              **{'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
-                 '__dataclass_builtins_object__': object,
-                 }
-              }
+    annotation_fields = [f.name for f in fields if f.init]
+
+    locals = {'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
+              '__dataclass_builtins_object__': object}
 
     body_lines = []
     for f in fields:
@@ -670,7 +716,8 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
                         [self_name] + _init_params,
                         body_lines,
                         locals=locals,
-                        return_type=None)
+                        return_type=None,
+                        annotation_fields=annotation_fields)
 
 
 def _frozen_get_del_attr(cls, fields, func_builder):
@@ -1336,6 +1383,25 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields):
                 or _update_func_cell_for__class__(member.fdel, cls, newcls)):
                 break
 
+    # Get new annotations to remove references to the original class
+    # in forward references
+    newcls_ann = annotationlib.get_annotations(
+        newcls, format=annotationlib.Format.FORWARDREF)
+
+    # Fix references in dataclass Fields
+    for f in getattr(newcls, _FIELDS).values():
+        try:
+            ann = newcls_ann[f.name]
+        except KeyError:
+            pass
+        else:
+            f.type = ann
+
+    # Fix the class reference in the __annotate__ method
+    init_annotate = newcls.__init__.__annotate__
+    if getattr(init_annotate, "__generated_by_dataclasses__", False):
+        _update_func_cell_for__class__(init_annotate, cls, newcls)
+
     return newcls
 
 
index 6bf5e5b3e5554be954ac4b34b5c78fda27531311..513dd78c4381b42762f8d89ff0e0cb26ae4159c9 100644 (file)
@@ -2471,6 +2471,135 @@ class TestInit(unittest.TestCase):
         self.assertEqual(D(5).a, 10)
 
 
+class TestInitAnnotate(unittest.TestCase):
+    # Tests for the generated __annotate__ function for __init__
+    # See: https://github.com/python/cpython/issues/137530
+
+    def test_annotate_function(self):
+        # No forward references
+        @dataclass
+        class A:
+            a: int
+
+        value_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.VALUE)
+        forwardref_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.FORWARDREF)
+        string_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.STRING)
+
+        self.assertEqual(value_annos, {'a': int, 'return': None})
+        self.assertEqual(forwardref_annos, {'a': int, 'return': None})
+        self.assertEqual(string_annos, {'a': 'int', 'return': 'None'})
+
+        self.assertTrue(getattr(A.__init__.__annotate__, "__generated_by_dataclasses__"))
+
+    def test_annotate_function_forwardref(self):
+        # With forward references
+        @dataclass
+        class B:
+            b: undefined
+
+        # VALUE annotations should raise while unresolvable
+        with self.assertRaises(NameError):
+            _ = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.VALUE)
+
+        forwardref_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.FORWARDREF)
+        string_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.STRING)
+
+        self.assertEqual(forwardref_annos, {'b': support.EqualToForwardRef('undefined', owner=B, is_class=True), 'return': None})
+        self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})
+
+        # Now VALUE and FORWARDREF should resolve, STRING should be unchanged
+        undefined = int
+
+        value_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.VALUE)
+        forwardref_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.FORWARDREF)
+        string_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.STRING)
+
+        self.assertEqual(value_annos, {'b': int, 'return': None})
+        self.assertEqual(forwardref_annos, {'b': int, 'return': None})
+        self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})
+
+    def test_annotate_function_init_false(self):
+        # Check `init=False` attributes don't get into the annotations of the __init__ function
+        @dataclass
+        class C:
+            c: str = field(init=False)
+
+        self.assertEqual(annotationlib.get_annotations(C.__init__), {'return': None})
+
+    def test_annotate_function_contains_forwardref(self):
+        # Check string annotations on objects containing a ForwardRef
+        @dataclass
+        class D:
+            d: list[undefined]
+
+        with self.assertRaises(NameError):
+            annotationlib.get_annotations(D.__init__)
+
+        self.assertEqual(
+            annotationlib.get_annotations(D.__init__, format=annotationlib.Format.FORWARDREF),
+            {"d": list[support.EqualToForwardRef("undefined", is_class=True, owner=D)], "return": None}
+        )
+
+        self.assertEqual(
+            annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
+            {"d": "list[undefined]", "return": "None"}
+        )
+
+        # Now test when it is defined
+        undefined = str
+
+        # VALUE should now resolve
+        self.assertEqual(
+            annotationlib.get_annotations(D.__init__),
+            {"d": list[str], "return": None}
+        )
+
+        self.assertEqual(
+            annotationlib.get_annotations(D.__init__, format=annotationlib.Format.FORWARDREF),
+            {"d": list[str], "return": None}
+        )
+
+        self.assertEqual(
+            annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
+            {"d": "list[undefined]", "return": "None"}
+        )
+
+    def test_annotate_function_not_replaced(self):
+        # Check that __annotate__ is not replaced on non-generated __init__ functions
+        @dataclass(slots=True)
+        class E:
+            x: str
+            def __init__(self, x: int) -> None:
+                self.x = x
+
+        self.assertEqual(
+            annotationlib.get_annotations(E.__init__), {"x": int, "return": None}
+        )
+
+        self.assertFalse(hasattr(E.__init__.__annotate__, "__generated_by_dataclasses__"))
+
+    def test_init_false_forwardref(self):
+        # Test forward references in fields not required for __init__ annotations.
+
+        # At the moment this raises a NameError for VALUE annotations even though the
+        # undefined annotation is not required for the __init__ annotations.
+        # Ideally this will be fixed but currently there is no good way to resolve this
+
+        @dataclass
+        class F:
+            not_in_init: list[undefined] = field(init=False, default=None)
+            in_init: int
+
+        annos = annotationlib.get_annotations(F.__init__, format=annotationlib.Format.FORWARDREF)
+        self.assertEqual(
+            annos,
+            {"in_init": int, "return": None},
+        )
+
+        with self.assertRaises(NameError):
+            annos = annotationlib.get_annotations(F.__init__)  # NameError on not_in_init
+
+
 class TestRepr(unittest.TestCase):
     def test_repr(self):
         @dataclass
@@ -3831,7 +3960,15 @@ class TestSlots(unittest.TestCase):
 
             return SlotsTest
 
-        for make in (make_simple, make_with_annotations, make_with_annotations_and_method):
+        def make_with_forwardref():
+            @dataclass(slots=True)
+            class SlotsTest:
+                x: undefined
+                y: list[undefined]
+
+            return SlotsTest
+
+        for make in (make_simple, make_with_annotations, make_with_annotations_and_method, make_with_forwardref):
             with self.subTest(make=make):
                 C = make()
                 support.gc_collect()
diff --git a/Misc/NEWS.d/next/Library/2025-10-21-15-54-13.gh-issue-137530.ZyIVUH.rst b/Misc/NEWS.d/next/Library/2025-10-21-15-54-13.gh-issue-137530.ZyIVUH.rst
new file mode 100644 (file)
index 0000000..4ff55b4
--- /dev/null
@@ -0,0 +1 @@
+:mod:`dataclasses` Fix annotations for generated ``__init__`` methods by replacing the annotations that were in-line in the generated source code with ``__annotate__`` functions attached to the methods.