]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
manager/kafka_client.py: separate file for each chunk
authorAleš Mrázek <ales.mrazek@nic.cz>
Mon, 4 Aug 2025 22:15:21 +0000 (00:15 +0200)
committerVladimír Čunát <vladimir.cunat@nic.cz>
Thu, 9 Oct 2025 08:58:45 +0000 (10:58 +0200)
Each chunk of the file is stored separately. If all the chunks are available, the final file is assembled.

python/knot_resolver/manager/kafka_client.py

index d69cfb30274532506e051808546736681c4cd6f0..e5772a177d270cfa0ef2c57a4c06960865bb216c 100644 (file)
@@ -1,16 +1,17 @@
 import logging
-import os
 import shutil
 from pathlib import Path
 from threading import Timer
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
 
 from knot_resolver.constants import KAFKA_LIB
 from knot_resolver.datamodel import KresConfig
 from knot_resolver.manager.config_store import ConfigStore, only_on_real_changes_update
+from knot_resolver.manager.exceptions import KresKafkaClientError
 from knot_resolver.manager.triggers import trigger_reload, trigger_renew
 from knot_resolver.utils.functional import Result
-from knot_resolver.utils.modeling.parsing import parse_json
+from knot_resolver.utils.modeling import try_to_parse
+from knot_resolver.utils.modeling.exceptions import DataParsingError, DataValidationError
 
 logger = logging.getLogger(__name__)
 
@@ -25,140 +26,263 @@ def kafka_config(config: KresConfig) -> List[Any]:
 if KAFKA_LIB:
     from kafka import KafkaConsumer  # type: ignore[import-untyped,import-not-found]
     from kafka.consumer.fetcher import ConsumerRecord  # type: ignore[import-untyped,import-not-found]
-    from kafka.errors import NoBrokersAvailable  # type: ignore[import-untyped,import-not-found]
+    from kafka.errors import KafkaError  # type: ignore[import-untyped,import-not-found]
     from kafka.structs import TopicPartition  # type: ignore[import-untyped,import-not-found]
 
+    config_file_extensions = (".json", ".yaml", ".yml")
+
     _kafka: Optional["KresKafkaClient"] = None
 
