From: Michał Kępień Date: Wed, 20 Mar 2024 08:22:36 +0000 (+0100) Subject: Add an async DNS server for use in system tests X-Git-Tag: v9.19.23~23^2~1 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=6c010a5644324947c8c13b5600cd8d988ae7684f;p=thirdparty%2Fbind9.git Add an async DNS server for use in system tests Implement a new Python class, AsyncDnsServer, which can be used by ans.py scripts placed in ansX/ system test subdirectories. This enables conveniently starting a feature-limited, non-standards-compliant, custom DNS server instance. It can read and serve zone files, but it is also able to evaluate any user-provided query-processing logic, allowing query responses to be changed, delayed, or dropped altogether. These are all actions commonly taken by custom DNS servers written in Python that are used in BIND 9 system tests. Having a single "base" implementation of such a custom DNS server reduces code duplication, improving test maintainability. Co-authored-by: Tom Krizek --- diff --git a/bin/tests/system/conf.sh.in b/bin/tests/system/conf.sh.in index 7182db1f0a4..f09221ae5db 100644 --- a/bin/tests/system/conf.sh.in +++ b/bin/tests/system/conf.sh.in @@ -68,6 +68,8 @@ export KRB5_CONFIG=/dev/null # use local keytab instead of default /etc/krb5.keytab export KRB5_KTNAME=dns.keytab +export ANS_LOG_LEVEL=debug + # # Programs detected by configure # Variables will be empty if no program was found by configure diff --git a/bin/tests/system/isctest/asyncserver.py b/bin/tests/system/isctest/asyncserver.py new file mode 100644 index 00000000000..37eca8a672c --- /dev/null +++ b/bin/tests/system/isctest/asyncserver.py @@ -0,0 +1,799 @@ +""" +Copyright (C) Internet Systems Consortium, Inc. ("ISC") + +SPDX-License-Identifier: MPL-2.0 + +This Source Code Form is subject to the terms of the Mozilla Public +License, v. 2.0. If a copy of the MPL was not distributed with this +file, you can obtain one at https://mozilla.org/MPL/2.0/. + +See the COPYRIGHT file distributed with this work for additional +information regarding copyright ownership. +""" + +from dataclasses import dataclass, field +from typing import ( + Any, + AsyncGenerator, + Callable, + Coroutine, + List, + Optional, + Tuple, + Union, + cast, +) + +import abc +import asyncio +import enum +import functools +import logging +import os +import pathlib +import re +import signal +import struct +import sys + +import dns.flags +import dns.message +import dns.name +import dns.node +import dns.rcode +import dns.rdataclass +import dns.rdatatype +import dns.rrset +import dns.zone + +try: + RdataType = dns.rdatatype.RdataType + RdataClass = dns.rdataclass.RdataClass +except AttributeError: # dnspython < 2.0.0 compat + RdataType = int # type: ignore + RdataClass = int # type: ignore + + +_UdpHandler = Callable[ + [bytes, Tuple[str, int], asyncio.DatagramTransport], Coroutine[Any, Any, None] +] + + +_TcpHandler = Callable[ + [asyncio.StreamReader, asyncio.StreamWriter], Coroutine[Any, Any, None] +] + + +class _AsyncUdpHandler(asyncio.DatagramProtocol): + """ + Protocol implementation for handling UDP traffic using asyncio. + """ + + def __init__( + self, + handler: _UdpHandler, + ) -> None: + self._transport: Optional[asyncio.DatagramTransport] = None + self._handler: _UdpHandler = handler + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """ + Called by asyncio when a connection is made. + """ + self._transport = cast(asyncio.DatagramTransport, transport) + + def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None: + """ + Called by asyncio when a datagram is received. + """ + assert self._transport + handler_coroutine = self._handler(data, addr, self._transport) + try: + # Python >= 3.7 + asyncio.create_task(handler_coroutine) + except AttributeError: + # Python < 3.7 + loop = asyncio.get_event_loop() + loop.create_task(handler_coroutine) + + +# pylint: disable=too-few-public-methods +class AsyncServer: + """ + A generic asynchronous server which may handle UDP and/or TCP traffic. + + Once the server is executed as asyncio coroutine, it will keep running + until a SIGINT/SIGTERM signal is received. + """ + + def __init__( + self, + udp_handler: Optional[_UdpHandler], + tcp_handler: Optional[_TcpHandler], + pidfile: Optional[str] = None, + ) -> None: + logging.basicConfig( + format="%(asctime)s %(levelname)8s %(message)s", + level=os.environ.get("ANS_LOG_LEVEL", "INFO").upper(), + ) + try: + ipv4_address = sys.argv[1] + except IndexError: + ipv4_address = self._get_ipv4_address_from_directory_name() + + last_ipv4_address_octet = ipv4_address.split(".")[-1] + ipv6_address = f"fd92:7065:b8e:ffff::{last_ipv4_address_octet}" + + try: + port = int(sys.argv[2]) + except IndexError: + port = int(os.environ.get("PORT", 5300)) + + logging.info("Setting up IPv4 listener at %s:%d", ipv4_address, port) + logging.info("Setting up IPv6 listener at [%s]:%d", ipv6_address, port) + + self._ip_addresses: Tuple[str, str] = (ipv4_address, ipv6_address) + self._port: int = port + self._udp_handler: Optional[_UdpHandler] = udp_handler + self._tcp_handler: Optional[_TcpHandler] = tcp_handler + self._pidfile: Optional[str] = pidfile + self._work_done: Optional[asyncio.Future] = None + + def _get_ipv4_address_from_directory_name(self) -> str: + containing_directory = pathlib.Path().absolute().stem + match_result = re.match(r"ans(?P\d+)", containing_directory) + if not match_result: + raise RuntimeError("Unable to auto-determine the IPv4 address to use") + + return f"10.53.0.{match_result.group('index')}" + + def run(self) -> None: + """ + Start the server in an asynchronous coroutine. + """ + coroutine = self._run + try: + # Python >= 3.7 + asyncio.run(coroutine()) + except AttributeError: + # Python < 3.7 + loop = asyncio.get_event_loop() + loop.run_until_complete(coroutine()) + + async def _run(self) -> None: + self._setup_signals() + assert self._work_done + await self._listen_udp() + await self._listen_tcp() + self._write_pidfile() + await self._work_done + self._cleanup_pidfile() + + def _get_asyncio_loop(self) -> asyncio.AbstractEventLoop: + try: + # Python >= 3.7 + loop = asyncio.get_running_loop() + except AttributeError: + # Python < 3.7 + loop = asyncio.get_event_loop() + return loop + + def _setup_signals(self) -> None: + loop = self._get_asyncio_loop() + self._work_done = loop.create_future() + loop.add_signal_handler(signal.SIGINT, functools.partial(self._signal_done)) + loop.add_signal_handler(signal.SIGTERM, functools.partial(self._signal_done)) + + def _signal_done(self) -> None: + assert self._work_done + self._work_done.set_result(True) + + async def _listen_udp(self) -> None: + if not self._udp_handler: + return + loop = self._get_asyncio_loop() + for ip_address in self._ip_addresses: + await loop.create_datagram_endpoint( + lambda: _AsyncUdpHandler(cast(_UdpHandler, self._udp_handler)), + (ip_address, self._port), + ) + + async def _listen_tcp(self) -> None: + if not self._tcp_handler: + return + for ip_address in self._ip_addresses: + await asyncio.start_server( + self._tcp_handler, host=ip_address, port=self._port + ) + + def _write_pidfile(self) -> None: + if not self._pidfile: + return + logging.info("Writing PID to %s", self._pidfile) + with open(self._pidfile, "w", encoding="ascii") as pidfile: + print(f"{os.getpid()}", file=pidfile) + + def _cleanup_pidfile(self) -> None: + if not self._pidfile: + return + logging.info("Removing %s", self._pidfile) + os.unlink(self._pidfile) + + +class DnsProtocol(enum.Enum): + UDP = enum.auto() + TCP = enum.auto() + + +# pylint: disable=too-many-instance-attributes +@dataclass +class QueryContext: + """ + Context for the incoming query which may be used for preparing the response. + """ + + query: dns.message.Message + response: dns.message.Message + peer: Tuple[str, int] + protocol: DnsProtocol + zone: Optional[dns.zone.Zone] = None + soa: Optional[dns.rrset.RRset] = None + node: Optional[dns.node.Node] = None + answer: Optional[dns.rdataset.Rdataset] = None + + @property + def qname(self) -> dns.name.Name: + return self.query.question[0].name + + @property + def qclass(self) -> RdataClass: + return self.query.question[0].rdclass + + @property + def qtype(self) -> RdataType: + return self.query.question[0].rdtype + + +@dataclass +class ResponseAction(abc.ABC): + """ + Base class for actions that can be taken in response to a query. + """ + + @abc.abstractmethod + async def perform(self) -> Optional[Union[dns.message.Message, bytes]]: + """ + This method is expected to carry out arbitrary actions (e.g. wait for a + specific amount of time, modify the answer, etc.) and then return the + DNS response to send (a dns.message.Message, a raw bytes object, or + None, which prevents any response from being sent). + """ + raise NotImplementedError + + +@dataclass +class DnsResponseSend(ResponseAction): + """ + Action which yields a dns.message.Message response. + + The response may be sent with a delay if requested. + + Depending on the value of the `authoritative` property, this class may set + the AA bit in the response (True), clear it (False), or not touch it at all + (None). + """ + + response: dns.message.Message + authoritative: Optional[bool] = None + delay: float = 0.0 + + async def perform(self) -> Optional[Union[dns.message.Message, bytes]]: + """ + Yield a potentially delayed response that is a dns.message.Message. + """ + assert isinstance(self.response, dns.message.Message) + if self.authoritative is not None: + if self.authoritative: + self.response.flags |= dns.flags.AA + else: + self.response.flags &= ~dns.flags.AA + if self.delay > 0: + logging.info( + "Delaying response (ID=%d) by %d ms", + self.response.id, + self.delay * 1000, + ) + await asyncio.sleep(self.delay) + return self.response + + +@dataclass +class BytesResponseSend(ResponseAction): + """ + Action which yields a raw response that is a sequence of bytes. + + The response may be sent with a delay if requested. + """ + + response: bytes + delay: float = 0.0 + + async def perform(self) -> Optional[Union[dns.message.Message, bytes]]: + """ + Yield a potentially delayed response that is a sequence of bytes. + """ + assert isinstance(self.response, bytes) + if self.delay > 0: + logging.info("Delaying raw response by %d ms", self.delay * 1000) + await asyncio.sleep(self.delay) + return self.response + + +@dataclass +class ResponseDrop(ResponseAction): + """ + Action which does nothing - as if a packet was dropped. + """ + + async def perform(self) -> Optional[Union[dns.message.Message, bytes]]: + return None + + +class ResponseHandler(abc.ABC): + """ + Base class for generic response handlers. + + If a query passes the `match()` function logic, then it is handled by this + response handler and response(s) may be generated by the `get_responses()` + method. + """ + + @abc.abstractmethod + def match(self, qctx: QueryContext) -> bool: + """ + Matching logic - query is handled when it returns True. + """ + return True + + @abc.abstractmethod + async def get_responses( + self, qctx: QueryContext + ) -> AsyncGenerator[ResponseAction, None]: + """ + Custom handler which may produce response(s) to matching queries. + + The response prepared from zone data is passed to this method in + qctx.response. + """ + yield DnsResponseSend(qctx.response) + + +class DomainHandler(ResponseHandler): + """ + Base class used for deriving custom domain handlers. + + The derived class must specify a list of `domains` that it wants to handle. + Queries for any of these domains (and their subdomains) will then be passed + to the `get_response()` method in the derived class. + """ + + @property + @abc.abstractmethod + def domains(self) -> List[str]: + """ + A list of domain names handled by this class. + """ + raise NotImplementedError + + def __init__(self) -> None: + self._domains: List[dns.name.Name] = [ + dns.name.from_text(d) for d in self.domains + ] + + def __str__(self) -> str: + return f"{self.__class__.__name__}(domains: {', '.join(self.domains)})" + + def match(self, qctx: QueryContext) -> bool: + """ + Handle queries whose QNAME matches any of the domains handled by this + class. + """ + for domain in self._domains: + if qctx.qname.is_subdomain(domain): + return True + return False + + +@dataclass +class _ZoneTreeNode: + """ + A node representing a zone with one origin. + """ + + zone: Optional[dns.zone.Zone] + children: List["_ZoneTreeNode"] = field(default_factory=list) + + +class _ZoneTree: + """ + Tree with independent zones. + + This zone tree is used as a backing structure for the DNS server. The + individual zones are independent to allow the (single) server to serve both + the parent zone and a child zone if needed. + """ + + def __init__(self) -> None: + self._root: _ZoneTreeNode = _ZoneTreeNode(None) + + def add(self, zone: dns.zone.Zone) -> None: + """ + Add a zone to the tree and rearrange sub-zones if necessary. + """ + assert zone.origin + best_match = self._find_best_match(zone.origin, self._root) + added_node = _ZoneTreeNode(zone) + self._move_children(best_match, added_node) + best_match.children.append(added_node) + + def _find_best_match( + self, name: dns.name.Name, start_node: _ZoneTreeNode + ) -> _ZoneTreeNode: + for child in start_node.children: + assert child.zone + assert child.zone.origin + if name.is_subdomain(child.zone.origin): + return self._find_best_match(name, child) + return start_node + + def _move_children(self, node_from: _ZoneTreeNode, node_to: _ZoneTreeNode) -> None: + assert node_to.zone + assert node_to.zone.origin + + children_to_move = [] + for child in node_from.children: + assert child.zone + assert child.zone.origin + if child.zone.origin.is_subdomain(node_to.zone.origin): + children_to_move.append(child) + + for child in children_to_move: + node_from.children.remove(child) + node_to.children.append(child) + + def find_best_zone(self, name: dns.name.Name) -> Optional[dns.zone.Zone]: + """ + Return the closest matching zone (if any) for the domain name. + """ + node = self._find_best_match(name, self._root) + return node.zone if node != self._root else None + + +class AsyncDnsServer(AsyncServer): + """ + DNS server which responds to queries based on zone data and/or custom + handlers. + + The server may use custom handlers which allow arbitrary query processing. + These don't need to be standards-compliant and can be used for testing all + sorts of scenarios, including delaying responses, synthesizing them based + on query contents etc. + + The server also loads any zone files (*.db) found in its directory and + serves them. Responses prepared using zone data can then be modified, + replaced, or suppressed by query handlers. Query handlers can also generate + response from scratch, without using zone data at all. + """ + + def __init__(self, load_zones: bool = True): + super().__init__(self._handle_udp, self._handle_tcp, "ans.pid") + + self._zone_tree: _ZoneTree = _ZoneTree() + self._response_handlers: List[ResponseHandler] = [] + + if load_zones: + self._load_zones() + + def install_response_handler(self, handler: ResponseHandler) -> None: + """ + Add a response handler which will be used to handle matching queries. + + Response handlers can modify, replace, or suppress the answers prepared + from zone file contents. + """ + logging.info("Installing response handler: %s", handler) + self._response_handlers.append(handler) + + def _load_zones(self) -> None: + for entry in os.scandir(): + entry_path = pathlib.Path(entry.path) + if entry_path.suffix != ".db": + continue + origin = dns.name.from_text(entry_path.stem) + logging.info("Loading zone file %s", entry_path) + zone = dns.zone.from_file(entry.path, origin, relativize=False) + self._zone_tree.add(zone) + + async def _handle_udp( + self, wire: bytes, peer: Tuple[str, int], transport: asyncio.DatagramTransport + ) -> None: + logging.debug("Received UDP message: %s", wire.hex()) + responses = self._handle_query(wire, peer, DnsProtocol.UDP) + async for response in responses: + transport.sendto(response, peer) + + async def _handle_tcp( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + wire_length_bytes = await reader.read(2) + (wire_length,) = struct.unpack("!H", wire_length_bytes) + logging.debug("Receiving TCP message (%d octets)...", wire_length) + + wire = await reader.read(wire_length) + full_message = wire_length_bytes + wire + logging.debug("Received complete TCP message: %s", full_message.hex()) + + peer = writer.get_extra_info("peername") + responses = self._handle_query(wire, peer, DnsProtocol.TCP) + async for response in responses: + writer.write(response) + try: + await writer.drain() + except ConnectionResetError: + logging.error( + "TCP connection from %s reset by peer", self._format_peer(peer) + ) + return + + writer.close() + await writer.wait_closed() + + def _format_peer(self, peer: Tuple[str, int]) -> str: + host = peer[0] + port = peer[1] + if "::" in host: + host = f"[{host}]" + return f"{host}:{port}" + + def _log_query( + self, qctx: QueryContext, peer: Tuple[str, int], protocol: DnsProtocol + ) -> None: + logging.info( + "Received %s/%s/%s (ID=%d) query from %s (%s)", + qctx.qname.to_text(omit_final_dot=True), + dns.rdataclass.to_text(qctx.qclass), + dns.rdatatype.to_text(qctx.qtype), + qctx.query.id, + self._format_peer(peer), + protocol.name, + ) + logging.debug( + "\n".join([f"[IN] {l}" for l in [""] + str(qctx.query).splitlines()]) + ) + + def _log_response( + self, + qctx: QueryContext, + response: Optional[Union[dns.message.Message, bytes]], + peer: Tuple[str, int], + protocol: DnsProtocol, + ) -> None: + if not response: + logging.info( + "Not sending a response to query (ID=%d) from %s (%s)", + qctx.query.id, + self._format_peer(peer), + protocol.name, + ) + return + + if isinstance(response, dns.message.Message): + try: + qname = response.question[0].name.to_text(omit_final_dot=True) + qclass = dns.rdataclass.to_text(response.question[0].rdclass) + qtype = dns.rdatatype.to_text(response.question[0].rdtype) + except IndexError: + qname = "" + qclass = "-" + qtype = "-" + + logging.info( + "Sending %s/%s/%s (ID=%d) response (%d/%d/%d/%d) to a query (ID=%d) from %s (%s)", + qname, + qclass, + qtype, + response.id, + len(response.question), + len(response.answer), + len(response.authority), + len(response.additional), + qctx.query.id, + self._format_peer(peer), + protocol.name, + ) + logging.debug( + "\n".join([f"[OUT] {l}" for l in [""] + str(response).splitlines()]) + ) + return + + logging.info( + "Sending response (%d bytes) to a query (ID=%d) from %s (%s)", + len(response), + qctx.query.id, + self._format_peer(peer), + protocol.name, + ) + logging.debug("[OUT] %s", response.hex()) + + async def _handle_query( + self, wire: bytes, peer: Tuple[str, int], protocol: DnsProtocol + ) -> AsyncGenerator[bytes, None]: + """ + Yield wire data to send as a response over the established transport. + """ + query = dns.message.from_wire(wire) + response_stub = dns.message.make_response(query) + qctx = QueryContext(query, response_stub, peer, protocol) + self._log_query(qctx, peer, protocol) + responses = self._prepare_responses(qctx) + async for response in responses: + self._log_response(qctx, response, peer, protocol) + if response: + if isinstance(response, dns.message.Message): + response = response.to_wire(max_size=65535) + if protocol == DnsProtocol.UDP: + yield response + else: + response_length = struct.pack("!H", len(response)) + yield response_length + response + + async def _prepare_responses( + self, qctx: QueryContext + ) -> AsyncGenerator[Optional[Union[dns.message.Message, bytes]], None]: + """ + Yield response(s) either from response handlers or zone data. + """ + self._prepare_response_from_zone_data(qctx) + + response_handled = False + async for action in self._run_response_handlers(qctx): + yield await action.perform() + response_handled = True + + if not response_handled: + yield qctx.response + + def _prepare_response_from_zone_data(self, qctx: QueryContext) -> None: + """ + Prepare a response to the query based on the available zone data. + + The functionality is split across smaller functions that modify the + query context until a proper response is formed. + """ + if self._refused_response(qctx): + return + + if self._delegation_response(qctx): + return + + qctx.response.flags |= dns.flags.AA + + if self._ent_response(qctx): + return + + if self._nxdomain_response(qctx): + return + + if self._nodata_response(qctx): + return + + self._noerror_response(qctx) + + def _refused_response(self, qctx: QueryContext) -> bool: + qctx.zone = self._zone_tree.find_best_zone(qctx.qname) + if qctx.zone: + return False + + qctx.response.set_rcode(dns.rcode.REFUSED) + return True + + def _delegation_response(self, qctx: QueryContext) -> bool: + assert qctx.zone + + name = qctx.qname + delegation = None + + while name != qctx.zone.origin: + node = qctx.zone.get_node(name) + if node: + delegation = node.get_rdataset(qctx.qclass, dns.rdatatype.NS) + if delegation: + break + name = name.parent() + + if not delegation: + return False + + delegation_rrset = dns.rrset.RRset(name, qctx.qclass, dns.rdatatype.NS) + delegation_rrset.update(delegation) + + qctx.response.set_rcode(dns.rcode.NOERROR) + qctx.response.authority.append(delegation_rrset) + + self._delegation_response_additional(qctx) + + return True + + def _delegation_response_additional(self, qctx: QueryContext) -> None: + assert qctx.zone + assert qctx.response.authority[0] + + for nameserver in qctx.response.authority[0]: + if not nameserver.target.is_subdomain(qctx.response.authority[0].name): + continue + glue_a = qctx.zone.get_rrset(nameserver.target, dns.rdatatype.A) + if glue_a: + qctx.response.additional.append(glue_a) + glue_aaaa = qctx.zone.get_rrset(nameserver.target, dns.rdatatype.AAAA) + if glue_aaaa: + qctx.response.additional.append(glue_aaaa) + + def _ent_response(self, qctx: QueryContext) -> bool: + assert qctx.zone + assert qctx.zone.origin + + qctx.soa = qctx.zone.find_rrset(qctx.zone.origin, dns.rdatatype.SOA) + assert qctx.soa + + qctx.node = qctx.zone.get_node(qctx.qname) + if qctx.node or not any( + n for n in qctx.zone.nodes if n.is_subdomain(qctx.qname) + ): + return False + + qctx.response.set_rcode(dns.rcode.NOERROR) + qctx.response.authority.append(qctx.soa) + return True + + def _nxdomain_response(self, qctx: QueryContext) -> bool: + assert qctx.soa + + if qctx.node: + return False + + qctx.response.set_rcode(dns.rcode.NXDOMAIN) + qctx.response.authority.append(qctx.soa) + return True + + def _nodata_response(self, qctx: QueryContext) -> bool: + assert qctx.node + assert qctx.soa + + qctx.answer = qctx.node.get_rdataset(qctx.qclass, qctx.qtype) + if qctx.answer: + return False + + qctx.response.set_rcode(dns.rcode.NOERROR) + qctx.response.authority.append(qctx.soa) + return True + + def _noerror_response(self, qctx: QueryContext) -> None: + assert qctx.answer + + answer_rrset = dns.rrset.RRset(qctx.qname, qctx.qclass, qctx.qtype) + answer_rrset.update(qctx.answer) + + qctx.response.set_rcode(dns.rcode.NOERROR) + qctx.response.answer.append(answer_rrset) + + async def _run_response_handlers( + self, qctx: QueryContext + ) -> AsyncGenerator[ResponseAction, None]: + """ + Yield response(s) to the query from a matching query handler. + """ + for handler in self._response_handlers: + if handler.match(qctx): + async for response in handler.get_responses(qctx): + yield response + return diff --git a/bin/tests/system/start.pl b/bin/tests/system/start.pl index 32d20741688..67c33b86f24 100755 --- a/bin/tests/system/start.pl +++ b/bin/tests/system/start.pl @@ -323,6 +323,7 @@ sub construct_ans_command { } if (-e "$testdir/$server/ans.py") { + $ENV{'PYTHONPATH'} = $testdir . ":" . $ENV{'srcdir'}; $command = "$PYTHON -u ans.py 10.53.0.$n $queryport"; } elsif (-e "$testdir/$server/ans.pl") { $command = "$PERL ans.pl";