]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
datamodel: types: added custom types for float values
authorAleš Mrázek <ales.mrazek@nic.cz>
Thu, 9 Jan 2025 09:55:52 +0000 (10:55 +0100)
committerVladimír Čunát <vladimir.cunat@nic.cz>
Sun, 19 Jan 2025 18:40:58 +0000 (19:40 +0100)
FloatBase: base type to work with float values
FloatNonNegative: custom type for non-negative float numbers

python/knot_resolver/datamodel/types/__init__.py
python/knot_resolver/datamodel/types/base_types.py
python/knot_resolver/datamodel/types/types.py
tests/manager/datamodel/types/test_base_types.py

index d1334b5a272e0004e8c1d70ebf931d0e00ba53cb..7e5cab417a7aedd8d25e27ad0c1fda9f792977c8 100644 (file)
@@ -5,6 +5,7 @@ from .types import (
     DomainName,
     EscapedStr,
     EscapedStr32B,
+    FloatNonNegative,
     IDPattern,
     Int0_32,
     Int0_512,
@@ -37,6 +38,7 @@ __all__ = [
     "DomainName",
     "EscapedStr",
     "EscapedStr32B",
+    "FloatNonNegative",
     "IDPattern",
     "Int0_32",
     "Int0_512",
index 2dce91a98d456025645ad702163409a7416b4046..19a2b2d668e499a0b2ebcadc7075397cb80c0c72 100644 (file)
@@ -1,7 +1,7 @@
 # ruff: noqa: SLF001
 
 import re
-from typing import Any, Dict, Type
+from typing import Any, Dict, Type, Union
 
 from knot_resolver.utils.compat.typing import Pattern
 from knot_resolver.utils.modeling import BaseValueType
@@ -46,6 +46,48 @@ class IntBase(BaseValueType):
         return {"type": "integer"}
 
 
+class FloatBase(BaseValueType):
+    """
+    Base class to work with float value.
+    """
+
+    _orig_value: Union[float, int]
+    _value: float
+
+    def __init__(self, source_value: Any, object_path: str = "/") -> None:
+        if isinstance(source_value, (float, int)) and not isinstance(source_value, bool):
+            self._orig_value = source_value
+            self._value = float(source_value)
+        else:
+            raise ValueError(
+                f"Unexpected value for '{type(self)}'."
+                f" Expected float, got '{source_value}' with type '{type(source_value)}'",
+                object_path,
+            )
+
+    def __int__(self) -> int:
+        return int(self._value)
+
+    def __float__(self) -> float:
+        return self._value
+
+    def __str__(self) -> str:
+        return str(self._value)
+
+    def __repr__(self) -> str:
+        return f'{type(self).__name__}("{self._value}")'
+
+    def __eq__(self, o: object) -> bool:
+        return isinstance(o, FloatBase) and o._value == self._value
+
+    def serialize(self) -> Any:
+        return self._orig_value
+
+    @classmethod
+    def json_schema(cls: Type["FloatBase"]) -> Dict[Any, Any]:
+        return {"type": "number"}
+
+
 class StrBase(BaseValueType):
     """
     Base class to work with string value.
@@ -151,6 +193,35 @@ class IntRangeBase(IntBase):
         return typ
 
 
+class FloatRangeBase(FloatBase):
+    """
+    Base class to work with float value in range.
+    Just inherit the class and set the values for '_min' and '_max'.
+
+    class FloatNonNegative(IntRangeBase):
+        _min: float = 0.0
+    """
+
+    _min: float
+    _max: float
+
+    def __init__(self, source_value: Any, object_path: str = "/") -> None:
+        super().__init__(source_value, object_path)
+        if hasattr(self, "_min") and (self._value < self._min):
+            raise ValueError(f"value {self._value} is lower than the minimum {self._min}.", object_path)
+        if hasattr(self, "_max") and (self._value > self._max):
+            raise ValueError(f"value {self._value} is higher than the maximum {self._max}", object_path)
+
+    @classmethod
+    def json_schema(cls: Type["FloatRangeBase"]) -> Dict[Any, Any]:
+        typ: Dict[str, Any] = {"type": "number"}
+        if hasattr(cls, "_min"):
+            typ["minimum"] = cls._min
+        if hasattr(cls, "_max"):
+            typ["maximum"] = cls._max
+        return typ
+
+
 class PatternBase(StrBase):
     """
     Base class to work with string value that match regex pattern.
index 3c9b9fe1c8f6ee184ff26ed0e6628537f5bffb15..946e2b13e35d23a7f91586e82e7500b30f0befb4 100644 (file)
@@ -2,7 +2,14 @@ import ipaddress
 import re
 from typing import Any, Dict, Optional, Type, Union
 
-from knot_resolver.datamodel.types.base_types import IntRangeBase, PatternBase, StrBase, StringLengthBase, UnitBase
+from knot_resolver.datamodel.types.base_types import (
+    FloatRangeBase,
+    IntRangeBase,
+    PatternBase,
+    StrBase,
+    StringLengthBase,
+    UnitBase,
+)
 from knot_resolver.utils.modeling import BaseValueType
 
 
@@ -46,6 +53,10 @@ class PortNumber(IntRangeBase):
             raise ValueError(f"invalid port number {port}") from e
 
 
+class FloatNonNegative(FloatRangeBase):
+    _min: float = 0.0
+
+
 class SizeUnit(UnitBase):
     _units = {"B": 1, "K": 1024, "M": 1024**2, "G": 1024**3}
 
index 210604ed995e47687e30068c8d527505fb47d6d9..4bb27a958f7669eaed014b8dee1b9c8079894e05 100644 (file)
@@ -6,7 +6,7 @@ import pytest
 from pytest import raises
 
 from knot_resolver import KresBaseException
-from knot_resolver.datamodel.types.base_types import IntRangeBase, StringLengthBase
+from knot_resolver.datamodel.types.base_types import FloatRangeBase, IntRangeBase, StringLengthBase
 
 
 @pytest.mark.parametrize("min,max", [(0, None), (None, 0), (1, 65535), (-65535, -1)])
@@ -38,6 +38,35 @@ def test_int_range_base(min: Optional[int], max: Optional[int]):
             Test(inval)
 
 
+@pytest.mark.parametrize("min,max", [(0.0, None), (None, 0.0), (1.0, 65535.0), (-65535.0, -1.0)])
+def test_float_range_base(min: Optional[float], max: Optional[float]):
+    class Test(FloatRangeBase):
+        if min:
+            _min = min
+        if max:
+            _max = max
+
+    if min:
+        assert float(Test(min)) == min
+    if max:
+        assert float(Test(max)) == max
+
+    rmin = min if min else sys.float_info.min - 1.0
+    rmax = max if max else sys.float_info.max
+
+    n = 100
+    vals: List[float] = [random.uniform(rmin, rmax) for _ in range(n)]
+    assert [str(Test(val)) == f"{val}" for val in vals]
+
+    invals: List[float] = []
+    invals.extend([random.uniform(rmax + 1.0, sys.float_info.max) for _ in range(n % 2)] if max else [])
+    invals.extend([random.uniform(sys.float_info.min - 1.0, rmin - 1.0) for _ in range(n % 2)] if max else [])
+
+    for inval in invals:
+        with raises(KresBaseException):
+            Test(inval)
+
+
 @pytest.mark.parametrize("min,max", [(10, None), (None, 10), (2, 32)])
 def test_str_bytes_length_base(min: Optional[int], max: Optional[int]):
     class Test(StringLengthBase):