-    class MessageHeaders:
-        def __init__(self, headers: Dict[str, bytes]) -> None:
-            self.hostname = headers["hostname"].decode("utf-8") if "hostname" in headers else None
-            self.file_name = headers["file-name"].decode("utf-8") if "file-name" in headers else None
-            self.total_chunks = int(headers["total-chunks"].decode("utf-8")) if "total-chunks" in headers else None
-            self.chunk_index = int(headers["chunk-index"].decode("utf-8")) if "chunk-index" in headers else None
+    def backup_and_replace(file_src_path: Path, file_dest_path: Path) -> None:
+        if file_dest_path.exists():
+            file_backup_path = Path(f"{file_dest_path}.backup")
+            shutil.copy(file_dest_path, file_backup_path)
+            logger.debug(f"Created backup file '{file_backup_path}'")
+        file_src_path.replace(file_dest_path)
+        logger.info(f"Saved new data to '{file_dest_path}'")
+
+    def create_file_chunk_path(file: Path, index: int) -> Path:
+        return Path(f"{file}.chunks/{index}")
+
+    def create_file_tmp_path(file: Path) -> Path:
+        return Path(f"{file}.tmp")
+
+    class Headers:
+        def __init__(self, headers: List[Tuple[str, bytes]]) -> None:
+            # default values
+            self.hostname: Optional[str] = None
+            self.file_name: Optional[str] = None
+            self.total_chunks: Optional[int] = None
+            self.chunk_index: Optional[int] = None
+            # assign values from the message headers
+            self._assign_headers(headers)
+
+        def _assign_headers(self, headers: List[Tuple[str, bytes]]) -> None:
+            for hkey, hvalue in headers:
+                if hkey == "hostname":
+                    self.hostname = hvalue.decode("utf-8")
+                elif hkey == "file-name":
+                    self.file_name = hvalue.decode("utf-8")
+                elif hkey == "total-chunks":
+                    self.total_chunks = int(hvalue)
+                elif hkey == "chunk-index":
+                    self.chunk_index = int(hvalue)
+                else:
+                    logger.warning(f"Unknown headers key '{hkey}'")
+
+    def hostname_match(headers: Headers, hostname: str) -> bool:
+        if not headers.hostname:
+            KresKafkaClientError("The required 'hostname' message header is missing")
+
+        # skip processing if hostname don't match
+        if headers.hostname != hostname:
+            logger.info(
+                f"The resolver's hostname '{hostname}' do not match the message header hostname '{headers.hostname}':"
+                " The message is intended for a resolver with the matching hostname"
+            )
+            return False
+        return True
+
+    def check_chunk_headers(headers: Headers) -> None:
+        index = headers.chunk_index
+        total = headers.total_chunks
+
+        if index and not total:
+            raise KresKafkaClientError("missing 'total-chunks' message header")
+        if total and not index:
+            raise KresKafkaClientError("missing 'chunk-index' message header")
+        if index and total and index > total:
+            raise KresKafkaClientError(
+                f"'chunk-index' value cannot be bigger than 'total-chunks' value '{index} > {total}'"
+            )
+
+    def cleanup_files_dir(config_file_path: Path, files_dir: Path) -> None:
+        config_file_backup_path = Path(f"{config_file_path}.backup")
+        used_files: List[Path] = [config_file_path, config_file_backup_path]
+
+        # current config
+        with open(config_file_path, "r") as backup_file:
+            current_config = KresConfig(try_to_parse(backup_file.read()))
+        if current_config.local_data.rpz:
+            for rpz in current_config.local_data.rpz:
+                used_files.append(rpz.file.to_path().resolve())
+                used_files.append(Path(f"{rpz.file.to_path()}.backup").resolve())
+
+        # keep backup config functional
+        if config_file_backup_path.exists():
+            with open(config_file_backup_path, "r") as backup_file:
+                backup_config = KresConfig(try_to_parse(backup_file.read()))
+            if backup_config.local_data.rpz:
+                for backup_rpz in backup_config.local_data.rpz:
+                    used_files.append(backup_rpz.file.to_path().resolve())
+                    used_files.append(Path(f"{backup_rpz.file.to_path()}.backup").resolve())
+
+        # delete unused files from current and backup config
+        for path in files_dir.iterdir():
+            if path.is_file() and path.resolve() not in used_files:
+                logger.debug(f"Cleaned up file '{path}'")
+                path.unlink()
+
+    def process_record(config: KresConfig, record: ConsumerRecord) -> None:  # noqa: PLR0912, PLR0915
+        key: str = record.key.decode("utf-8")
+        value: str = record.value.decode("utf-8")
+        headers = Headers(record.headers)
+
+        logger.info(f"Received message with '{key}' key")
+
+        # check hostname
+        if not hostname_match(headers, str(config.hostname)):
+            return
+
+        # check chunks
+        check_chunk_headers(headers)
+
+        # check file name
+        if not headers.file_name:
+            raise KresKafkaClientError("missing 'file-name' message header")
+
+        # prepare file path and extension
+        file_path = Path(headers.file_name)
+        file_extension = file_path.suffix
+        if not file_path.is_absolute():
+            file_path = config.kafka.files_dir.to_path() / file_path
+        file_tmp_path = create_file_tmp_path(file_path)
+
+        index = headers.chunk_index
+        total = headers.total_chunks
+        file_is_ready = False
+
+        # received complete data in one message
+        if not index and not total or index == 1 and total == 1:
+            with open(file_tmp_path, "w") as file:
+                file.write(value)
+            logger.debug(f"Saved complete data to '{file_tmp_path}' file")
+            file_is_ready = True
+
+        # received chunk of data
+        elif index and total:
+            file_chunk_path = create_file_chunk_path(file_path, index)
+            # create chunks dir if not exists
+            file_chunk_path.parent.mkdir(exist_ok=True)
+            with open(file_chunk_path, "w") as file:
+                file.write(value)
+            logger.debug(f"Saved chunk {index} of data to '{file_chunk_path}' file")
+
+            missing: List[int] = []
+            file_chunks_paths: List[Path] = []
+            for i in range(1, total + 1):
+                path = create_file_chunk_path(file_path, i)
+                if path.exists():
+                    file_chunks_paths.append(path)
+                else:
+                    missing.append(i)
+
+            if len(file_chunks_paths) == total:
+                with open(file_tmp_path, "wb") as tmp_file:
+                    for path in file_chunks_paths:
+                        with open(path, "rb") as chunk_file:
+                            tmp_file.write(chunk_file.read())
+                logger.debug(f"Saved complete data from all chunks to '{file_tmp_path}' file")
+                file_is_ready = True
+
+                # remove chunks dir
+                chunks_dir = f"{file_path}.chunks"
+                shutil.rmtree(chunks_dir)
+                logger.debug(f"Removed chunks directory '{chunks_dir}'")
+            else:
+                logger.debug(f"The file '{headers.file_name}' cannot be assembled yet: missing chunks {missing}")
+
+        # complete tmp file is ready
+        if file_tmp_path.exists() and file_is_ready:
+            # configuration files (.yaml, .json, ...all)
+            if file_extension in config_file_extensions:
+                # validate configuration
+                KresConfig(try_to_parse(value))
+
+                # backup and replace file with new data
+                backup_and_replace(file_tmp_path, file_path)
+
+                # cleanup old files
+                cleanup_files_dir(file_path, config.kafka.files_dir.to_path())
+
+                # trigger reload
+                trigger_reload(config)
+
+            # other files (.rpz, ...)
+            else:
+                # backup and replace file with new data
+                backup_and_replace(file_tmp_path, file_path)
+                # trigger renew
+                trigger_renew(config)
+
+    logger.info("Successfully processed message")
 
     class KresKafkaClient:
         def __init__(self, config: KresConfig) -> None:
             self._config = config
 
