]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
utils: overloaded functions with optional arguments
authorVasek Sraier <git@vakabus.cz>
Tue, 30 Mar 2021 15:44:05 +0000 (17:44 +0200)
committerAleš Mrázek <ales.mrazek@nic.cz>
Fri, 8 Apr 2022 14:17:52 +0000 (16:17 +0200)
manager/knot_resolver_manager/utils/overload.py
manager/tests/utils/test_overloaded.py

index 2b2b0390e620f43eb128dd8bc96d551a95eea2ad..6ff1f2cc4832fcea645a778947d44e6b1b5c33f9 100644 (file)
@@ -1,4 +1,5 @@
-from typing import Any, Callable, Dict, Generic, Tuple, Type, TypeVar, overload
+from knot_resolver_manager.utils.types import NoneType, get_optional_inner_type, is_optional
+from typing import Any, Callable, Dict, Generic, List, Tuple, Type, TypeVar
 
 T = TypeVar("T")
 
@@ -6,15 +7,32 @@ class OverloadedFunctionException(Exception): pass
 
 class overloaded(Generic[T]):
     def __init__(self):
-        self.vtable: Dict[Tuple[Any], Callable[..., T]] = {}
+        self._vtable: Dict[Tuple[Any], Callable[..., T]] = {}
+
+    @staticmethod
+    def _create_signatures(*types: Any) -> List[Any]:
+        result: List[List[Any]] = [[]]
+        for arg_type in types:
+            if is_optional(arg_type):
+                tp = get_optional_inner_type(arg_type)
+                result = [p + [NoneType] for p in result] + [p + [tp] for p in result]
+            else:
+                result = [p + [arg_type] for p in result]
+        
+        # make tuples
+        return [tuple(x) for x in result]
     
     def add(self, *args: Type[Any], **kwargs: Type[Any]) -> Callable[[Callable[..., T]], Callable[..., T]]:
         if len(kwargs) != 0:
             raise OverloadedFunctionException("Sorry, named arguments are not supported. You can however implement them and make them functional... ;)")
 
         def wrapper(func: Callable[...,T]) -> Callable[...,T]:
-            signature = tuple(args)
-            self.vtable[signature] = func
+            signatures = overloaded._create_signatures(*args)
+            for signature in signatures:
+                if signature in self._vtable:
+                    raise OverloadedFunctionException("Sorry, signature {signature} is already defined. You can't make a second definition of the same signature.")
+                self._vtable[signature] = func
+
             def inner_wrapper(*args: Any, **kwargs: Any) -> T:
                 return self(*args, **kwargs)
             return inner_wrapper
@@ -25,6 +43,11 @@ class overloaded(Generic[T]):
             raise OverloadedFunctionException("Sorry, named arguments are not supported. You can however implement them and make them functional... ;)")
 
         signature = tuple(type(a) for a in args)
-        if signature not in self.vtable:
+        if signature not in self._vtable:
             raise OverloadedFunctionException(f"Function overload with signature {signature} is not registered and can't be called.")
-        return self.vtable[signature](*args)
\ No newline at end of file
+        return self._vtable[signature](*args)
+    
+    def _print_vtable(self):
+        for signature in self._vtable:
+            print(f"{signature} registered")
+        print()
\ No newline at end of file
index 283f79ea7fe9f1cd7df0fb1d1c6bc494b6cf9c69..3a740d8ecbc3136546e615e4fc33f78b8ab105cf 100644 (file)
@@ -1,3 +1,4 @@
+from typing import Optional
 from knot_resolver_manager.utils.overload import overloaded
 
 
@@ -17,4 +18,26 @@ def test_simple():
     f1("test")
     f2(5)
     f1("test")
-    f2(5)
\ No newline at end of file
+    f2(5)
+
+
+def test_optional():
+    func: overloaded[int] = overloaded()
+
+    @func.add(Optional[int], str)
+    def f1(a: Optional[int], b: str) -> int:
+        assert a is None or type(a) == int
+        assert type(b) == str
+        return -1
+    
+    @func.add(Optional[str], int)
+    def f2(a: Optional[str], b: int) -> int:
+        assert a is None or type(a) == str
+        assert type(b) == int
+        return 1
+    
+
+    func(None, 5)
+    func("str", 5)
+    func(None, "str")
+    func(5, "str")
\ No newline at end of file