dispatch_cache[cls] = impl
return impl
+ def _is_union_type(cls):
+ from typing import get_origin, Union
+ return get_origin(cls) in {Union, types.UnionType}
+
+ def _is_valid_union_type(cls):
+ from typing import get_args
+ return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))
+
def register(cls, func=None):
"""generic_func.register(cls, func) -> func
"""
nonlocal cache_token
if func is None:
- if isinstance(cls, type):
+ if isinstance(cls, type) or _is_valid_union_type(cls):
return lambda f: register(cls, f)
ann = getattr(cls, '__annotations__', {})
if not ann:
# only import typing if annotation parsing is necessary
from typing import get_type_hints
argname, cls = next(iter(get_type_hints(func).items()))
- if not isinstance(cls, type):
- raise TypeError(
- f"Invalid annotation for {argname!r}. "
- f"{cls!r} is not a class."
- )
- registry[cls] = func
+ if not isinstance(cls, type) and not _is_valid_union_type(cls):
+ if _is_union_type(cls):
+ raise TypeError(
+ f"Invalid annotation for {argname!r}. "
+ f"{cls!r} not all arguments are classes."
+ )
+ else:
+ raise TypeError(
+ f"Invalid annotation for {argname!r}. "
+ f"{cls!r} is not a class."
+ )
+
+ if _is_union_type(cls):
+ from typing import get_args
+
+ for arg in get_args(cls):
+ registry[arg] = func
+ else:
+ registry[cls] = func
if cache_token is None and hasattr(cls, '__abstractmethods__'):
cache_token = get_cache_token()
dispatch_cache.clear()
'typing.Iterable[str] is not a class.'
))
+ with self.assertRaises(TypeError) as exc:
+ @i.register
+ def _(arg: typing.Union[int, typing.Iterable[str]]):
+ return "Invalid Union"
+ self.assertTrue(str(exc.exception).startswith(
+ "Invalid annotation for 'arg'."
+ ))
+ self.assertTrue(str(exc.exception).endswith(
+ 'typing.Union[int, typing.Iterable[str]] not all arguments are classes.'
+ ))
+
def test_invalid_positional_argument(self):
@functools.singledispatch
def f(*args):
with self.assertRaisesRegex(TypeError, msg):
f()
+ def test_union(self):
+ @functools.singledispatch
+ def f(arg):
+ return "default"
+
+ @f.register
+ def _(arg: typing.Union[str, bytes]):
+ return "typing.Union"
+
+ @f.register
+ def _(arg: int | float):
+ return "types.UnionType"
+
+ self.assertEqual(f([]), "default")
+ self.assertEqual(f(""), "typing.Union")
+ self.assertEqual(f(b""), "typing.Union")
+ self.assertEqual(f(1), "types.UnionType")
+ self.assertEqual(f(1.0), "types.UnionType")
+
class CachedCostItem:
_cost = 1