-            self._consumer: Optional[KafkaConsumer] = None
-            self._consumer_connect()
-            self._consumer_timer = Timer(5, self._consume)
-            self._consumer_timer.start()
-
-        def deinit(self) -> None:
-            if self._consumer_timer:
-                self._consumer_timer.cancel()
-            if self._consumer:
-                self._consumer.close()
-
-        def _consume(self) -> None:  # noqa: PLR0912, PLR0915
-            if not self._consumer:
-                return
-
-            logger.info("Consuming messages...")
-            messages: Dict[TopicPartition, List[ConsumerRecord]] = self._consumer.poll()
-
-            for _partition, records in messages.items():
-                for record in records:
-                    try:
-                        key: str = record.key.decode("utf-8")
-                        value: str = record.value.decode("utf-8")
-                        logger.info(f"Received message with '{key}' key")
-
-                        # parse headers
-                        headers = MessageHeaders(dict(record.headers))
-
-                        my_hostname = str(self._config.hostname)
-                        if headers.hostname != my_hostname:
-                            logger.info(
-                                f"Dropping message intended for '{headers.hostname}' hostname, this resolver hostname is '{my_hostname}'"
-                            )
-                            continue
-
-                        # prepare files names
-                        file_name = headers.file_name if headers.file_name else key
-                        file_path = Path(file_name)
-                        if not file_path.is_absolute():
-                            file_path = self._config.kafka.files_dir.to_path() / file_path
-                        file_path_tmp = f"{file_path}.tmp"
-                        file_path_backup = f"{file_path}.backup"
-
-                        _, file_extension = os.path.splitext(file_name)
-
-                        # received full data in one message
-                        # or last chunk of data
-                        if headers.chunk_index == headers.total_chunks:
-                            if file_path.exists():
-                                shutil.copy(file_path, file_path_backup)
-                                logger.debug(f"Created backup of '{file_path_backup}' file")
-
-                            # rewrite only on first part, else append
-                            mode = (
-                                "w"
-                                if (headers.chunk_index and int(headers.chunk_index)) or not headers.total_chunks == 1
-                                else "a"
-                            )
-                            with open(file_path_tmp, mode) as file:
-                                file.write(value)
-
-                            config_extensions = (".json", ".yaml", ".yml")
-                            if file_extension in config_extensions:
-                                # validate config
-                                KresConfig(parse_json(value))
-
-                            os.replace(file_path_tmp, file_path)
-                            logger.info(f"Saved data to '{file_path}'")
-
-                            # config files must be reloaded
-                            if file_extension in config_extensions:
-                                # trigger delayed configuration reload
-                                trigger_reload(self._config)
-                            else:
-                                # trigger delayed configuration renew
-                                trigger_renew(self._config)
-                        # received part of data
-                        else:
-                            # rewrite only on first part, else append
-                            mode = "w" if headers.chunk_index and int(headers.chunk_index) == 1 else "a"
-                            with open(file_path_tmp, mode) as file:
-                                file.write(value)
-                            logger.debug(f"Saved part {headers.chunk_index} of data to '{file_path_tmp}' file")
-                    except Exception as e:
-                        logger.error(f"Processing message failed with error: \n{e}")
-                        continue
-
-            # start new timer
-            self._consumer_timer = Timer(5, self._consume)
-            self._consumer_timer.start()
-
-        def _consumer_connect(self) -> None:
-            kafka = self._config.kafka
-
+            # reduce the verbosity of kafka module logger
             kafka_logger = logging.getLogger("kafka")
             kafka_logger.setLevel(logging.ERROR)
 
             brokers = []
