]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
manager/kafka_client.py: use headers instead of parsing message key docs-jezek-test-jq0zac/deployments/7311
authorAleš Mrázek <ales.mrazek@nic.cz>
Wed, 30 Jul 2025 13:23:46 +0000 (15:23 +0200)
committerAleš Mrázek <ales.mrazek@nic.cz>
Wed, 30 Jul 2025 13:23:46 +0000 (15:23 +0200)
python/knot_resolver/manager/kafka_client.py

index e33c3d7448e20501c1d6a982d4e01aa7ab767884..0e32b614e75d633670bf6b50f4a3cb9169d76cde 100644 (file)
@@ -23,6 +23,13 @@ if KAFKA_LIB:
 
     _kafka: Optional["KresKafkaClient"] = None
 
+    class MessageHeaders:
+        def __init__(self, headers: Dict[str, bytes]) -> None:
+            self.hostname = headers["hostname"].decode("utf-8") if "hostname" in headers else None
+            self.file_name = headers["file-name"].decode("utf-8") if "file-name" in headers else None
+            self.total_chunks = int(headers["total-chunks"].decode("utf-8")) if "total-chunks" in headers else None
+            self.chunk_index = int(headers["chunk-index"].decode("utf-8")) if "chunk-index" in headers else None
+
     class KresKafkaClient:
         def __init__(self, config: KresConfig) -> None:
             self._config = config
@@ -38,11 +45,11 @@ if KAFKA_LIB:
             if self._consumer:
                 self._consumer.close()
 
-        def _consume(self) -> None:  # noqa: PLR0912, PLR0915
+        def _consume(self) -> None:  # noqa: PLR0912
             if not self._consumer:
                 return
 
-            logger.info("Consuming...")
+            logger.info("Consuming messages...")
             messages: Dict[TopicPartition, List[ConsumerRecord]] = self._consumer.poll()
 
             for _partition, records in messages.items():
@@ -50,79 +57,81 @@ if KAFKA_LIB:
                     try:
                         key: str = record.key.decode("utf-8")
                         value: str = record.value.decode("utf-8")
+                        logger.info(f"Received message with '{key}' key")
 
-                        # messages without key
-                        # config
-                        if not key:
-                            logger.info("Received configuration message")
-
-                            # validate config
-                            KresConfig(parse_json(value))
-
-                            file_path = self._config.kafka.files_dir.to_path() / "config.json"
-                            file_path_tmp = f"{file_path}.tmp"
-                            file_path_backup = f"{file_path}.backup"
-
-                            if file_path.exists():
-                                shutil.copy(file_path, file_path_backup)
-                            with open(file_path_tmp, "w") as file:
-                                file.write(value)
-
-                            os.replace(file_path_tmp, file_path)
-
-                            logger.info(f"Configuration saved to '{file_path}'")
+                        # parse headers
+                        headers = MessageHeaders(dict(record.headers))
 
-                            # trigger delayed configuration reload
-                            trigger_reload(self._config)
+                        my_hostname = str(self._config.hostname)
+                        if headers.hostname != my_hostname:
+                            logger.info(
+                                f"Dropping message intended for '{headers.hostname}' hostname, this resolver hostname is '{my_hostname}'"
+                            )
                             continue
-                        # messages with key
-                        # RPZ or other files
 
-                        logger.info(f"Received message with '{key}' key")
-                        key_split = key.split(":")
+                        if not headers.file_name:
+                            logger.error("Missing 'file-name' header")
+                            continue
 
                         # prepare files names
-                        file_name = key_split[0]
-                        file_path = Path(file_name)
+                        file_path = Path(headers.file_name)
                         if not file_path.is_absolute():
-                            file_path = self._config.kafka.files_dir.to_path() / file_name
-
+                            file_path = self._config.kafka.files_dir.to_path() / file_path
                         file_path_tmp = f"{file_path}.tmp"
                         file_path_backup = f"{file_path}.backup"
 
-                        file_part = key_split[1] if len(key_split) > 1 else None
-                        _, file_extension = os.path.splitext(file_name)
+                        _, file_extension = os.path.splitext(headers.file_name)
 
-                        # received part of data
-                        if file_part and file_part.isdigit():
-                            # rewrite only on first part, else append
-                            mode = "w" if int(file_part) == 0 else "a"
-                            with open(file_path_tmp, mode) as file:
-                                file.write(value)
-                            logger.debug(f"Saved part {file_part} of data to '{file_path_tmp}' file")
-                        # received END of data
-                        elif file_part and file_part == "END":
+                        if not headers.chunk_index:
+                            logger.error("Missing 'chunk-index' message header")
+                        elif not headers.total_chunks:
+                            logger.error("Missing 'total-chunks' message header")
+                        # received full data in one message
+                        # or last chunk of data
+                        elif headers.chunk_index == headers.total_chunks:
                             if file_path.exists():
                                 shutil.copy(file_path, file_path_backup)
                                 logger.debug(f"Created backup of '{file_path_backup}' file")
 
-                            os.replace(file_path_tmp, file_path)
-                            logger.info(f"Saved file data to '{file_path}'")
+                            with open(file_path_tmp, "w") as file:
+                                file.write(value)
+
+                            config_extensions = (".json", ".yaml", ".yml")
+                            if file_extension in config_extensions:
+                                # validate config
+                                KresConfig(parse_json(value))
 
-                            # trigger delayed configuration renew
-                            trigger_renew(self._config)
+                            os.replace(file_path_tmp, file_path)
+                            logger.info(f"Saved data to '{file_path}'")
+
+                            if file_extension in config_extensions:
+                                # trigger delayed configuration reload
+                                trigger_reload(self._config)
+                            else:
+                                # trigger delayed configuration renew
+                                trigger_renew(self._config)
+                        # received part of data
                         else:
-                            logger.error("Failed to parse message key")
+                            # rewrite only on first part, else append
+                            mode = "w" if int(headers.chunk_index) == 1 else "a"
+                            with open(file_path_tmp, mode) as file:
+                                file.write(value)
+                            logger.debug(f"Saved part {headers.chunk_index} of data to '{file_path_tmp}' file")
+
                     except Exception as e:
                         logger.error(f"Processing message failed with error: \n{e}")
                         continue
 
+            # start new timer
             self._consumer_timer = Timer(5, self._consume)
             self._consumer_timer.start()
 
         def _consumer_connect(self) -> None:
             kafka = self._config.kafka
 
+            kafka_logger = logging.getLogger("kafka")
+            kafka_logger.setLevel(logging.ERROR)
+
             brokers = []
             for server in kafka.server.to_std():
                 broker = str(server)