# Remove entries which are already present in the __mro__ or unrelated.
def is_related(typ):
return (typ not in bases and hasattr(typ, '__mro__')
+ and not isinstance(typ, GenericAlias)
and issubclass(cls, typ))
types = [n for n in types if is_related(n)]
# Remove entries which are strict bases of other entries (they will end up
from typing import get_origin, Union
return get_origin(cls) in {Union, types.UnionType}
- def _is_valid_union_type(cls):
+ def _is_valid_dispatch_type(cls):
+ if isinstance(cls, type) and not isinstance(cls, GenericAlias):
+ return True
from typing import get_args
- return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))
+ return (_is_union_type(cls) and
+ all(isinstance(arg, type) and not isinstance(arg, GenericAlias)
+ for arg in get_args(cls)))
def register(cls, func=None):
"""generic_func.register(cls, func) -> func
"""
nonlocal cache_token
- if func is None:
- if isinstance(cls, type) or _is_valid_union_type(cls):
+ if _is_valid_dispatch_type(cls):
+ if func is None:
return lambda f: register(cls, f)
+ else:
+ if func is not None:
+ raise TypeError(
+ f"Invalid first argument to `register()`. "
+ f"{cls!r} is not a class or union type."
+ )
ann = getattr(cls, '__annotations__', {})
if not ann:
raise TypeError(
# only import typing if annotation parsing is necessary
from typing import get_type_hints
argname, cls = next(iter(get_type_hints(func).items()))
- if not isinstance(cls, type) and not _is_valid_union_type(cls):
+ if not _is_valid_dispatch_type(cls):
if _is_union_type(cls):
raise TypeError(
f"Invalid annotation for {argname!r}. "
self.assertEqual(f(1), "types.UnionType")
self.assertEqual(f(1.0), "types.UnionType")
+ def test_register_genericalias(self):
+ @functools.singledispatch
+ def f(arg):
+ return "default"
+
+ with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+ f.register(list[int], lambda arg: "types.GenericAlias")
+ with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+ f.register(typing.List[int], lambda arg: "typing.GenericAlias")
+ with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+ f.register(list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)")
+ with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+ f.register(typing.List[float] | bytes, lambda arg: "typing.Union[typing.GenericAlias]")
+ with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+ f.register(typing.Any, lambda arg: "typing.Any")
+
+ self.assertEqual(f([1]), "default")
+ self.assertEqual(f([1.0]), "default")
+ self.assertEqual(f(""), "default")
+ self.assertEqual(f(b""), "default")
+
+ def test_register_genericalias_decorator(self):
+ @functools.singledispatch
+ def f(arg):
+ return "default"
+
+ with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+ f.register(list[int])
+ with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+ f.register(typing.List[int])
+ with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+ f.register(list[int] | str)
+ with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+ f.register(typing.List[int] | str)
+ with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+ f.register(typing.Any)
+
+ def test_register_genericalias_annotation(self):
+ @functools.singledispatch
+ def f(arg):
+ return "default"
+
+ with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
+ @f.register
+ def _(arg: list[int]):
+ return "types.GenericAlias"
+ with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
+ @f.register
+ def _(arg: typing.List[float]):
+ return "typing.GenericAlias"
+ with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
+ @f.register
+ def _(arg: list[int] | str):
+ return "types.UnionType(types.GenericAlias)"
+ with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
+ @f.register
+ def _(arg: typing.List[float] | bytes):
+ return "typing.Union[typing.GenericAlias]"
+ with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
+ @f.register
+ def _(arg: typing.Any):
+ return "typing.Any"
+
+ self.assertEqual(f([1]), "default")
+ self.assertEqual(f([1.0]), "default")
+ self.assertEqual(f(""), "default")
+ self.assertEqual(f(b""), "default")
+
class CachedCostItem:
_cost = 1