-            for server in kafka.server.to_std():
+            kafka_conf = config.kafka
+            for server in kafka_conf.server.to_std():
                 broker = str(server)
                 brokers.append(broker.replace("@", ":") if server.port else f"{broker}:9092")
 
-            logger.info("Connecting to Kafka brokers...")
+            logger.info("Connecting to Kafka broker(s)...")
             try:
                 consumer = KafkaConsumer(
-                    str(kafka.topic),
+                    str(kafka_conf.topic),
                     bootstrap_servers=brokers,
                     client_id=str(self._config.hostname),
-                    security_protocol=str(kafka.security_protocol).upper(),
-                    ssl_cafile=str(kafka.ca_file) if kafka.ca_file else None,
-                    ssl_certfile=str(kafka.cert_file) if kafka.cert_file else None,
-                    ssl_keyfile=str(kafka.key_file) if kafka.key_file else None,
+                    security_protocol=str(kafka_conf.security_protocol).upper(),
+                    ssl_cafile=str(kafka_conf.ca_file) if kafka_conf.ca_file else None,
+                    ssl_certfile=str(kafka_conf.cert_file) if kafka_conf.cert_file else None,
+                    ssl_keyfile=str(kafka_conf.key_file) if kafka_conf.key_file else None,
                 )
                 self._consumer = consumer
+
+                self._consumer_timer = Timer(5, self._consume_messages)
+                self._consumer_timer.start()
                 logger.info("Successfully connected to Kafka broker")
-            except NoBrokersAvailable:
-                logger.error(f"Connecting to Kafka broker '{kafka.server}' has failed: no broker available")
-                self._consumer = None
+            except KafkaError as e:
+                raise KresKafkaClientError(f"Connecting to Kafka broker(s) '{brokers}' has failed") from e
+
+        def deinit(self) -> None:
+            self._consumer_timer.cancel()
+            self._consumer.close()
+            self._consumer = None
+
+        def _consume_messages(self) -> None:
+            if not self._consumer:
+                return
+
+            logger.info("Started consuming messages...")
+            messages: Dict[TopicPartition, List[ConsumerRecord]] = self._consumer.poll(timeout_ms=100)
+
+            for _partition, records in messages.items():
+                for record in records:
+                    error_msg_prefix = "Processing message failed with"
+                    try:
+                        process_record(self._config, record)
+                    except KresKafkaClientError as e:
+                        logger.error(f"{error_msg_prefix} Kafka client error:\n{e}")
+                    except DataParsingError as e:
+                        logger.error(f"{error_msg_prefix} data parsing error:\n{e}")
+                    except DataValidationError as e:
+                        logger.error(f"{error_msg_prefix} data validation error:\n{e}")
+                    except Exception as e:
+                        logger.error(f"{error_msg_prefix} unknown error:\n{e}")
+
+            # keep consuming if received messages
+            if len(messages) > 0:
+                self._consume_messages()
+            else:
+                # else start new timer
+                self._consumer_timer = Timer(5, self._consume_messages)
+                self._consumer_timer.start()
 
 
 @only_on_real_changes_update(kafka_config)