]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
manager: controller: new SubprocessType for loading policy added
authorAleš Mrázek <ales.mrazek@nic.cz>
Mon, 11 Mar 2024 13:24:22 +0000 (14:24 +0100)
committerAleš Mrázek <ales.mrazek@nic.cz>
Tue, 2 Jul 2024 12:07:48 +0000 (14:07 +0200)
manager/knot_resolver_manager/kresd_controller/interface.py

index 1dc99505fb3344dd6248a2e7c43e64469334ecb1..bc7708721f208d71edd1810fd6c1ca4633f32ac5 100644 (file)
@@ -6,6 +6,7 @@ import struct
 import sys
 from abc import ABC, abstractmethod  # pylint: disable=no-name-in-module
 from enum import Enum, auto
+from pathlib import Path
 from typing import Dict, Iterable, Optional, Type, TypeVar
 from weakref import WeakValueDictionary
 
@@ -20,6 +21,7 @@ logger = logging.getLogger(__name__)
 
 class SubprocessType(Enum):
     KRESD = auto()
+    POLICY_LOADER = auto()
     GC = auto()
 
 
@@ -105,24 +107,43 @@ class Subprocess(ABC):
         self._registered_worker: bool = False
 
     async def start(self) -> None:
-        # create config file
-        lua_config = self._config.render_lua()
-        await writefile(kresd_config_file(self._config, self.id), lua_config)
+
+        config_file: Optional[Path] = None
+        if self.type is SubprocessType.KRESD:
+            config_lua = self._config.render_lua()
+            config_file = kresd_config_file(self._config, self.id)
+            await writefile(config_file, config_lua)
+        elif self.type is SubprocessType.POLICY_LOADER:
+            config_lua = self._config.render_lua_policy()
+            config_file = Path("policy-loader.conf")
+            await writefile(config_file, config_lua)
+
         try:
             await self._start()
             if self.type is SubprocessType.KRESD:
                 register_worker(self)
                 self._registered_worker = True
         except SubprocessControllerException as e:
-            kresd_config_file(self._config, self.id).unlink()
+            if config_file:
+                config_file.unlink()
             raise e
 
     async def apply_new_config(self, new_config: KresConfig) -> None:
         self._config = new_config
+
         # update config file
         logger.debug(f"Writing config file for {self.id}")
-        lua_config = new_config.render_lua()
-        await writefile(kresd_config_file(new_config, self.id), lua_config)
+
+        config_file: Optional[Path] = None
+        if self.type is SubprocessType.KRESD:
+            config_lua = self._config.render_lua()
+            config_file = kresd_config_file(self._config, self.id)
+            await writefile(config_file, config_lua)
+        elif self.type is SubprocessType.POLICY_LOADER:
+            config_lua = self._config.render_lua_policy()
+            config_file = Path("policy-loader.conf")
+            await writefile(config_file, config_lua)
+
         # update runtime status
         logger.debug(f"Restarting {self.id}")
         await self._restart()
@@ -138,7 +159,13 @@ class Subprocess(ABC):
         Remove temporary files and all traces of this instance running. It is NOT SAFE to call this while
         the kresd is running, because it will break automatic restarts (at the very least).
         """
-        kresd_config_file(self._config, self.id).unlink()
+
+        if self.type is SubprocessType.KRESD:
+            config_file = kresd_config_file(self._config, self.id)
+            config_file.unlink()
+        elif self.type is SubprocessType.POLICY_LOADER:
+            config_file = Path("policy-loader.conf")
+            config_file.unlink()
 
     def __eq__(self, o: object) -> bool:
         return isinstance(o, type(self)) and o.type == self.type and o.id == self.id
@@ -167,8 +194,12 @@ class Subprocess(ABC):
         return self._id
 
     async def command(self, cmd: str) -> object:
+        if not self._registered_worker:
+            raise RuntimeError("the command cannot be sent to a process other than the kresd worker")
+
         reader: asyncio.StreamReader
         writer: Optional[asyncio.StreamWriter] = None
+
         try:
             reader, writer = await asyncio.open_unix_connection(f"./control/{int(self.id)}")