]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
manager: datamodel: base types refactored
authorAleš Mrázek <ales.mrazek@nic.cz>
Tue, 20 Jun 2023 09:20:15 +0000 (11:20 +0200)
committerAleš Mrázek <ales.mrazek@nic.cz>
Thu, 13 Jul 2023 07:50:09 +0000 (09:50 +0200)
manager/knot_resolver_manager/datamodel/cache_schema.py
manager/knot_resolver_manager/datamodel/types/base_types.py
manager/knot_resolver_manager/datamodel/types/types.py
manager/tests/unit/datamodel/types/test_custom_types.py

index 61341fc3bb41e5415f0a4abe0600b9ed55f11907..8adc41bea737ae44b4b16d1efb1fc2190c5141ae 100644 (file)
@@ -47,11 +47,11 @@ class GarbageCollectorSchema(ConfigSchema):
     interval: TimeUnit = TimeUnit("1s")
     threshold: Percent = Percent(80)
     release: Percent = Percent(10)
-    temp_keys_space: SizeUnit = SizeUnit(0)
+    temp_keys_space: SizeUnit = SizeUnit("0M")
     rw_deletes: IntNonNegative = IntNonNegative(100)
     rw_reads: IntNonNegative = IntNonNegative(200)
-    rw_duration: TimeUnit = TimeUnit(0)
-    rw_delay: TimeUnit = TimeUnit(0)
+    rw_duration: TimeUnit = TimeUnit("0us")
+    rw_delay: TimeUnit = TimeUnit("0us")
     dry_run: bool = False
 
 
index 96c0a3938ed568979bf04e1bbbd03f77d5dfbab9..b3184fc772336a4722f31306c0fb88599c61dfa6 100644 (file)
@@ -9,19 +9,35 @@ class IntBase(BaseValueType):
     Base class to work with integer value.
     """
 
+    _orig_value: int
     _value: int
 
+    def __init__(self, source_value: Any, object_path: str = "/") -> None:
+        super().__init__(source_value, object_path)
+        if isinstance(source_value, int) and not isinstance(source_value, bool):
+            self._orig_value = source_value
+            self._value = source_value
+        else:
+            raise ValueError(
+                f"Unexpected value for '{type(self)}'."
+                f" Expected integer, got '{source_value}' with type '{type(source_value)}'",
+                object_path,
+            )
+
     def __int__(self) -> int:
         return self._value
 
     def __str__(self) -> str:
         return str(self._value)
 
+    def __repr__(self) -> str:
+        return f'{type(self).__name__}("{self._value}")'
+
     def __eq__(self, o: object) -> bool:
         return isinstance(o, IntBase) and o._value == self._value
 
     def serialize(self) -> Any:
-        return self._value
+        return self._orig_value
 
     @classmethod
     def json_schema(cls: Type["IntBase"]) -> Dict[Any, Any]:
@@ -33,16 +49,29 @@ class StrBase(BaseValueType):
     Base class to work with string value.
     """
 
+    _orig_value: str
     _value: str
 
+    def __init__(self, source_value: Any, object_path: str = "/") -> None:
+        super().__init__(source_value, object_path)
+        if isinstance(source_value, (str, int)) and not isinstance(source_value, bool):
+            self._orig_value = str(source_value)
+            self._value = str(source_value)
+        else:
+            raise ValueError(
+                f"Unexpected value for '{type(self)}'."
+                f" Expected string, got '{source_value}' with type '{type(source_value)}'",
+                object_path,
+            )
+
     def __int__(self) -> int:
         raise ValueError("Can't convert string to an integer.")
 
     def __str__(self) -> str:
         return self._value
 
-    def to_std(self) -> str:
-        return self._value
+    def __repr__(self) -> str:
+        return f'{type(self).__name__}("{self._value}")'
 
     def __hash__(self) -> int:
         return hash(self._value)
@@ -51,7 +80,7 @@ class StrBase(BaseValueType):
         return isinstance(o, StrBase) and o._value == self._value
 
     def serialize(self) -> Any:
-        return self._value
+        return self._orig_value
 
     @classmethod
     def json_schema(cls: Type["StrBase"]) -> Dict[Any, Any]:
@@ -71,18 +100,11 @@ class IntRangeBase(IntBase):
     _max: int
 
     def __init__(self, source_value: Any, object_path: str = "/") -> None:
-        super().__init__(source_value)
-        if isinstance(source_value, int) and not isinstance(source_value, bool):
-            if hasattr(self, "_min") and (source_value < self._min):
-                raise ValueError(f"value {source_value} is lower than the minimum {self._min}.")
-            if hasattr(self, "_max") and (source_value > self._max):
-                raise ValueError(f"value {source_value} is higher than the maximum {self._max}")
-            self._value = source_value
-        else:
-            raise ValueError(
-                f"expected integer, got '{type(source_value)}'",
-                object_path,
-            )
+        super().__init__(source_value, object_path)
+        if hasattr(self, "_min") and (self._value < self._min):
+            raise ValueError(f"value {self._value} is lower than the minimum {self._min}.", object_path)
+        if hasattr(self, "_max") and (self._value > self._max):
+            raise ValueError(f"value {self._value} is higher than the maximum {self._max}", object_path)
 
     @classmethod
     def json_schema(cls: Type["IntRangeBase"]) -> Dict[Any, Any]:
@@ -106,24 +128,16 @@ class PatternBase(StrBase):
     _re: Pattern[str]
 
     def __init__(self, source_value: Any, object_path: str = "/") -> None:
-        super().__init__(source_value)
-        if isinstance(source_value, str):
-            if type(self)._re.match(source_value):
-                self._value: str = source_value
-            else:
-                raise ValueError(f"'{source_value}' does not match '{self._re.pattern}' pattern")
-        else:
-            raise ValueError(
-                f"expected string, got '{type(source_value)}'",
-                object_path,
-            )
+        super().__init__(source_value, object_path)
+        if not type(self)._re.match(self._value):
+            raise ValueError(f"'{self._value}' does not match '{self._re.pattern}' pattern", object_path)
 
     @classmethod
     def json_schema(cls: Type["PatternBase"]) -> Dict[Any, Any]:
         return {"type": "string", "pattern": rf"{cls._re.pattern}"}
 
 
-class UnitBase(IntBase):
+class UnitBase(StrBase):
     """
     Base class to work with string value that match regex pattern.
     Just inherit the class and set '_units'.
@@ -134,49 +148,37 @@ class UnitBase(IntBase):
 
     _re: Pattern[str]
     _units: Dict[str, int]
-    _value_orig: str
+    _base_value: int
 
     def __init__(self, source_value: Any, object_path: str = "/") -> None:
-        super().__init__(source_value)
+        super().__init__(source_value, object_path)
+
         type(self)._re = re.compile(rf"^(\d+)({r'|'.join(type(self)._units.keys())})$")
-        if isinstance(source_value, str) and self._re.match(source_value):
-            self._value_orig = source_value
-            grouped = self._re.search(source_value)
-            if grouped:
-                val, unit = grouped.groups()
-                if unit is None:
-                    raise ValueError(f"Missing units. Accepted units are {list(type(self)._units.keys())}")
-                elif unit not in type(self)._units:
-                    raise ValueError(
-                        f"Used unexpected unit '{unit}' for {type(self).__name__}."
-                        f" Accepted units are {list(type(self)._units.keys())}",
-                        object_path,
-                    )
-                self._value = int(val) * type(self)._units[unit]
-            else:
-                raise ValueError(f"{type(self._value)} Failed to convert: {self}")
-        elif source_value in (0, "0"):
-            self._value_orig = source_value
-            self._value = int(source_value)
-        elif isinstance(source_value, int):
-            raise ValueError(
-                f"number without units, please convert to string and add unit  - {list(type(self)._units.keys())}",
-                object_path,
-            )
+        grouped = self._re.search(self._value)
+        if grouped:
+            val, unit = grouped.groups()
+            if unit is None:
+                raise ValueError(f"Missing units. Accepted units are {list(type(self)._units.keys())}", object_path)
+            elif unit not in type(self)._units:
+                raise ValueError(
+                    f"Used unexpected unit '{unit}' for {type(self).__name__}."
+                    f" Accepted units are {list(type(self)._units.keys())}",
+                    object_path,
+                )
+            self._base_value = int(val) * type(self)._units[unit]
         else:
             raise ValueError(
-                f"expected number with units in a string, got '{type(source_value)}'.",
+                f"Unexpected value for '{type(self)}'."
+                " Expected string that matches pattern " + rf"'{type(self)._re.pattern}'."
+                f" Positive integer and one of the units {list(type(self)._units.keys())}, got '{source_value}'.",
                 object_path,
             )
 
-    def __str__(self) -> str:
-        """
-        Used by Jinja2. Must return only a number.
-        """
-        return str(self._value_orig)
+    def __int__(self) -> int:
+        return self._base_value
 
     def __repr__(self) -> str:
-        return f"Unit[{type(self).__name__},{self._value_orig}]"
+        return f"Unit[{type(self).__name__},{self._value}]"
 
     def __eq__(self, o: object) -> bool:
         """
@@ -186,7 +188,7 @@ class UnitBase(IntBase):
         return isinstance(o, UnitBase) and o._value == self._value
 
     def serialize(self) -> Any:
-        return self._value_orig
+        return self._orig_value
 
     @classmethod
     def json_schema(cls: Type["UnitBase"]) -> Dict[Any, Any]:
index 14be0122da3b9da0f86d0157932e4cf3875f6936..2b409e7e563ec2f5cc1fff7ff8f935bf851ea865 100644 (file)
@@ -45,23 +45,23 @@ class SizeUnit(UnitBase):
     _units = {"B": 1, "K": 1024, "M": 1024**2, "G": 1024**3}
 
     def bytes(self) -> int:
-        return self._value
+        return self._base_value
 
     def mbytes(self) -> int:
-        return self._value // 1024**2
+        return self._base_value // 1024**2
 
 
 class TimeUnit(UnitBase):
     _units = {"us": 1, "ms": 10**3, "s": 10**6, "m": 60 * 10**6, "h": 3600 * 10**6, "d": 24 * 3600 * 10**6}
 
     def seconds(self) -> int:
-        return self._value // 1000**2
+        return self._base_value // 1000**2
 
     def millis(self) -> int:
-        return self._value // 1000
+        return self._base_value // 1000
 
     def micros(self) -> int:
-        return self._value
+        return self._base_value
 
 
 class DomainName(StrBase):
@@ -81,28 +81,20 @@ class DomainName(StrBase):
     )
 
     def __init__(self, source_value: Any, object_path: str = "/") -> None:
-        super().__init__(source_value)
-        if isinstance(source_value, str):
-            try:
-                punycode = source_value.encode("idna").decode("utf-8") if source_value != "." else "."
-            except ValueError:
-                raise ValueError(
-                    f"conversion of '{source_value}' to IDN punycode representation failed",
-                    object_path,
-                )
-
-            if type(self)._re.match(punycode):
-                self._value = source_value
-                self._punycode = punycode
-            else:
-                raise ValueError(
-                    f"'{source_value}' represented in punycode '{punycode}' does not match '{self._re.pattern}' pattern",
-                    object_path,
-                )
+        super().__init__(source_value, object_path)
+        try:
+            punycode = self._value.encode("idna").decode("utf-8") if self._value != "." else "."
+        except ValueError:
+            raise ValueError(
+                f"conversion of '{self._value}' to IDN punycode representation failed",
+                object_path,
+            )
+
+        if type(self)._re.match(punycode):
+            self._punycode = punycode
         else:
             raise ValueError(
-                "Unexpected value for '<domain-name>'."
-                f" Expected string, got '{source_value}' with type '{type(source_value)}'",
+                f"'{source_value}' represented in punycode '{punycode}' does not match '{self._re.pattern}' pattern",
                 object_path,
             )
 
@@ -120,6 +112,10 @@ class DomainName(StrBase):
 
 
 class InterfaceName(PatternBase):
+    """
+    Network interface name.
+    """
+
     _re = re.compile(r"^[a-zA-Z0-9]+(?:[-_][a-zA-Z0-9]+)*$")
 
 
@@ -145,27 +141,22 @@ class InterfacePort(StrBase):
     port: PortNumber
 
     def __init__(self, source_value: Any, object_path: str = "/") -> None:
-        super().__init__(source_value)
-        if isinstance(source_value, str):
-            parts = source_value.split("@")
-            if len(parts) == 2:
+        super().__init__(source_value, object_path)
+
+        parts = self._value.split("@")
+        if len(parts) == 2:
+            try:
+                self.addr = ipaddress.ip_address(parts[0])
+            except ValueError as e1:
                 try:
-                    self.addr = ipaddress.ip_address(parts[0])
-                except ValueError as e1:
-                    try:
-                        self.if_name = InterfaceName(parts[0])
-                    except ValueError as e2:
-                        raise ValueError(f"expected IP address or interface name, got '{parts[0]}'.") from e1 and e2
-                self.port = PortNumber.from_str(parts[1], object_path)
-            else:
-                raise ValueError(f"expected '<ip-address|interface-name>@<port>', got '{source_value}'.")
-            self._value = source_value
+                    self.if_name = InterfaceName(parts[0])
+                except ValueError as e2:
+                    raise ValueError(
+                        f"expected IP address or interface name, got '{parts[0]}'.", object_path
+                    ) from e1 and e2
+            self.port = PortNumber.from_str(parts[1], object_path)
         else:
-            raise ValueError(
-                "Unexpected value for '<ip-address|interface-name>@<port>'."
-                f" Expected string, got '{source_value}' with type '{type(source_value)}'",
-                object_path,
-            )
+            raise ValueError(f"expected '<ip-address|interface-name>@<port>', got '{source_value}'.", object_path)
 
 
 class InterfaceOptionalPort(StrBase):
@@ -174,28 +165,23 @@ class InterfaceOptionalPort(StrBase):
     port: Optional[PortNumber] = None
 
     def __init__(self, source_value: Any, object_path: str = "/") -> None:
-        super().__init__(source_value)
-        if isinstance(source_value, str):
-            parts = source_value.split("@")
-            if 0 < len(parts) < 3:
+        super().__init__(source_value, object_path)
+
+        parts = self._value.split("@")
+        if 0 < len(parts) < 3:
+            try:
+                self.addr = ipaddress.ip_address(parts[0])
+            except ValueError as e1:
                 try:
-                    self.addr = ipaddress.ip_address(parts[0])
-                except ValueError as e1:
-                    try:
-                        self.if_name = InterfaceName(parts[0])
-                    except ValueError as e2:
-                        raise ValueError(f"expected IP address or interface name, got '{parts[0]}'.") from e1 and e2
-                if len(parts) == 2:
-                    self.port = PortNumber.from_str(parts[1], object_path)
-            else:
-                raise ValueError(f"expected '<ip-address|interface-name>[@<port>]', got '{parts}'.")
-            self._value = source_value
+                    self.if_name = InterfaceName(parts[0])
+                except ValueError as e2:
+                    raise ValueError(
+                        f"expected IP address or interface name, got '{parts[0]}'.", object_path
+                    ) from e1 and e2
+            if len(parts) == 2:
+                self.port = PortNumber.from_str(parts[1], object_path)
         else:
-            raise ValueError(
-                "Unexpected value for '<ip-address|interface-name>[@<port>]'."
-                f" Expected string, got '{source_value}' with type '{type(source_value)}'",
-                object_path,
-            )
+            raise ValueError(f"expected '<ip-address|interface-name>[@<port>]', got '{parts}'.", object_path)
 
 
 class IPAddressPort(StrBase):
@@ -203,23 +189,17 @@ class IPAddressPort(StrBase):
     port: PortNumber
 
     def __init__(self, source_value: Any, object_path: str = "/") -> None:
-        super().__init__(source_value)
-        if isinstance(source_value, str):
-            parts = source_value.split("@")
-            if len(parts) == 2:
-                self.port = PortNumber.from_str(parts[1], object_path)
-                try:
-                    self.addr = ipaddress.ip_address(parts[0])
-                except ValueError as e:
-                    raise ValueError(f"failed to parse IP address '{parts[0]}'.") from e
-            else:
-                raise ValueError(f"expected '<ip-address>@<port>', got '{source_value}'.")
-            self._value = source_value
+        super().__init__(source_value, object_path)
+
+        parts = self._value.split("@")
+        if len(parts) == 2:
+            self.port = PortNumber.from_str(parts[1], object_path)
+            try:
+                self.addr = ipaddress.ip_address(parts[0])
+            except ValueError as e:
+                raise ValueError(f"failed to parse IP address '{parts[0]}'.", object_path) from e
         else:
-            raise ValueError(
-                "Unexpected value for '<ip-address>@<port>'."
-                f" Expected string, got '{source_value}' with type '{type(source_value)}'"
-            )
+            raise ValueError(f"expected '<ip-address>@<port>', got '{source_value}'.", object_path)
 
 
 class IPAddressOptionalPort(StrBase):
@@ -228,24 +208,16 @@ class IPAddressOptionalPort(StrBase):
 
     def __init__(self, source_value: Any, object_path: str = "/") -> None:
         super().__init__(source_value)
