from typing import Generic, ClassVar, Final, final, Protocol
from typing import assert_type, cast, runtime_checkable
from typing import get_type_hints
-from typing import get_origin, get_args
+from typing import get_origin, get_args, get_protocol_members
from typing import override
-from typing import is_typeddict
+from typing import is_typeddict, is_protocol
from typing import reveal_type
from typing import dataclass_transform
from typing import no_type_check, no_type_check_decorator
self.assertNotIn("__callable_proto_members_only__", vars(NonP))
self.assertNotIn("__callable_proto_members_only__", vars(NonPR))
+ self.assertEqual(get_protocol_members(P), {"x"})
+ self.assertEqual(get_protocol_members(PR), {"meth"})
+
+ # the returned object should be immutable,
+ # and should be a different object to the original attribute
+ # to prevent users from (accidentally or deliberately)
+ # mutating the attribute on the original class
+ self.assertIsInstance(get_protocol_members(P), frozenset)
+ self.assertIsNot(get_protocol_members(P), P.__protocol_attrs__)
+ self.assertIsInstance(get_protocol_members(PR), frozenset)
+ self.assertIsNot(get_protocol_members(PR), P.__protocol_attrs__)
+
acceptable_extra_attrs = {
'_is_protocol', '_is_runtime_protocol', '__parameters__',
'__init__', '__annotations__', '__subclasshook__',
Foo() # Previously triggered RecursionError
+ def test_get_protocol_members(self):
+ with self.assertRaisesRegex(TypeError, "not a Protocol"):
+ get_protocol_members(object)
+ with self.assertRaisesRegex(TypeError, "not a Protocol"):
+ get_protocol_members(object())
+ with self.assertRaisesRegex(TypeError, "not a Protocol"):
+ get_protocol_members(Protocol)
+ with self.assertRaisesRegex(TypeError, "not a Protocol"):
+ get_protocol_members(Generic)
+
+ class P(Protocol):
+ a: int
+ def b(self) -> str: ...
+ @property
+ def c(self) -> int: ...
+
+ self.assertEqual(get_protocol_members(P), {'a', 'b', 'c'})
+ self.assertIsInstance(get_protocol_members(P), frozenset)
+ self.assertIsNot(get_protocol_members(P), P.__protocol_attrs__)
+
+ class Concrete:
+ a: int
+ def b(self) -> str: return "capybara"
+ @property
+ def c(self) -> int: return 5
+
+ with self.assertRaisesRegex(TypeError, "not a Protocol"):
+ get_protocol_members(Concrete)
+ with self.assertRaisesRegex(TypeError, "not a Protocol"):
+ get_protocol_members(Concrete())
+
+ class ConcreteInherit(P):
+ a: int = 42
+ def b(self) -> str: return "capybara"
+ @property
+ def c(self) -> int: return 5
+
+ with self.assertRaisesRegex(TypeError, "not a Protocol"):
+ get_protocol_members(ConcreteInherit)
+ with self.assertRaisesRegex(TypeError, "not a Protocol"):
+ get_protocol_members(ConcreteInherit())
+
+ def test_is_protocol(self):
+ self.assertTrue(is_protocol(Proto))
+ self.assertTrue(is_protocol(Point))
+ self.assertFalse(is_protocol(Concrete))
+ self.assertFalse(is_protocol(Concrete()))
+ self.assertFalse(is_protocol(Generic))
+ self.assertFalse(is_protocol(object))
+
+ # Protocol is not itself a protocol
+ self.assertFalse(is_protocol(Protocol))
+
def test_interaction_with_isinstance_checks_on_superclasses_with_ABCMeta(self):
# Ensure the cache is empty, or this test won't work correctly
collections.abc.Sized._abc_registry_clear()
'get_args',
'get_origin',
'get_overloads',
+ 'get_protocol_members',
'get_type_hints',
+ 'is_protocol',
'is_typeddict',
'LiteralString',
'Never',
# read-only property, TypeError if it's a builtin class.
pass
return method
+
+
+def is_protocol(tp: type, /) -> bool:
+ """Return True if the given type is a Protocol.
+
+ Example::
+
+ >>> from typing import Protocol, is_protocol
+ >>> class P(Protocol):
+ ... def a(self) -> str: ...
+ ... b: int
+ >>> is_protocol(P)
+ True
+ >>> is_protocol(int)
+ False
+ """
+ return (
+ isinstance(tp, type)
+ and getattr(tp, '_is_protocol', False)
+ and tp != Protocol
+ )
+
+
+def get_protocol_members(tp: type, /) -> frozenset[str]:
+ """Return the set of members defined in a Protocol.
+
+ Example::
+
+ >>> from typing import Protocol, get_protocol_members
+ >>> class P(Protocol):
+ ... def a(self) -> str: ...
+ ... b: int
+ >>> get_protocol_members(P)
+ frozenset({'a', 'b'})
+
+ Raise a TypeError for arguments that are not Protocols.
+ """
+ if not is_protocol(tp):
+ raise TypeError(f'{tp!r} is not a Protocol')
+ return frozenset(tp.__protocol_attrs__)