From: Vasek Sraier Date: Tue, 30 Mar 2021 14:51:57 +0000 (+0200) Subject: utils: simple function overloading helper X-Git-Tag: v6.0.0a1~190 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=efdb04ddfa571c7991cc8b84c25312fb4f3f51aa;p=thirdparty%2Fknot-resolver.git utils: simple function overloading helper --- diff --git a/manager/knot_resolver_manager/utils/overload.py b/manager/knot_resolver_manager/utils/overload.py new file mode 100644 index 000000000..2b2b0390e --- /dev/null +++ b/manager/knot_resolver_manager/utils/overload.py @@ -0,0 +1,30 @@ +from typing import Any, Callable, Dict, Generic, Tuple, Type, TypeVar, overload + +T = TypeVar("T") + +class OverloadedFunctionException(Exception): pass + +class overloaded(Generic[T]): + def __init__(self): + self.vtable: Dict[Tuple[Any], Callable[..., T]] = {} + + def add(self, *args: Type[Any], **kwargs: Type[Any]) -> Callable[[Callable[..., T]], Callable[..., T]]: + if len(kwargs) != 0: + raise OverloadedFunctionException("Sorry, named arguments are not supported. You can however implement them and make them functional... ;)") + + def wrapper(func: Callable[...,T]) -> Callable[...,T]: + signature = tuple(args) + self.vtable[signature] = func + def inner_wrapper(*args: Any, **kwargs: Any) -> T: + return self(*args, **kwargs) + return inner_wrapper + return wrapper + + def __call__(self, *args: Any, **kwargs: Any) -> T: + if len(kwargs) != 0: + raise OverloadedFunctionException("Sorry, named arguments are not supported. You can however implement them and make them functional... ;)") + + signature = tuple(type(a) for a in args) + if signature not in self.vtable: + raise OverloadedFunctionException(f"Function overload with signature {signature} is not registered and can't be called.") + return self.vtable[signature](*args) \ No newline at end of file diff --git a/manager/tests/utils/test_overloaded.py b/manager/tests/utils/test_overloaded.py new file mode 100644 index 000000000..283f79ea7 --- /dev/null +++ b/manager/tests/utils/test_overloaded.py @@ -0,0 +1,20 @@ +from knot_resolver_manager.utils.overload import overloaded + + +def test_simple(): + func: overloaded[None] = overloaded() + + @func.add(int) + def f1(a: int) -> None: + assert type(a) == int + + @func.add(str) + def f2(a: str) -> None: + assert type(a) == str + + func("test") + func(5) + f1("test") + f2(5) + f1("test") + f2(5) \ No newline at end of file