]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-104873: Add typing.get_protocol_members and typing.is_protocol (#104878)
authorJelle Zijlstra <jelle.zijlstra@gmail.com>
Wed, 14 Jun 2023 12:35:06 +0000 (05:35 -0700)
committerGitHub <noreply@github.com>
Wed, 14 Jun 2023 12:35:06 +0000 (05:35 -0700)
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Doc/library/typing.rst
Doc/whatsnew/3.13.rst
Lib/test/test_typing.py
Lib/typing.py
Misc/NEWS.d/next/Library/2023-05-24-09-55-33.gh-issue-104873.BKQ54y.rst [new file with mode: 0644]

index 949b108c60c4f6526e4f1898f7eef46a11c581a2..487be8f28a788d73a608395f93514a1d114a165d 100644 (file)
@@ -3388,6 +3388,38 @@ Introspection helpers
 
    .. versionadded:: 3.8
 
+.. function:: get_protocol_members(tp)
+
+   Return the set of members defined in a :class:`Protocol`.
+
+   ::
+
+      >>> from typing import Protocol, get_protocol_members
+      >>> class P(Protocol):
+      ...     def a(self) -> str: ...
+      ...     b: int
+      >>> get_protocol_members(P)
+      frozenset({'a', 'b'})
+
+   Raise :exc:`TypeError` for arguments that are not Protocols.
+
+   .. versionadded:: 3.13
+
+.. function:: is_protocol(tp)
+
+   Determine if a type is a :class:`Protocol`.
+
+   For example::
+
+      class P(Protocol):
+          def a(self) -> str: ...
+          b: int
+
+      is_protocol(P)    # => True
+      is_protocol(int)  # => False
+
+   .. versionadded:: 3.13
+
 .. function:: is_typeddict(tp)
 
    Check if a type is a :class:`TypedDict`.
index 78d2a7b6b294d4fc9d4b5b61f5dce87f9ecdec8a..fcd10e522c8aca73ba81037eb13b3e3aeeda5826 100644 (file)
@@ -120,6 +120,14 @@ traceback
   to format the nested exceptions of a :exc:`BaseExceptionGroup` instance, recursively.
   (Contributed by Irit Katriel in :gh:`105292`.)
 
+typing
+------
+
+* Add :func:`typing.get_protocol_members` to return the set of members
+  defining a :class:`typing.Protocol`. Add :func:`typing.is_protocol` to
+  check whether a class is a :class:`typing.Protocol`. (Contributed by Jelle Zijlstra in
+  :gh:`104873`.)
+
 Optimizations
 =============
 
index 432fc88b1c072efdf7b0885cd4aef10a52896a68..a36d801c52514bf1446f3b1174fd81caf38187b8 100644 (file)
@@ -24,9 +24,9 @@ from typing import Callable
 from typing import Generic, ClassVar, Final, final, Protocol
 from typing import assert_type, cast, runtime_checkable
 from typing import get_type_hints
-from typing import get_origin, get_args
+from typing import get_origin, get_args, get_protocol_members
 from typing import override
-from typing import is_typeddict
+from typing import is_typeddict, is_protocol
 from typing import reveal_type
 from typing import dataclass_transform
 from typing import no_type_check, no_type_check_decorator
@@ -3363,6 +3363,18 @@ class ProtocolTests(BaseTestCase):
         self.assertNotIn("__callable_proto_members_only__", vars(NonP))
         self.assertNotIn("__callable_proto_members_only__", vars(NonPR))
 
+        self.assertEqual(get_protocol_members(P), {"x"})
+        self.assertEqual(get_protocol_members(PR), {"meth"})
+
+        # the returned object should be immutable,
+        # and should be a different object to the original attribute
+        # to prevent users from (accidentally or deliberately)
+        # mutating the attribute on the original class
+        self.assertIsInstance(get_protocol_members(P), frozenset)
+        self.assertIsNot(get_protocol_members(P), P.__protocol_attrs__)
+        self.assertIsInstance(get_protocol_members(PR), frozenset)
+        self.assertIsNot(get_protocol_members(PR), P.__protocol_attrs__)
+
         acceptable_extra_attrs = {
             '_is_protocol', '_is_runtime_protocol', '__parameters__',
             '__init__', '__annotations__', '__subclasshook__',
@@ -3778,6 +3790,59 @@ class ProtocolTests(BaseTestCase):
 
         Foo()  # Previously triggered RecursionError
 
+    def test_get_protocol_members(self):
+        with self.assertRaisesRegex(TypeError, "not a Protocol"):
+            get_protocol_members(object)
+        with self.assertRaisesRegex(TypeError, "not a Protocol"):
+            get_protocol_members(object())
+        with self.assertRaisesRegex(TypeError, "not a Protocol"):
+            get_protocol_members(Protocol)
+        with self.assertRaisesRegex(TypeError, "not a Protocol"):
+            get_protocol_members(Generic)
+
+        class P(Protocol):
+            a: int
+            def b(self) -> str: ...
+            @property
+            def c(self) -> int: ...
+
+        self.assertEqual(get_protocol_members(P), {'a', 'b', 'c'})
+        self.assertIsInstance(get_protocol_members(P), frozenset)
+        self.assertIsNot(get_protocol_members(P), P.__protocol_attrs__)
+
+        class Concrete:
+            a: int
+            def b(self) -> str: return "capybara"
+            @property
+            def c(self) -> int: return 5
+
+        with self.assertRaisesRegex(TypeError, "not a Protocol"):
+            get_protocol_members(Concrete)
+        with self.assertRaisesRegex(TypeError, "not a Protocol"):
+            get_protocol_members(Concrete())
+
+        class ConcreteInherit(P):
+            a: int = 42
+            def b(self) -> str: return "capybara"
+            @property
+            def c(self) -> int: return 5
+
+        with self.assertRaisesRegex(TypeError, "not a Protocol"):
+            get_protocol_members(ConcreteInherit)
+        with self.assertRaisesRegex(TypeError, "not a Protocol"):
+            get_protocol_members(ConcreteInherit())
+
+    def test_is_protocol(self):
+        self.assertTrue(is_protocol(Proto))
+        self.assertTrue(is_protocol(Point))
+        self.assertFalse(is_protocol(Concrete))
+        self.assertFalse(is_protocol(Concrete()))
+        self.assertFalse(is_protocol(Generic))
+        self.assertFalse(is_protocol(object))
+
+        # Protocol is not itself a protocol
+        self.assertFalse(is_protocol(Protocol))
+
     def test_interaction_with_isinstance_checks_on_superclasses_with_ABCMeta(self):
         # Ensure the cache is empty, or this test won't work correctly
         collections.abc.Sized._abc_registry_clear()
index a531e7d7abbef60bce09cc5177416044b422db1f..4e6dc44773538edfaa9bdaf4437483d464960ebd 100644 (file)
@@ -131,7 +131,9 @@ __all__ = [
     'get_args',
     'get_origin',
     'get_overloads',
+    'get_protocol_members',
     'get_type_hints',
+    'is_protocol',
     'is_typeddict',
     'LiteralString',
     'Never',
@@ -3337,3 +3339,43 @@ def override[F: _Func](method: F, /) -> F:
         # read-only property, TypeError if it's a builtin class.
         pass
     return method
+
+
+def is_protocol(tp: type, /) -> bool:
+    """Return True if the given type is a Protocol.
+
+    Example::
+
+        >>> from typing import Protocol, is_protocol
+        >>> class P(Protocol):
+        ...     def a(self) -> str: ...
+        ...     b: int
+        >>> is_protocol(P)
+        True
+        >>> is_protocol(int)
+        False
+    """
+    return (
+        isinstance(tp, type)
+        and getattr(tp, '_is_protocol', False)
+        and tp != Protocol
+    )
+
+
+def get_protocol_members(tp: type, /) -> frozenset[str]:
+    """Return the set of members defined in a Protocol.
+
+    Example::
+
+        >>> from typing import Protocol, get_protocol_members
+        >>> class P(Protocol):
+        ...     def a(self) -> str: ...
+        ...     b: int
+        >>> get_protocol_members(P)
+        frozenset({'a', 'b'})
+
+    Raise a TypeError for arguments that are not Protocols.
+    """
+    if not is_protocol(tp):
+        raise TypeError(f'{tp!r} is not a Protocol')
+    return frozenset(tp.__protocol_attrs__)
diff --git a/Misc/NEWS.d/next/Library/2023-05-24-09-55-33.gh-issue-104873.BKQ54y.rst b/Misc/NEWS.d/next/Library/2023-05-24-09-55-33.gh-issue-104873.BKQ54y.rst
new file mode 100644 (file)
index 0000000..c901d83
--- /dev/null
@@ -0,0 +1,3 @@
+Add :func:`typing.get_protocol_members` to return the set of members
+defining a :class:`typing.Protocol`.  Add :func:`typing.is_protocol` to
+check whether a class is a :class:`typing.Protocol`. Patch by Jelle Zijlstra.