From: Vasek Sraier Date: Tue, 30 Mar 2021 15:44:05 +0000 (+0200) Subject: utils: overloaded functions with optional arguments X-Git-Tag: v6.0.0a1~189 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=dddb980065ef2ba62d627cb641648f802bbd00ab;p=thirdparty%2Fknot-resolver.git utils: overloaded functions with optional arguments --- diff --git a/manager/knot_resolver_manager/utils/overload.py b/manager/knot_resolver_manager/utils/overload.py index 2b2b0390e..6ff1f2cc4 100644 --- a/manager/knot_resolver_manager/utils/overload.py +++ b/manager/knot_resolver_manager/utils/overload.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Dict, Generic, Tuple, Type, TypeVar, overload +from knot_resolver_manager.utils.types import NoneType, get_optional_inner_type, is_optional +from typing import Any, Callable, Dict, Generic, List, Tuple, Type, TypeVar T = TypeVar("T") @@ -6,15 +7,32 @@ class OverloadedFunctionException(Exception): pass class overloaded(Generic[T]): def __init__(self): - self.vtable: Dict[Tuple[Any], Callable[..., T]] = {} + self._vtable: Dict[Tuple[Any], Callable[..., T]] = {} + + @staticmethod + def _create_signatures(*types: Any) -> List[Any]: + result: List[List[Any]] = [[]] + for arg_type in types: + if is_optional(arg_type): + tp = get_optional_inner_type(arg_type) + result = [p + [NoneType] for p in result] + [p + [tp] for p in result] + else: + result = [p + [arg_type] for p in result] + + # make tuples + return [tuple(x) for x in result] def add(self, *args: Type[Any], **kwargs: Type[Any]) -> Callable[[Callable[..., T]], Callable[..., T]]: if len(kwargs) != 0: raise OverloadedFunctionException("Sorry, named arguments are not supported. You can however implement them and make them functional... ;)") def wrapper(func: Callable[...,T]) -> Callable[...,T]: - signature = tuple(args) - self.vtable[signature] = func + signatures = overloaded._create_signatures(*args) + for signature in signatures: + if signature in self._vtable: + raise OverloadedFunctionException("Sorry, signature {signature} is already defined. You can't make a second definition of the same signature.") + self._vtable[signature] = func + def inner_wrapper(*args: Any, **kwargs: Any) -> T: return self(*args, **kwargs) return inner_wrapper @@ -25,6 +43,11 @@ class overloaded(Generic[T]): raise OverloadedFunctionException("Sorry, named arguments are not supported. You can however implement them and make them functional... ;)") signature = tuple(type(a) for a in args) - if signature not in self.vtable: + if signature not in self._vtable: raise OverloadedFunctionException(f"Function overload with signature {signature} is not registered and can't be called.") - return self.vtable[signature](*args) \ No newline at end of file + return self._vtable[signature](*args) + + def _print_vtable(self): + for signature in self._vtable: + print(f"{signature} registered") + print() \ No newline at end of file diff --git a/manager/tests/utils/test_overloaded.py b/manager/tests/utils/test_overloaded.py index 283f79ea7..3a740d8ec 100644 --- a/manager/tests/utils/test_overloaded.py +++ b/manager/tests/utils/test_overloaded.py @@ -1,3 +1,4 @@ +from typing import Optional from knot_resolver_manager.utils.overload import overloaded @@ -17,4 +18,26 @@ def test_simple(): f1("test") f2(5) f1("test") - f2(5) \ No newline at end of file + f2(5) + + +def test_optional(): + func: overloaded[int] = overloaded() + + @func.add(Optional[int], str) + def f1(a: Optional[int], b: str) -> int: + assert a is None or type(a) == int + assert type(b) == str + return -1 + + @func.add(Optional[str], int) + def f2(a: Optional[str], b: int) -> int: + assert a is None or type(a) == str + assert type(b) == int + return 1 + + + func(None, 5) + func("str", 5) + func(None, "str") + func(5, "str") \ No newline at end of file