]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
datamodel: added basic ip&port and path custom data types
authorVasek Sraier <git@vakabus.cz>
Sun, 5 Sep 2021 17:50:33 +0000 (19:50 +0200)
committerAleš Mrázek <ales.mrazek@nic.cz>
Fri, 8 Apr 2022 14:17:52 +0000 (16:17 +0200)
manager/knot_resolver_manager/datamodel/types.py
manager/knot_resolver_manager/utils/data_parser_validator.py
manager/tests/datamodel/test_datamodel_types.py

index 954ccb5ece47214dc9489568158123d4a8fc9b46..2742e3f80375f1570a0680f1b7e3ba5733d1bcbf 100644 (file)
@@ -1,7 +1,13 @@
+import ipaddress
+import logging
 import re
-from typing import Any, Dict, Optional, Pattern, Union
+from pathlib import Path
+from typing import Any, Dict, Optional, Pattern, Union, cast
 
 from knot_resolver_manager.utils import CustomValueType, DataValidationException
+from knot_resolver_manager.utils.data_parser_validator import DataParser
+
+logger = logging.getLogger(__name__)
 
 
 class Unit(CustomValueType):
@@ -58,3 +64,80 @@ class SizeUnit(Unit):
 class TimeUnit(Unit):
     _re = re.compile(r"^(\d+)\s{0,1}([smhd]){0,1}$")
     _units = {None: 1, "s": 1, "m": 60, "h": 3600, "d": 24 * 3600}
+
+
+class AnyPath(CustomValueType):
+    def __init__(self, source_value: Any) -> None:
+        super().__init__(source_value)
+        if not isinstance(source_value, str):
+            raise DataValidationException(f"Expected file path in a string, got '{source_value}'")
+        self._value: Path = Path(source_value)
+
+        try:
+            self._value = self._value.resolve(strict=False)
+        except RuntimeError as e:
+            raise DataValidationException("Failed to resolve given file path. Is there a symlink loop?") from e
+
+    def __str__(self) -> str:
+        return str(self._value)
+
+    def __eq__(self, _o: object) -> bool:
+        raise RuntimeError("Path's cannot be simply compared for equality")
+
+    def __int__(self) -> int:
+        raise RuntimeError("Path cannot be converted to type <int>")
+
+    def to_path(self) -> Path:
+        return self._value
+
+
+class _IPAndPortData(DataParser):
+    ip: str
+    port: int
+
+
+class IPAndPort(CustomValueType):
+    """
+    IP and port. Supports two formats:
+      1. string in the form of 'ip@port'
+      2. object with string field 'ip' and numeric field 'port'
+    """
+
+    def __init__(self, source_value: Any) -> None:
+        super().__init__(source_value)
+
+        # parse values from object
+        if isinstance(source_value, dict):
+            obj = _IPAndPortData(cast(Dict[Any, Any], source_value))
+            ip = obj.ip
+            port = obj.port
+
+        # parse values from string
+        elif isinstance(source_value, str):
+            if "@" not in source_value:
+                raise DataValidationException("Expected ip and port in format 'ip@port'. Missing '@'")
+            ip, port_str = source_value.split(maxsplit=1, sep="@")
+            try:
+                port = int(port_str)
+            except ValueError:
+                raise DataValidationException(f"Failed to parse port number from string '{port_str}'")
+        else:
+            raise DataValidationException(
+                "Expected IP and port as an object or as a string 'ip@port'," f" got '{source_value}'"
+            )
+
+        # validate port value range
+        if not (0 <= port <= 65_535):
+            raise DataValidationException(f"Port value {port} out of range of usual 2-byte port value")
+
+        try:
+            self.ip: Union[ipaddress.IPv4Address, ipaddress.IPv6Address] = ipaddress.ip_address(ip)
+        except ValueError as e:
+            raise DataValidationException(f"Failed to parse IP address from string '{ip}'") from e
+        self.port: int = port
+
+    def __str__(self) -> str:
+        """
+        Returns value in 'ip@port' format
+        """
+        return f"{self.ip}@{self.port}"
index 72213eff81d7c4c1e512935828c58eaab970b0bd..20ee1b8328c2fed6424a3394faeec9ed49b124bf 100644 (file)
@@ -266,7 +266,7 @@ _SUBTREE_MUTATION_PATH_PATTERN = re.compile(r"^(/[^/]+)*/?$")
 
 
 class DataParser:
-    def __init__(self, obj: Optional[Dict[str, Any]] = None):
+    def __init__(self, obj: Optional[Dict[Any, Any]] = None):
         cls = self.__class__
         annot = cls.__dict__.get("__annotations__", {})
 
index e5321f4f5f967337b60735b88f6f276d86cc1834..0d1e689e38683ab3b6bac2e61f5072a42b0b753d 100644 (file)
@@ -1,6 +1,8 @@
+import ipaddress
+
 from pytest import raises
 
-from knot_resolver_manager.datamodel.types import SizeUnit, TimeUnit
+from knot_resolver_manager.datamodel.types import AnyPath, IPAndPort, SizeUnit, TimeUnit
 from knot_resolver_manager.utils import DataParser, DataValidationException, DataValidator
 
 
@@ -64,3 +66,30 @@ time: 10m
     b = TestClass.from_json(j)
     assert a.size == b.size == obj.size
     assert a.time == b.time == obj.time
+
+
+def test_ipandport():
+    class Data(DataParser):
+        o: IPAndPort
+        s: IPAndPort
+
+    val = """
+    o:
+      ip: "::"
+      port: 590
+    s: 127.0.0.1@5656
+    """
+
+    val = Data.from_yaml(val)
+
+    assert val.o.port == 590
+    assert val.o.ip == ipaddress.ip_address("::")
+    assert val.s.port == 5656
+    assert val.s.ip == ipaddress.ip_address("127.0.0.1")
+
+
+def test_anypath():
+    class Data(DataParser):
+        p: AnyPath
+
+    assert str(Data.from_yaml('p: "/tmp"').p) == "/tmp"