]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-119180: Use equality when comparing against `annotationlib.Format` (#131755)
authorVictorien <65306057+Viicos@users.noreply.github.com>
Fri, 28 Mar 2025 04:56:09 +0000 (05:56 +0100)
committerGitHub <noreply@github.com>
Fri, 28 Mar 2025 04:56:09 +0000 (21:56 -0700)
Lib/test/test_annotationlib.py
Lib/test/test_typing.py
Lib/typing.py

index 20f74b4ed0aadb92621c3c2b623752d07d6c9407..495606b48ed2e822f92d254819dad5f6fdf6bd42 100644 (file)
@@ -517,7 +517,7 @@ class TestGetAnnotations(unittest.TestCase):
 
         foo.__annotations__ = {"a": "foo", "b": "str"}
         for format in Format:
-            if format is Format.VALUE_WITH_FAKE_GLOBALS:
+            if format == Format.VALUE_WITH_FAKE_GLOBALS:
                 continue
             with self.subTest(format=format):
                 self.assertEqual(
@@ -816,7 +816,7 @@ class TestGetAnnotations(unittest.TestCase):
 
         wa = WeirdAnnotations()
         for format in Format:
-            if format is Format.VALUE_WITH_FAKE_GLOBALS:
+            if format == Format.VALUE_WITH_FAKE_GLOBALS:
                 continue
             with (
                 self.subTest(format=format),
index 402353404cb0fbf67695b5b45aa9311ee6b182d6..2c0297313cb4ab01c6054371bb45657f5c6e8dc6 100644 (file)
@@ -7158,6 +7158,8 @@ class GetTypeHintTests(BaseTestCase):
 
         self.assertEqual(get_type_hints(C, format=annotationlib.Format.STRING),
                          {'x': 'undefined'})
+        # Make sure using an int as format also works:
+        self.assertEqual(get_type_hints(C, format=4), {'x': 'undefined'})
 
     def test_get_type_hints_format_function(self):
         def func(x: undefined) -> undefined: ...
index 96211553a21e39008cdaee2440065c9676ba0127..e36da7e9f48b71bb6b3e29ed0b60eee760f1fb26 100644 (file)
@@ -2315,7 +2315,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False,
         hints = {}
         for base in reversed(obj.__mro__):
             ann = annotationlib.get_annotations(base, format=format)
-            if format is annotationlib.Format.STRING:
+            if format == annotationlib.Format.STRING:
                 hints.update(ann)
                 continue
             if globalns is None:
@@ -2339,7 +2339,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False,
                 value = _eval_type(value, base_globals, base_locals, base.__type_params__,
                                    format=format, owner=obj)
                 hints[name] = value
-        if include_extras or format is annotationlib.Format.STRING:
+        if include_extras or format == annotationlib.Format.STRING:
             return hints
         else:
             return {k: _strip_annotations(t) for k, t in hints.items()}
@@ -2353,7 +2353,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False,
         and not hasattr(obj, '__annotate__')
     ):
         raise TypeError(f"{obj!r} is not a module, class, or callable.")
-    if format is annotationlib.Format.STRING:
+    if format == annotationlib.Format.STRING:
         return hints
 
     if globalns is None: