-from strictyaml import Map, Str, Int
-from strictyaml.parser import load
-from strictyaml.representation import YAML
-
from .datamodel import ConfData
-_CONFIG_SCHEMA = Map({"lua_config": Str(), "num_workers": Int()})
-
-
-def _get_config_schema():
- """
- Returns a schema defined using the strictyaml library, that the manager
- should accept at it's input.
-
- If this function does something, that can be cached, it should cache it by
- itself. For example, loading the schema from a file is OK, the loaded
- parsed schema object should then however be cached in memory. The function
- is on purpose non-async and it's expected to return very fast.
- """
- return _CONFIG_SCHEMA
-
-
class ConfigValidationException(Exception):
pass
-async def _validate_config(config):
- """
- Perform runtime value validation of the provided configuration object which
- is guaranteed to follow the configuration schema returned by the
- `get_config_schema` function.
-
- Throws a ConfigValidationException in case any errors are found. The error
- message should be in the error message of the exception.
- """
-
- if config["num_workers"] < 0:
- raise ConfigValidationException("Number of workers must be non-negative")
-
-
async def parse(yaml: str) -> ConfData:
conf = ConfData.from_yaml(yaml)
await conf.validate()
- return conf
\ No newline at end of file
+ return conf
from dataclasses import dataclass
+from typing import Optional
-from .utils import dataclass_strictyaml
+from .utils import StrictyamlParser
class ConfDataValidationException(Exception):
@dataclass
-@dataclass_strictyaml
-class ConfData:
+class ConfData(StrictyamlParser):
num_workers: int = 1
- lua_config: str = None
+ lua_config: Optional[str] = None
async def validate(self) -> bool:
if self.num_workers < 0:
raise ConfDataValidationException("Number of workers must be non-negative")
+
+ return True
import asyncio
from uuid import uuid4
from typing import List, Optional
-from strictyaml.representation import YAML
from . import compat
from . import systemd
async def _write_config(self, config: ConfData):
# FIXME: this code is blocking!!!
- with open("/etc/knot-resolver/kresd.conf", "w") as f:
- f.write(config.lua_config)
+ if config.lua_config is not None:
+ with open("/etc/knot-resolver/kresd.conf", "w") as f:
+ f.write(config.lua_config)
async def apply_config(self, config: ConfData):
async with self._children_lock:
-from .dataclasses_yaml import dataclass_strictyaml_schema, dataclass_strictyaml
+from .dataclasses_yaml import (
+ dataclass_strictyaml_schema,
+ dataclass_strictyaml,
+ StrictyamlParser,
+)
-__all__ = ["dataclass_strictyaml_schema", "dataclass_strictyaml"]
+__all__ = ["dataclass_strictyaml_schema", "dataclass_strictyaml", "StrictyamlParser"]
-from typing import List, Dict, Tuple, Union
+from typing import List, Dict, Tuple, Type, TypeVar, Union
from strictyaml import (
Map,
Str,
# List[T]
elif origin == List and len(args) == 1:
kwargs[name] = [
- _yamlobj_to_dataclass(args[0], val)
- for val in obj[name]
+ _yamlobj_to_dataclass(args[0], val) for val in obj[name]
]
# Tuple
return cls(**kwargs)
+def _from_yaml(cls, text: str):
+ schema = getattr(cls, _SCHEMA_FIELD_NAME)
+
+ yamlobj = load(text, schema)
+ return _yamlobj_to_dataclass(cls, yamlobj)
+
+
def dataclass_strictyaml(cls):
if not hasattr(cls, _SCHEMA_FIELD_NAME):
cls = dataclass_strictyaml_schema(cls)
- def from_yaml(text: str) -> cls:
- schema = getattr(cls, _SCHEMA_FIELD_NAME)
+ setattr(cls, "from_yaml", classmethod(_from_yaml))
+ return cls
+
- yamlobj = load(text, schema)
- return _yamlobj_to_dataclass(cls, yamlobj)
+_T = TypeVar("_T", bound="StrictyamlParser")
- setattr(cls, "from_yaml", from_yaml)
- return cls
+
+class StrictyamlParser:
+ @classmethod
+ def from_yaml(cls: Type[_T], text: str) -> _T:
+ if not hasattr(cls, _SCHEMA_FIELD_NAME):
+ dataclass_strictyaml_schema(cls)
+
+ return _from_yaml(cls, text)
--- /dev/null
+from knot_resolver_manager.utils.dataclasses_yaml import (
+ StrictyamlParser,
+ dataclass_strictyaml,
+)
+from knot_resolver_manager.utils import dataclass_strictyaml_schema
+from typing import List, Dict, Optional, Tuple
+from strictyaml import Map, Str, EmptyDict, Int, Float, Seq, MapPattern, FixedSeq
+import strictyaml
+import pytest
+from dataclasses import dataclass
+
+
+def test_parsing_primitive():
+ @dataclass
+ @dataclass_strictyaml
+ class TestClass:
+ i: int
+ s: str
+ f: float
+
+ yaml = """i: 5
+s: "test"
+f: 3.14"""
+
+ obj = TestClass.from_yaml(yaml)
+
+ assert obj.i == 5
+ assert obj.s == "test"
+ assert obj.f == 3.14
+
+
+def test_parsing_nested():
+ @dataclass
+ @dataclass_strictyaml_schema
+ class Lower:
+ i: int
+
+ @dataclass
+ class Upper(StrictyamlParser):
+ l: Lower
+
+ yaml = """l:
+ i: 5"""
+
+ obj = Upper.from_yaml(yaml)
+ assert obj.l.i == 5
+
+
+def test_simple_compount_types():
+ @dataclass
+ class TestClass(StrictyamlParser):
+ l: List[int]
+ d: Dict[str, str]
+ t: Tuple[str, int]
+ o: Optional[int]
+
+ yaml = """l:
+ - 1
+ - 2
+ - 3
+ - 4
+ - 5
+d:
+ something: else
+ w: all
+t:
+ - test
+ - 5"""
+
+ obj = TestClass.from_yaml(yaml)
+
+ assert obj.l == [1, 2, 3, 4, 5]
+ assert obj.d == {"something": "else", "w": "all"}
+ assert obj.t == ("test", 5)
+ assert obj.o is None
+
+
+def test_nested_compount_types():
+ @dataclass
+ class TestClass(StrictyamlParser):
+ o: Optional[Dict[str, str]]
+
+ yaml = """o:
+ key: val"""
+
+ obj = TestClass.from_yaml(yaml)
+
+ assert obj.o == {"key": "val"}
+
+
+def test_nested_compount_types2():
+ @dataclass
+ class TestClass(StrictyamlParser):
+ i: int
+ o: Optional[Dict[str, str]]
+
+ yaml = "i: 5"
+
+ obj = TestClass.from_yaml(yaml)
+
+ assert obj.o is None