]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-119180: Make FORWARDREF format look at __annotations__ first (#124479)
authorJelle Zijlstra <jelle.zijlstra@gmail.com>
Wed, 25 Sep 2024 22:32:45 +0000 (15:32 -0700)
committerGitHub <noreply@github.com>
Wed, 25 Sep 2024 22:32:45 +0000 (15:32 -0700)
From discussion with Larry Hastings and Carl Meyer, this is the desired
behavior.

Lib/annotationlib.py
Lib/test/test_annotationlib.py

index be3bc275817f50ad66f439728c3d53d0c576e40d..20c9542efac2d8bef923fc061db5b4bf9249ba42 100644 (file)
@@ -664,28 +664,38 @@ def get_annotations(
     if eval_str and format != Format.VALUE:
         raise ValueError("eval_str=True is only supported with format=Format.VALUE")
 
-    # For VALUE format, we look at __annotations__ directly.
-    if format != Format.VALUE:
-        annotate = get_annotate_function(obj)
-        if annotate is not None:
-            ann = call_annotate_function(annotate, format, owner=obj)
-            if not isinstance(ann, dict):
-                raise ValueError(f"{obj!r}.__annotate__ returned a non-dict")
-            return dict(ann)
-
-    if isinstance(obj, type):
-        try:
-            ann = _BASE_GET_ANNOTATIONS(obj)
-        except AttributeError:
-            # For static types, the descriptor raises AttributeError.
-            return {}
-    else:
-        ann = getattr(obj, "__annotations__", None)
-        if ann is None:
-            return {}
-
-    if not isinstance(ann, dict):
-        raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
+    match format:
+        case Format.VALUE:
+            # For VALUE, we only look at __annotations__
+            ann = _get_dunder_annotations(obj)
+        case Format.FORWARDREF:
+            # For FORWARDREF, we use __annotations__ if it exists
+            try:
+                ann = _get_dunder_annotations(obj)
+            except NameError:
+                pass
+            else:
+                return dict(ann)
+
+            # But if __annotations__ threw a NameError, we try calling __annotate__
+            ann = _get_and_call_annotate(obj, format)
+            if ann is not None:
+                return ann
+
+            # If that didn't work either, we have a very weird object: evaluating
+            # __annotations__ threw NameError and there is no __annotate__. In that case,
+            # we fall back to trying __annotations__ again.
+            return dict(_get_dunder_annotations(obj))
+        case Format.SOURCE:
+            # For SOURCE, we try to call __annotate__
+            ann = _get_and_call_annotate(obj, format)
+            if ann is not None:
+                return ann
+            # But if we didn't get it, we use __annotations__ instead.
+            ann = _get_dunder_annotations(obj)
+            return ann
+        case _:
+            raise ValueError(f"Unsupported format {format!r}")
 
     if not ann:
         return {}
@@ -750,3 +760,30 @@ def get_annotations(
         for key, value in ann.items()
     }
     return return_value
+
+
+def _get_and_call_annotate(obj, format):
+    annotate = get_annotate_function(obj)
+    if annotate is not None:
+        ann = call_annotate_function(annotate, format, owner=obj)
+        if not isinstance(ann, dict):
+            raise ValueError(f"{obj!r}.__annotate__ returned a non-dict")
+        return dict(ann)
+    return None
+
+
+def _get_dunder_annotations(obj):
+    if isinstance(obj, type):
+        try:
+            ann = _BASE_GET_ANNOTATIONS(obj)
+        except AttributeError:
+            # For static types, the descriptor raises AttributeError.
+            return {}
+    else:
+        ann = getattr(obj, "__annotations__", None)
+        if ann is None:
+            return {}
+
+    if not isinstance(ann, dict):
+        raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
+    return dict(ann)
index cc051ef3b9365817f651bcf90c352749a89c6257..5b052dab5007d64e96c2d4ff6df27fd197db8c37 100644 (file)
@@ -740,17 +740,97 @@ class TestGetAnnotations(unittest.TestCase):
 
         self.assertEqual(annotationlib.get_annotations(f), {"x": int})
         self.assertEqual(
-            annotationlib.get_annotations(f, format=annotationlib.Format.FORWARDREF),
+            annotationlib.get_annotations(f, format=Format.FORWARDREF),
             {"x": int},
         )
 
         f.__annotations__["x"] = str
         # The modification is reflected in VALUE (the default)
         self.assertEqual(annotationlib.get_annotations(f), {"x": str})
-        # ... but not in FORWARDREF, which uses __annotate__
+        # ... and also in FORWARDREF, which tries __annotations__ if available
         self.assertEqual(
-            annotationlib.get_annotations(f, format=annotationlib.Format.FORWARDREF),
-            {"x": int},
+            annotationlib.get_annotations(f, format=Format.FORWARDREF),
+            {"x": str},
+        )
+        # ... but not in SOURCE which always uses __annotate__
+        self.assertEqual(
+            annotationlib.get_annotations(f, format=Format.SOURCE),
+            {"x": "int"},
+        )
+
+    def test_non_dict_annotations(self):
+        class WeirdAnnotations:
+            @property
+            def __annotations__(self):
+                return "not a dict"
+
+        wa = WeirdAnnotations()
+        for format in Format:
+            with (
+                self.subTest(format=format),
+                self.assertRaisesRegex(
+                    ValueError, r".*__annotations__ is neither a dict nor None"
+                ),
+            ):
+                annotationlib.get_annotations(wa, format=format)
+
+    def test_annotations_on_custom_object(self):
+        class HasAnnotations:
+            @property
+            def __annotations__(self):
+                return {"x": int}
+
+        ha = HasAnnotations()
+        self.assertEqual(
+            annotationlib.get_annotations(ha, format=Format.VALUE), {"x": int}
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(ha, format=Format.FORWARDREF), {"x": int}
+        )
+
+        # TODO(gh-124412): This should return {'x': 'int'} instead.
+        self.assertEqual(
+            annotationlib.get_annotations(ha, format=Format.SOURCE), {"x": int}
+        )
+
+    def test_raising_annotations_on_custom_object(self):
+        class HasRaisingAnnotations:
+            @property
+            def __annotations__(self):
+                return {"x": undefined}
+
+        hra = HasRaisingAnnotations()
+
+        with self.assertRaises(NameError):
+            annotationlib.get_annotations(hra, format=Format.VALUE)
+
+        with self.assertRaises(NameError):
+            annotationlib.get_annotations(hra, format=Format.FORWARDREF)
+
+        undefined = float
+        self.assertEqual(
+            annotationlib.get_annotations(hra, format=Format.VALUE), {"x": float}
+        )
+
+    def test_forwardref_prefers_annotations(self):
+        class HasBoth:
+            @property
+            def __annotations__(self):
+                return {"x": int}
+
+            @property
+            def __annotate__(self):
+                return lambda format: {"x": str}
+
+        hb = HasBoth()
+        self.assertEqual(
+            annotationlib.get_annotations(hb, format=Format.VALUE), {"x": int}
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(hb, format=Format.FORWARDREF), {"x": int}
+        )
+        self.assertEqual(
+            annotationlib.get_annotations(hb, format=Format.SOURCE), {"x": str}
         )
 
     def test_pep695_generic_class_with_future_annotations(self):