--- /dev/null
+[flake8]
+max-line-length = 120
\ No newline at end of file
from aiohttp import web
+from knot_resolver_manager.kresd_manager import KresdManager
+
+from . import confmodel
+from . import compat
async def hello(_request: web.Request) -> web.Response:
return web.Response(text="Hello, world")
+async def apply_config(request: web.Request) -> web.Response:
+ config = await confmodel.parse(await request.text())
+ manager: KresdManager = request.app["kresd_manager"]
+ await manager.apply_config(config)
+ return web.Response(text="OK")
+
+
def main():
app = web.Application()
- app.add_routes([web.get("/", hello)])
+ # initialize KresdManager
+ manager = KresdManager()
+ compat.asyncio_run(manager.load_system_state())
+ app["kresd_manager"] = manager
+
+ # configure routing
+ app.add_routes([web.get("/", hello), web.post("/config", apply_config)])
+
+ # run forever
web.run_app(app, path="./manager.sock")
--- /dev/null
+# pylint: disable=E1101
+
+from asyncio.futures import Future
+import sys
+import asyncio
+import functools
+from typing import Awaitable, Coroutine
+
+
+def asyncio_to_thread(func, *args, **kwargs) -> Awaitable:
+ # version 3.9 and higher, call directly
+ if sys.version_info.major >= 3 and sys.version_info.minor >= 9:
+ return asyncio.to_thread(func, *args, **kwargs)
+
+ # earlier versions, run with default executor
+ else:
+ loop = asyncio.get_event_loop()
+ pfunc = functools.partial(func, *args, **kwargs)
+ return loop.run_in_executor(None, pfunc)
+
+
+def asyncio_create_task(coro: Coroutine, name=None) -> Future:
+ # version 3.8 and higher, call directly
+ if sys.version_info.major >= 3 and sys.version_info.minor >= 8:
+ return asyncio.create_task(coro, name=name)
+
+ # version 3.7 and higher, call directly without the name argument
+ if sys.version_info.major >= 3 and sys.version_info.minor >= 8:
+ return asyncio.create_task(coro)
+
+ # earlier versions, use older function
+ else:
+ return asyncio.ensure_future(coro)
+
+
+def asyncio_run(coro: Coroutine, debug=None) -> Awaitable:
+ # ideally copy-paste of this:
+ # https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py#L8
+
+ # version 3.7 and higher, call directly
+ if sys.version_info.major >= 3 and sys.version_info.minor >= 7:
+ return asyncio.run(coro, debug=debug)
+
+ # earlier versions, run with default executor
+ else:
+ loop = asyncio.get_event_loop()
+ return loop.run_until_complete(coro)
--- /dev/null
+from strictyaml import Map, Str, Int
+from strictyaml.parser import load
+from strictyaml.representation import YAML
+
+
+_CONFIG_SCHEMA = Map({"lua_config": Str(), "num_workers": Int()})
+
+
+def _get_config_schema():
+ """
+ Returns a schema defined using the strictyaml library, that the manager
+ should accept at it's input.
+
+ If this function does something, that can be cached, it should cache it by
+ itself. For example, loading the schema from a file is OK, the loaded
+ parsed schema object should then however be cached in memory. The function
+ is on purpose non-async and it's expected to return very fast.
+ """
+ return _CONFIG_SCHEMA
+
+
+class ConfigValidationException(Exception):
+ pass
+
+
+async def _validate_config(config):
+ """
+ Perform runtime value validation of the provided configuration object which
+ is guaranteed to follow the configuration schema returned by the
+ `get_config_schema` function.
+
+ Throws a ConfigValidationException in case any errors are found. The error
+ message should be in the error message of the exception.
+ """
+
+ if config["num_workers"] < 0:
+ raise ConfigValidationException("Number of workers must be non-negative")
+
+
+async def parse(textual_config: str) -> YAML:
+ schema = _get_config_schema()
+ conf = load(textual_config, schema)
+ await _validate_config(conf)
+ return conf
--- /dev/null
+import asyncio
+from uuid import uuid4
+from typing import List, Optional
+from strictyaml.representation import YAML
+
+
+class Kresd:
+ def __init__(self, kresd_id: Optional[str] = None):
+ self._lock = asyncio.Lock()
+ self._id: str = kresd_id or str(uuid4())
+
+ # if we got existing id, mark for restart
+ self._needs_restart: bool = id is not None
+
+ async def is_running(self) -> bool:
+ raise NotImplementedError()
+
+ async def start(self):
+ raise NotImplementedError()
+
+ async def stop(self):
+ raise NotImplementedError()
+
+ async def restart(self):
+ raise NotImplementedError()
+
+ def mark_for_restart(self):
+ self._needs_restart = True
+
+
+class KresdManager:
+ def __init__(self):
+ self._children: List[Kresd] = []
+ self._children_lock = asyncio.Lock()
+
+ async def load_system_state(self):
+ async with self._children_lock:
+ await self._collect_already_running_children()
+
+ async def _spawn_new_child(self):
+ kresd = Kresd()
+ await kresd.start()
+ self._children.append(kresd)
+
+ async def _stop_a_child(self):
+ if len(self._children) == 0:
+ raise IndexError("Can't stop a kresd when there are no running")
+
+ kresd = self._children.pop()
+ await kresd.stop()
+
+ async def _collect_already_running_children(self):
+ raise NotImplementedError()
+
+ async def _rolling_restart(self):
+ for kresd in self._children:
+ await kresd.restart()
+ await asyncio.sleep(1)
+
+ async def _ensure_number_of_children(self, n: int):
+ # kill children that are not needed
+ while len(self._children) > n:
+ await self._stop_a_child()
+
+ # spawn new children if needed
+ while len(self._children) < n:
+ await self._spawn_new_child()
+
+ async def _write_config(self, config: YAML):
+ raise NotImplementedError()
+
+ async def apply_config(self, config: YAML):
+ async with self._children_lock:
+ await self._write_config(config)
+ await self._ensure_number_of_children(config["num_workers"])
+ await self._rolling_restart()
--- /dev/null
+from typing import List, Union
+import dbus
+from typing_extensions import Literal
+
+
+def _create_manager_interface():
+ bus = dbus.SystemBus()
+ systemd = bus.get_object("org.freedesktop.systemd1", "/org/freedesktop/systemd1")
+
+ manager = dbus.Interface(systemd, "org.freedesktop.systemd1.Manager")
+
+ return manager
+
+
+def get_unit_file_state(
+ unit_name: str,
+) -> Union[Literal["disabled"], Literal["enabled"]]:
+ res = str(_create_manager_interface().GetUnitFileState(unit_name))
+ assert res == "disabled" or res == "enabled"
+ return res
+
+
+def list_units() -> List[str]:
+ return [str(u[0]) for u in _create_manager_interface().ListUnits()]
+
+
+def list_jobs():
+ return _create_manager_interface().ListJobs()
+
+
+def restart_unit(unit_name: str):
+ return _create_manager_interface().RestartUnit(unit_name, "fail")
optional = false
python-versions = ">=3.6, <3.7"
+[[package]]
+name = "dbus-python"
+version = "1.2.16"
+description = "Python bindings for libdbus"
+category = "main"
+optional = false
+python-versions = "*"
+
[[package]]
name = "distlib"
version = "0.3.1"
[metadata]
lock-version = "1.1"
python-versions = "^3.6.12"
-content-hash = "90a3b2334875dcde45ebbb46bff45b04e689ef28d37bdc560056e4fe365fde0c"
+content-hash = "103e16cdbcee85cc8aa19e806f3a595b7d82bc565f6de0fa7970745c6ef6c1ef"
[metadata.files]
aiohttp = [
{file = "dataclasses-0.8-py3-none-any.whl", hash = "sha256:0201d89fa866f68c8ebd9d08ee6ff50c0b255f8ec63a71c16fda7af82bb887bf"},
{file = "dataclasses-0.8.tar.gz", hash = "sha256:8479067f342acf957dc82ec415d355ab5edb7e7646b90dc6e2fd1d96ad084c97"},
]
+dbus-python = [
+ {file = "dbus-python-1.2.16.tar.gz", hash = "sha256:11238f1d86c995d8aed2e22f04a1e3779f0d70e587caffeab4857f3c662ed5a4"},
+]
distlib = [
{file = "distlib-0.3.1-py2.py3-none-any.whl", hash = "sha256:8c09de2c67b3e7deef7184574fc060ab8a793e7adbb183d942c389c8b13c52fb"},
{file = "distlib-0.3.1.zip", hash = "sha256:edf6116872c863e1aa9d5bb7cb5e05a022c519a4594dc703843343a9ddd9bff1"},
python = "^3.6.12"
aiohttp = "^3.6.12"
strictyaml = "^1.3.2"
+dbus-python = "^1.2.16"
[tool.poetry.dev-dependencies]
pytest = "^5.2"
"no-self-use",
"raise-missing-from",
"too-few-public-methods",
- "unused-import", # checked by flake8
+ "unused-import", # checked by flake8,
+ "bad-continuation", # conflicts with black
+ "consider-using-in", # pyright can't see through in expressions
]
[tool.pylint.SIMILARITIES]