import logging
-import os
import shutil
from pathlib import Path
from threading import Timer
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
from knot_resolver.constants import KAFKA_LIB
from knot_resolver.datamodel import KresConfig
from knot_resolver.manager.config_store import ConfigStore, only_on_real_changes_update
+from knot_resolver.manager.exceptions import KresKafkaClientError
from knot_resolver.manager.triggers import trigger_reload, trigger_renew
from knot_resolver.utils.functional import Result
-from knot_resolver.utils.modeling.parsing import parse_json
+from knot_resolver.utils.modeling import try_to_parse
+from knot_resolver.utils.modeling.exceptions import DataParsingError, DataValidationError
logger = logging.getLogger(__name__)
if KAFKA_LIB:
from kafka import KafkaConsumer # type: ignore[import-untyped,import-not-found]
from kafka.consumer.fetcher import ConsumerRecord # type: ignore[import-untyped,import-not-found]
- from kafka.errors import NoBrokersAvailable # type: ignore[import-untyped,import-not-found]
+ from kafka.errors import KafkaError # type: ignore[import-untyped,import-not-found]
from kafka.structs import TopicPartition # type: ignore[import-untyped,import-not-found]
+ config_file_extensions = (".json", ".yaml", ".yml")
+
_kafka: Optional["KresKafkaClient"] = None
- class MessageHeaders:
- def __init__(self, headers: Dict[str, bytes]) -> None:
- self.hostname = headers["hostname"].decode("utf-8") if "hostname" in headers else None
- self.file_name = headers["file-name"].decode("utf-8") if "file-name" in headers else None
- self.total_chunks = int(headers["total-chunks"].decode("utf-8")) if "total-chunks" in headers else None
- self.chunk_index = int(headers["chunk-index"].decode("utf-8")) if "chunk-index" in headers else None
+ def backup_and_replace(file_src_path: Path, file_dest_path: Path) -> None:
+ if file_dest_path.exists():
+ file_backup_path = Path(f"{file_dest_path}.backup")
+ shutil.copy(file_dest_path, file_backup_path)
+ logger.debug(f"Created backup file '{file_backup_path}'")
+ file_src_path.replace(file_dest_path)
+ logger.info(f"Saved new data to '{file_dest_path}'")
+
+ def create_file_chunk_path(file: Path, index: int) -> Path:
+ return Path(f"{file}.chunks/{index}")
+
+ def create_file_tmp_path(file: Path) -> Path:
+ return Path(f"{file}.tmp")
+
+ class Headers:
+ def __init__(self, headers: List[Tuple[str, bytes]]) -> None:
+ # default values
+ self.hostname: Optional[str] = None
+ self.file_name: Optional[str] = None
+ self.total_chunks: Optional[int] = None
+ self.chunk_index: Optional[int] = None
+ # assign values from the message headers
+ self._assign_headers(headers)
+
+ def _assign_headers(self, headers: List[Tuple[str, bytes]]) -> None:
+ for hkey, hvalue in headers:
+ if hkey == "hostname":
+ self.hostname = hvalue.decode("utf-8")
+ elif hkey == "file-name":
+ self.file_name = hvalue.decode("utf-8")
+ elif hkey == "total-chunks":
+ self.total_chunks = int(hvalue)
+ elif hkey == "chunk-index":
+ self.chunk_index = int(hvalue)
+ else:
+ logger.warning(f"Unknown headers key '{hkey}'")
+
+ def hostname_match(headers: Headers, hostname: str) -> bool:
+ if not headers.hostname:
+ KresKafkaClientError("The required 'hostname' message header is missing")
+
+ # skip processing if hostname don't match
+ if headers.hostname != hostname:
+ logger.info(
+ f"The resolver's hostname '{hostname}' do not match the message header hostname '{headers.hostname}':"
+ " The message is intended for a resolver with the matching hostname"
+ )
+ return False
+ return True
+
+ def check_chunk_headers(headers: Headers) -> None:
+ index = headers.chunk_index
+ total = headers.total_chunks
+
+ if index and not total:
+ raise KresKafkaClientError("missing 'total-chunks' message header")
+ if total and not index:
+ raise KresKafkaClientError("missing 'chunk-index' message header")
+ if index and total and index > total:
+ raise KresKafkaClientError(
+ f"'chunk-index' value cannot be bigger than 'total-chunks' value '{index} > {total}'"
+ )
+
+ def cleanup_files_dir(config_file_path: Path, files_dir: Path) -> None:
+ config_file_backup_path = Path(f"{config_file_path}.backup")
+ used_files: List[Path] = [config_file_path, config_file_backup_path]
+
+ # current config
+ with open(config_file_path, "r") as backup_file:
+ current_config = KresConfig(try_to_parse(backup_file.read()))
+ if current_config.local_data.rpz:
+ for rpz in current_config.local_data.rpz:
+ used_files.append(rpz.file.to_path().resolve())
+ used_files.append(Path(f"{rpz.file.to_path()}.backup").resolve())
+
+ # keep backup config functional
+ if config_file_backup_path.exists():
+ with open(config_file_backup_path, "r") as backup_file:
+ backup_config = KresConfig(try_to_parse(backup_file.read()))
+ if backup_config.local_data.rpz:
+ for backup_rpz in backup_config.local_data.rpz:
+ used_files.append(backup_rpz.file.to_path().resolve())
+ used_files.append(Path(f"{backup_rpz.file.to_path()}.backup").resolve())
+
+ # delete unused files from current and backup config
+ for path in files_dir.iterdir():
+ if path.is_file() and path.resolve() not in used_files:
+ logger.debug(f"Cleaned up file '{path}'")
+ path.unlink()
+
+ def process_record(config: KresConfig, record: ConsumerRecord) -> None: # noqa: PLR0912, PLR0915
+ key: str = record.key.decode("utf-8")
+ value: str = record.value.decode("utf-8")
+ headers = Headers(record.headers)
+
+ logger.info(f"Received message with '{key}' key")
+
+ # check hostname
+ if not hostname_match(headers, str(config.hostname)):
+ return
+
+ # check chunks
+ check_chunk_headers(headers)
+
+ # check file name
+ if not headers.file_name:
+ raise KresKafkaClientError("missing 'file-name' message header")
+
+ # prepare file path and extension
+ file_path = Path(headers.file_name)
+ file_extension = file_path.suffix
+ if not file_path.is_absolute():
+ file_path = config.kafka.files_dir.to_path() / file_path
+ file_tmp_path = create_file_tmp_path(file_path)
+
+ index = headers.chunk_index
+ total = headers.total_chunks
+ file_is_ready = False
+
+ # received complete data in one message
+ if not index and not total or index == 1 and total == 1:
+ with open(file_tmp_path, "w") as file:
+ file.write(value)
+ logger.debug(f"Saved complete data to '{file_tmp_path}' file")
+ file_is_ready = True
+
+ # received chunk of data
+ elif index and total:
+ file_chunk_path = create_file_chunk_path(file_path, index)
+ # create chunks dir if not exists
+ file_chunk_path.parent.mkdir(exist_ok=True)
+ with open(file_chunk_path, "w") as file:
+ file.write(value)
+ logger.debug(f"Saved chunk {index} of data to '{file_chunk_path}' file")
+
+ missing: List[int] = []
+ file_chunks_paths: List[Path] = []
+ for i in range(1, total + 1):
+ path = create_file_chunk_path(file_path, i)
+ if path.exists():
+ file_chunks_paths.append(path)
+ else:
+ missing.append(i)
+
+ if len(file_chunks_paths) == total:
+ with open(file_tmp_path, "wb") as tmp_file:
+ for path in file_chunks_paths:
+ with open(path, "rb") as chunk_file:
+ tmp_file.write(chunk_file.read())
+ logger.debug(f"Saved complete data from all chunks to '{file_tmp_path}' file")
+ file_is_ready = True
+
+ # remove chunks dir
+ chunks_dir = f"{file_path}.chunks"
+ shutil.rmtree(chunks_dir)
+ logger.debug(f"Removed chunks directory '{chunks_dir}'")
+ else:
+ logger.debug(f"The file '{headers.file_name}' cannot be assembled yet: missing chunks {missing}")
+
+ # complete tmp file is ready
+ if file_tmp_path.exists() and file_is_ready:
+ # configuration files (.yaml, .json, ...all)
+ if file_extension in config_file_extensions:
+ # validate configuration
+ KresConfig(try_to_parse(value))
+
+ # backup and replace file with new data
+ backup_and_replace(file_tmp_path, file_path)
+
+ # cleanup old files
+ cleanup_files_dir(file_path, config.kafka.files_dir.to_path())
+
+ # trigger reload
+ trigger_reload(config)
+
+ # other files (.rpz, ...)
+ else:
+ # backup and replace file with new data
+ backup_and_replace(file_tmp_path, file_path)
+ # trigger renew
+ trigger_renew(config)
+
+ logger.info("Successfully processed message")
class KresKafkaClient:
def __init__(self, config: KresConfig) -> None:
self._config = config
- self._consumer: Optional[KafkaConsumer] = None
- self._consumer_connect()
- self._consumer_timer = Timer(5, self._consume)
- self._consumer_timer.start()
-
- def deinit(self) -> None:
- if self._consumer_timer:
- self._consumer_timer.cancel()
- if self._consumer:
- self._consumer.close()
-
- def _consume(self) -> None: # noqa: PLR0912, PLR0915
- if not self._consumer:
- return
-
- logger.info("Consuming messages...")
- messages: Dict[TopicPartition, List[ConsumerRecord]] = self._consumer.poll()
-
- for _partition, records in messages.items():
- for record in records:
- try:
- key: str = record.key.decode("utf-8")
- value: str = record.value.decode("utf-8")
- logger.info(f"Received message with '{key}' key")
-
- # parse headers
- headers = MessageHeaders(dict(record.headers))
-
- my_hostname = str(self._config.hostname)
- if headers.hostname != my_hostname:
- logger.info(
- f"Dropping message intended for '{headers.hostname}' hostname, this resolver hostname is '{my_hostname}'"
- )
- continue
-
- # prepare files names
- file_name = headers.file_name if headers.file_name else key
- file_path = Path(file_name)
- if not file_path.is_absolute():
- file_path = self._config.kafka.files_dir.to_path() / file_path
- file_path_tmp = f"{file_path}.tmp"
- file_path_backup = f"{file_path}.backup"
-
- _, file_extension = os.path.splitext(file_name)
-
- # received full data in one message
- # or last chunk of data
- if headers.chunk_index == headers.total_chunks:
- if file_path.exists():
- shutil.copy(file_path, file_path_backup)
- logger.debug(f"Created backup of '{file_path_backup}' file")
-
- # rewrite only on first part, else append
- mode = (
- "w"
- if (headers.chunk_index and int(headers.chunk_index)) or not headers.total_chunks == 1
- else "a"
- )
- with open(file_path_tmp, mode) as file:
- file.write(value)
-
- config_extensions = (".json", ".yaml", ".yml")
- if file_extension in config_extensions:
- # validate config
- KresConfig(parse_json(value))
-
- os.replace(file_path_tmp, file_path)
- logger.info(f"Saved data to '{file_path}'")
-
- # config files must be reloaded
- if file_extension in config_extensions:
- # trigger delayed configuration reload
- trigger_reload(self._config)
- else:
- # trigger delayed configuration renew
- trigger_renew(self._config)
- # received part of data
- else:
- # rewrite only on first part, else append
- mode = "w" if headers.chunk_index and int(headers.chunk_index) == 1 else "a"
- with open(file_path_tmp, mode) as file:
- file.write(value)
- logger.debug(f"Saved part {headers.chunk_index} of data to '{file_path_tmp}' file")
- except Exception as e:
- logger.error(f"Processing message failed with error: \n{e}")
- continue
-
- # start new timer
- self._consumer_timer = Timer(5, self._consume)
- self._consumer_timer.start()
-
- def _consumer_connect(self) -> None:
- kafka = self._config.kafka
-
+ # reduce the verbosity of kafka module logger
kafka_logger = logging.getLogger("kafka")
kafka_logger.setLevel(logging.ERROR)
brokers = []
- for server in kafka.server.to_std():
+ kafka_conf = config.kafka
+ for server in kafka_conf.server.to_std():
broker = str(server)
brokers.append(broker.replace("@", ":") if server.port else f"{broker}:9092")
- logger.info("Connecting to Kafka brokers...")
+ logger.info("Connecting to Kafka broker(s)...")
try:
consumer = KafkaConsumer(
- str(kafka.topic),
+ str(kafka_conf.topic),
bootstrap_servers=brokers,
client_id=str(self._config.hostname),
- security_protocol=str(kafka.security_protocol).upper(),
- ssl_cafile=str(kafka.ca_file) if kafka.ca_file else None,
- ssl_certfile=str(kafka.cert_file) if kafka.cert_file else None,
- ssl_keyfile=str(kafka.key_file) if kafka.key_file else None,
+ security_protocol=str(kafka_conf.security_protocol).upper(),
+ ssl_cafile=str(kafka_conf.ca_file) if kafka_conf.ca_file else None,
+ ssl_certfile=str(kafka_conf.cert_file) if kafka_conf.cert_file else None,
+ ssl_keyfile=str(kafka_conf.key_file) if kafka_conf.key_file else None,
)
self._consumer = consumer
+
+ self._consumer_timer = Timer(5, self._consume_messages)
+ self._consumer_timer.start()
logger.info("Successfully connected to Kafka broker")
- except NoBrokersAvailable:
- logger.error(f"Connecting to Kafka broker '{kafka.server}' has failed: no broker available")
- self._consumer = None
+ except KafkaError as e:
+ raise KresKafkaClientError(f"Connecting to Kafka broker(s) '{brokers}' has failed") from e
+
+ def deinit(self) -> None:
+ self._consumer_timer.cancel()
+ self._consumer.close()
+ self._consumer = None
+
+ def _consume_messages(self) -> None:
+ if not self._consumer:
+ return
+
+ logger.info("Started consuming messages...")
+ messages: Dict[TopicPartition, List[ConsumerRecord]] = self._consumer.poll(timeout_ms=100)
+
+ for _partition, records in messages.items():
+ for record in records:
+ error_msg_prefix = "Processing message failed with"
+ try:
+ process_record(self._config, record)
+ except KresKafkaClientError as e:
+ logger.error(f"{error_msg_prefix} Kafka client error:\n{e}")
+ except DataParsingError as e:
+ logger.error(f"{error_msg_prefix} data parsing error:\n{e}")
+ except DataValidationError as e:
+ logger.error(f"{error_msg_prefix} data validation error:\n{e}")
+ except Exception as e:
+ logger.error(f"{error_msg_prefix} unknown error:\n{e}")
+
+ # keep consuming if received messages
+ if len(messages) > 0:
+ self._consume_messages()
+ else:
+ # else start new timer
+ self._consumer_timer = Timer(5, self._consume_messages)
+ self._consumer_timer.start()
@only_on_real_changes_update(kafka_config)