-from typing import Any, Callable, Optional, Type, TypeVar
+from __future__ import annotations
-T = TypeVar("T")
+from functools import wraps
+from typing import TYPE_CHECKING, Callable, TypeVar
+
+if TYPE_CHECKING:
+ from typing_extensions import ParamSpec
+
+ P = ParamSpec("P")
+ R = TypeVar("R")
+ T = TypeVar("T")
def ignore_exceptions_optional(
- _tp: Type[T], default: Optional[T], *exceptions: Type[BaseException]
-) -> Callable[[Callable[..., Optional[T]]], Callable[..., Optional[T]]]:
+ exceptions: type[BaseException] | tuple[type[BaseException]], default: T | None
+) -> Callable[[Callable[P, R]], Callable[P, R | T | None]]:
"""
- Wrap function preventing it from raising exceptions and instead returning the configured default value.
+ Prevent exception(s) from being raised and return the configured default value instead..
+
+ Args:
+ exceptions: Exception(s) to catch.
+ default: The default value to return.
+
+ Returns:
+ The value of the decorated function or the default value if an exception is caught.
- :param type[T] _tp: Return type of the function. Essentialy only a template argument for type-checking
- :param T default: The value to return as a default
- :param list[Type[BaseException]] exceptions: The list of exceptions to catch
- :return: value of the decorated function, or default if exception raised
- :rtype: T
"""
- def decorator(func: Callable[..., Optional[T]]) -> Callable[..., Optional[T]]:
- def f(*nargs: Any, **nkwargs: Any) -> Optional[T]:
+ def decorator(func: Callable[P, R]) -> Callable[P, T | (R | None)]:
+ @wraps(func)
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | (R | None):
try:
- return func(*nargs, **nkwargs)
- except BaseException as e:
- if isinstance(e, exceptions):
- return default
- raise
+ return func(*args, **kwargs)
+ except exceptions:
+ return default
- return f
+ return wrapper
return decorator
-
-
-def ignore_exceptions(
- default: T, *exceptions: Type[BaseException]
-) -> Callable[[Callable[..., Optional[T]]], Callable[..., Optional[T]]]:
- return ignore_exceptions_optional(type(default), default, *exceptions)
-
-
-def phantom_use(var: Any) -> None: # pylint: disable=unused-argument
- """
- Consumes argument doing absolutely nothing with it.
-
- Useful for convincing pylint, that we need the variable even when its unused.
- """
import os
import pkgutil
import signal
-import sys
import time
from asyncio import create_subprocess_exec, create_subprocess_shell
-from pathlib import PurePath
-from threading import Thread
-from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
from knot_resolver.utils.compat.asyncio import to_thread
-def unblock_signals() -> None:
- if sys.version_info >= (3, 8):
- signal.pthread_sigmask(signal.SIG_UNBLOCK, signal.valid_signals())
- else:
- # the list of signals is not exhaustive, but it should cover all signals we might ever want to block
- signal.pthread_sigmask(
- signal.SIG_UNBLOCK,
- {
- signal.SIGHUP,
- signal.SIGINT,
- signal.SIGTERM,
- signal.SIGUSR1,
- signal.SIGUSR2,
- },
- )
-
-
async def call(
cmd: Union[str, bytes, List[str], List[bytes]], shell: bool = False, discard_output: bool = False
) -> int:
"""Async alternative to subprocess.call()."""
- kwargs: Dict[str, Any] = {
- "preexec_fn": unblock_signals,
- }
+ kwargs: Dict[str, Any] = {"preexec_fn": signal.pthread_sigmask(signal.SIG_UNBLOCK, signal.valid_signals())}
if discard_output:
kwargs["stdout"] = asyncio.subprocess.DEVNULL
kwargs["stderr"] = asyncio.subprocess.DEVNULL
if shell:
if isinstance(cmd, list):
- raise RuntimeError("can't use list of arguments with shell=True")
+ msg = "can't use list of arguments with shell=True"
+ raise RuntimeError(msg)
proc = await create_subprocess_shell(cmd, **kwargs)
else:
if not isinstance(cmd, list):
- raise RuntimeError(
- "Please use list of arguments, not a single string. It will prevent ambiguity when parsing"
- )
+ msg = "Please use list of arguments, not a single string. It will prevent ambiguity when parsing"
+ raise RuntimeError(msg)
proc = await create_subprocess_exec(*cmd, **kwargs)
return await proc.wait()
-async def readfile(path: Union[str, PurePath]) -> str:
- """Asynchronously read whole file and return its content."""
+async def readfile(path: Path) -> str:
+ """Asynchronously read file on a path and return its content."""
- def readfile_sync(path: Union[str, PurePath]) -> str:
- with open(path, "r", encoding="utf8") as f:
- return f.read()
+ def readfile_sync(path: Path) -> str:
+ with path.open("r", encoding="utf8") as file:
+ return file.read()
return await to_thread(readfile_sync, path)
-async def writefile(path: Union[str, PurePath], content: str) -> None:
- """Asynchronously set content of a file to a given string `content`."""
+async def writefile(path: Path, content: str) -> None:
+ """Asynchronously set content of a file on path to a given string content."""
- def writefile_sync(path: Union[str, PurePath], content: str) -> int:
- with open(path, "w", encoding="utf8") as f:
- return f.write(content)
+ def writefile_sync(path: Path, content: str) -> int:
+ with path.open("w", encoding="utf8") as file:
+ return file.write(content)
await to_thread(writefile_sync, path, content)
async def wait_for_process_termination(pid: int, sleep_sec: float = 0) -> None:
"""
- Wait for the process termination.
+ Wait for any process (does not have to be a child process) given by its PID to terminate.
Will wait for any process (does not have to be a child process)
given by its PID to terminate sleep_sec configures the granularity,
"""
def wait_sync(pid: int, sleep_sec: float) -> None:
- while True:
- try:
+ try:
+ while True:
os.kill(pid, 0)
if sleep_sec == 0:
os.sched_yield()
else:
time.sleep(sleep_sec)
- except ProcessLookupError:
- break
+ except ProcessLookupError:
+ pass
await to_thread(wait_sync, pid, sleep_sec)
async def read_resource(package: str, filename: str) -> Optional[bytes]:
return await to_thread(pkgutil.get_data, package, filename)
-
-
-T = TypeVar("T")
-
-
-class BlockingEventDispatcher(Thread, Generic[T]):
- def __init__(self, name: str = "blocking_event_dispatcher") -> None:
- super().__init__(name=name, daemon=True)
- # warning: the asyncio queue is not thread safe
- self._removed_unit_names: "asyncio.Queue[T]" = asyncio.Queue()
- self._main_event_loop = asyncio.get_event_loop()
-
- def dispatch_event(self, event: T) -> None:
- """Dispatch events from the blocking thread."""
-
- async def add_to_queue() -> None:
- await self._removed_unit_names.put(event)
-
- self._main_event_loop.call_soon_threadsafe(add_to_queue)
-
- async def next_event(self) -> T:
- return await self._removed_unit_names.get()
try:
parsed_url = urlparse(url)
host = unquote(parsed_url.hostname or "(Unknown)")
- except Exception as e:
+ except ValueError as e:
host = f"(Invalid URL: {e})"
msg = f"""
{error_desc}
while path.startswith("/"):
path = path[1:]
url = f"{socket_desc.uri}/{path}"
+
req = Request(
url,
method=method,
data=body.encode("utf8") if body is not None else None,
headers={"Content-Type": content_type},
)
- # req.add_header("Authorization", _authorization_header)
timeout_m = 5 # minutes
try:
"""
Create an HTTP connection to a unix domain socket.
- :param unix_socket_url: A URL with a scheme of 'http+unix' and the
- netloc is a percent-encoded path to a unix domain socket. E.g.:
- 'http+unix://%2Ftmp%2Fprofilesvc.sock/status/pid'
+ Args:
+ unix_socket_url (str): A URL with a scheme of 'http+unix' and the netloc is a percent-encoded path
+ to a unix domain socket. E.g.: 'http+unix://%2Ftmp%2Fprofilesvc.sock/status/pid'
+ timeout (float): Connection timeout.
+
"""
super().__init__("localhost", timeout=timeout)
self.unix_socket_path = unix_socket_url
-import enum
import logging
import os
import socket
logger = logging.getLogger(__name__)
-class _Status(enum.Enum):
- NOT_INITIALIZED = 1
- FUNCTIONAL = 2
- FAILED = 3
-
-
-_status = _Status.NOT_INITIALIZED
-_socket = None
+def systemd_notify(**values: str) -> None:
+ """
+ Send systemd notify message to notify socket.
+
+ Notify socket location (unix socket) should be saved in $NOTIFY_SOCKET environment variable.
+ It is typically set by the processes supervisor (supervisord).
+ If $NOTIFY_SOCKET is not configured, it is not possible to send a notification and the operation will fail.
+ """
+ socket_addr = os.getenv("NOTIFY_SOCKET")
+ if socket_addr is None:
+ logger.warning("Failed to get $NOTIFY_SOCKET environment variable")
+ return
+ if socket_addr.startswith("@"):
+ socket_addr = socket_addr.replace("@", "\0", 1)
-def systemd_notify(**values: str) -> None:
- global _status
- global _socket
-
- if _status is _Status.NOT_INITIALIZED:
- socket_addr = os.getenv("NOTIFY_SOCKET")
- os.unsetenv("NOTIFY_SOCKET")
- if socket_addr is None:
- _status = _Status.FAILED
- return
- if socket_addr.startswith("@"):
- socket_addr = socket_addr.replace("@", "\0", 1)
-
- try:
- _socket = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
- _socket.connect(socket_addr)
- _status = _Status.FUNCTIONAL
- except Exception:
- _socket = None
- _status = _Status.FAILED
- logger.warning(f"Failed to connect to $NOTIFY_SOCKET at '{socket_addr}'", exc_info=True)
- return
-
- elif _status is _Status.FAILED:
+ try:
+ notify_socket = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+ notify_socket.connect(socket_addr)
+ except OSError:
+ logger.exception("Failed to connect to $NOTIFY_SOCKET at '%s'", socket_addr)
return
- if _status is _Status.FUNCTIONAL:
- assert _socket is not None
- payload = "\n".join((f"{key}={value}" for key, value in values.items()))
- try:
- _socket.send(payload.encode("utf8"))
- except Exception:
- logger.warning("Failed to send notification to systemd", exc_info=True)
- _status = _Status.FAILED
- _socket.close()
- _socket = None
+ payload = "\n".join((f"{key}={value}" for key, value in values.items()))
+ try:
+ notify_socket.send(payload.encode("utf8"))
+ except OSError:
+ logger.exception("Failed to send systemd notification to $NOTIFY_SOCKET at '%s'", socket_addr)
+
+ notify_socket.close()
@functools.lru_cache(maxsize=16)
def which(binary_name: str) -> Path:
"""
- Search $PATH and return the absolute path of that executable.
+ Get absolute path of an executable given name.
- The results of this function are LRU cached.
+ Searches in $PATH.
+ The result of the function is LRU cached.
+
+ Args:
+ binary_name (str): The name of the executable binary.
+
+ Returns:
+ Path: Absolute path of the executable.
+
+ Raises:
+ RuntimeError: If the executable was not found.
- If not found, throws an RuntimeError.
"""
possible_directories = os.get_exec_path()
for dr in possible_directories:
- p = Path(dr, binary_name)
- if p.exists():
- return p.absolute()
+ exec_path = Path(dr, binary_name)
+ if exec_path.exists():
+ return exec_path.absolute()
- raise RuntimeError(f"Executable {binary_name} was not found in $PATH")
+ msg = f"The executable '{binary_name}' was not found in $PATH"
+ raise RuntimeError(msg)