-        if isinstance(source_value, str):
-            parts = source_value.split("@")
-            if 0 < len(parts) < 3:
-                try:
-                    self.addr = ipaddress.ip_address(parts[0])
-                except ValueError as e:
-                    raise ValueError(f"failed to parse IP address '{parts[0]}'.") from e
-                if len(parts) == 2:
-                    self.port = PortNumber.from_str(parts[1], object_path)
-            else:
-                raise ValueError(f"expected '<ip-address>[@<port>]', got '{parts}'.")
-            self._value = source_value
+        parts = source_value.split("@")
+        if 0 < len(parts) < 3:
+            try:
+                self.addr = ipaddress.ip_address(parts[0])
+            except ValueError as e:
+                raise ValueError(f"failed to parse IP address '{parts[0]}'.", object_path) from e
+            if len(parts) == 2:
+                self.port = PortNumber.from_str(parts[1], object_path)
         else:
-            raise ValueError(
-                "Unexpected value for a '<ip-address>[@<port>]'."
-                f" Expected string, got '{source_value}' with type '{type(source_value)}'",
-                object_path,
-            )
+            raise ValueError(f"expected '<ip-address>[@<port>]', got '{parts}'.", object_path)
 
 
 class IPv4Address(BaseValueType):
@@ -274,6 +246,9 @@ class IPv4Address(BaseValueType):
     def __int__(self) -> int:
         raise ValueError("Can't convert IPv4 address to an integer")
 
+    def __repr__(self) -> str:
+        return f'{type(self).__name__}("{self._value}")'
+
     def __eq__(self, o: object) -> bool:
         """
         Two instances of IPv4Address are equal when they represent same IPv4 address as string.
@@ -316,6 +291,9 @@ class IPv6Address(BaseValueType):
     def __int__(self) -> int:
         raise ValueError("Can't convert IPv6 address to an integer")
 
+    def __repr__(self) -> str:
+        return f'{type(self).__name__}("{self._value}")'
+
     def __eq__(self, o: object) -> bool:
         """
         Two instances of IPv6Address are equal when they represent same IPv6 address as string.
@@ -351,14 +329,17 @@ class IPNetwork(BaseValueType):
                 f" Expected string, got '{source_value}' with type '{type(source_value)}'"
             )
 
-    def to_std(self) -> Union[ipaddress.IPv4Network, ipaddress.IPv6Network]:
-        return self._value
+    def __int__(self) -> int:
+        raise ValueError("Can't convert network prefix to an integer")
 
     def __str__(self) -> str:
         return self._value.with_prefixlen
 
-    def __int__(self) -> int:
-        raise ValueError("Can't convert network prefix to an integer")
+    def __repr__(self) -> str:
+        return f'{type(self).__name__}("{self._value}")'
+
+    def to_std(self) -> Union[ipaddress.IPv4Network, ipaddress.IPv6Network]:
+        return self._value
 
     def serialize(self) -> Any:
         return self._value.with_prefixlen
@@ -395,6 +376,9 @@ class IPv6Network(BaseValueType):
     def __int__(self) -> int:
         raise ValueError("Can't convert network prefix to an integer")
 
+    def __repr__(self) -> str:
+        return f'{type(self).__name__}("{self._value}")'
+
     def __eq__(self, o: object) -> bool:
         return isinstance(o, IPv6Network) and o._value == self._value
 
index 31d9cd2fff1bf0b87e8b0f25dbc1ebc6049cf452..3e1f5c61d2e010bc7d31eb137219aec4eec9d758 100644 (file)
@@ -57,7 +57,7 @@ def test_size_unit_invalid(val: Any):
         SizeUnit(val)
 
 
-@pytest.mark.parametrize("val", ["1d", "24h", "1440m", "86400s", "86400000ms", "86400000000us"])
+@pytest.mark.parametrize("val", ["1d", "24h", "1440m", "86400s", "86400000ms"])
 def test_time_unit_valid(val: str):
     o = TimeUnit(val)
     assert int(o) == 86400000000
@@ -79,8 +79,8 @@ def test_parsing_units():
         time: TimeUnit
 
     o = TestSchema({"size": "3K", "time": "10m"})
-    assert o.size == SizeUnit("3072B")
-    assert o.time == TimeUnit("600s")
+    assert int(o.size) == int(SizeUnit("3072B"))
+    assert int(o.time) == int(TimeUnit("600s"))
     assert o.size.bytes() == 3072
     assert o.time.seconds() == 10 * 60