From: Michał Kępień Date: Fri, 17 Apr 2026 15:57:05 +0000 (+0200) Subject: Sync asyncserver.py with the development branch X-Git-Tag: v9.18.49~6^2~2 X-Git-Url: http://git.ipfire.org/gitweb/?a=commitdiff_plain;h=b0e8966647e744482edc06e48bc9ff5079a1c541;p=thirdparty%2Fbind9.git Sync asyncserver.py with the development branch Import bin/tests/system/isctest/asyncserver.py as present in commit ced002c4ab7b920c9528d315a611a477cb4a9409 on the "main" branch. This enables using newer asyncserver.py infrastructure code in system tests that need to be backported to maintenance branches. --- diff --git a/bin/tests/system/isctest/asyncserver.py b/bin/tests/system/isctest/asyncserver.py index d35710ba5d4..080c08c3806 100644 --- a/bin/tests/system/isctest/asyncserver.py +++ b/bin/tests/system/isctest/asyncserver.py @@ -11,20 +11,9 @@ See the COPYRIGHT file distributed with this work for additional information regarding copyright ownership. """ +from collections.abc import AsyncGenerator, Callable, Coroutine, Sequence from dataclasses import dataclass, field -from typing import ( - Any, - AsyncGenerator, - Callable, - Coroutine, - Dict, - List, - Optional, - Set, - Tuple, - Union, - cast, -) +from typing import Any, cast import abc import asyncio @@ -52,11 +41,10 @@ import dns.rdataset import dns.rdatatype import dns.rrset import dns.tsig -import dns.version import dns.zone _UdpHandler = Callable[ - [bytes, Tuple[str, int], asyncio.DatagramTransport], Coroutine[Any, Any, None] + [bytes, tuple[str, int], asyncio.DatagramTransport], Coroutine[Any, Any, None] ] @@ -74,7 +62,7 @@ class _AsyncUdpHandler(asyncio.DatagramProtocol): self, handler: _UdpHandler, ) -> None: - self._transport: Optional[asyncio.DatagramTransport] = None + self._transport: asyncio.DatagramTransport | None = None self._handler: _UdpHandler = handler def connection_made(self, transport: asyncio.BaseTransport) -> None: @@ -83,7 +71,7 @@ class _AsyncUdpHandler(asyncio.DatagramProtocol): """ self._transport = cast(asyncio.DatagramTransport, transport) - def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None: + def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: """ Called by asyncio when a datagram is received. """ @@ -108,9 +96,9 @@ class AsyncServer: def __init__( self, - udp_handler: Optional[_UdpHandler], - tcp_handler: Optional[_TcpHandler], - pidfile: Optional[str] = None, + udp_handler: _UdpHandler | None, + tcp_handler: _TcpHandler | None, + pidfile: str | None = None, ) -> None: logging.basicConfig( format="%(asctime)s %(levelname)8s %(message)s", @@ -132,12 +120,12 @@ class AsyncServer: logging.info("Setting up IPv4 listener at %s:%d", ipv4_address, port) logging.info("Setting up IPv6 listener at [%s]:%d", ipv6_address, port) - self._ip_addresses: Tuple[str, str] = (ipv4_address, ipv6_address) + self._ip_addresses: tuple[str, str] = (ipv4_address, ipv6_address) self._port: int = port - self._udp_handler: Optional[_UdpHandler] = udp_handler - self._tcp_handler: Optional[_TcpHandler] = tcp_handler - self._pidfile: Optional[str] = pidfile - self._work_done: Optional[asyncio.Future] = None + self._udp_handler: _UdpHandler | None = udp_handler + self._tcp_handler: _TcpHandler | None = tcp_handler + self._pidfile: str | None = pidfile + self._work_done: asyncio.Future | None = None def _get_ipv4_address_from_directory_name(self) -> str: containing_directory = pathlib.Path().absolute().stem @@ -185,7 +173,7 @@ class AsyncServer: loop.set_exception_handler(self._handle_exception) def _handle_exception( - self, _: asyncio.AbstractEventLoop, context: Dict[str, Any] + self, _: asyncio.AbstractEventLoop, context: dict[str, Any] ) -> None: assert self._work_done exception = context.get("exception", RuntimeError(context["message"])) @@ -265,17 +253,16 @@ class QueryContext: query: dns.message.Message response: dns.message.Message + socket: Peer peer: Peer protocol: DnsProtocol - zone: Optional[dns.zone.Zone] = field(default=None, init=False) - soa: Optional[dns.rrset.RRset] = field(default=None, init=False) - node: Optional[dns.node.Node] = field(default=None, init=False) - answer: Optional[dns.rdataset.Rdataset] = field(default=None, init=False) - alias: Optional[dns.name.Name] = field(default=None, init=False) - _initialized_response: Optional[dns.message.Message] = field( - default=None, init=False - ) - _initialized_response_with_zone_data: Optional[dns.message.Message] = field( + zone: dns.zone.Zone | None = field(default=None, init=False) + soa: dns.rrset.RRset | None = field(default=None, init=False) + node: dns.node.Node | None = field(default=None, init=False) + answer: dns.rdataset.Rdataset | None = field(default=None, init=False) + alias: dns.name.Name | None = field(default=None, init=False) + _initialized_response: dns.message.Message | None = field(default=None, init=False) + _initialized_response_with_zone_data: dns.message.Message | None = field( default=None, init=False ) @@ -320,7 +307,7 @@ class ResponseAction(abc.ABC): """ @abc.abstractmethod - async def perform(self) -> Optional[Union[dns.message.Message, bytes]]: + async def perform(self) -> dns.message.Message | bytes | None: """ This method is expected to carry out arbitrary actions (e.g. wait for a specific amount of time, modify the answer, etc.) and then return the @@ -343,14 +330,30 @@ class DnsResponseSend(ResponseAction): """ response: dns.message.Message - authoritative: Optional[bool] = None + authoritative: bool | None = None delay: float = 0.0 + acknowledge_hand_rolled_response: bool = False - async def perform(self) -> Optional[Union[dns.message.Message, bytes]]: + async def perform(self) -> dns.message.Message | bytes | None: """ Yield a potentially delayed response that is a dns.message.Message. """ assert isinstance(self.response, dns.message.Message) + if not ( + _is_asyncserver_response(self.response) + or self.acknowledge_hand_rolled_response + ): + error = "The response you are trying to send was not created using " + error += "AsyncDnsServer's response preparation methods. " + error += "This will break features such as automatic AA flag " + error += "and RCODE handling. If you need a fresh copy of a " + error += "response, use `QueryContext.prepare_new_response` " + error += "instead of `dns.message.make_response`. " + error += "To acknowledge this and proceed anyway, set " + error += "`acknowledge_hand_rolled_response=True` in " + error += "DnsResponseSend's constructor." + raise RuntimeError(error) + if self.authoritative is not None: if self.authoritative: self.response.flags |= dns.flags.AA @@ -377,7 +380,7 @@ class BytesResponseSend(ResponseAction): response: bytes delay: float = 0.0 - async def perform(self) -> Optional[Union[dns.message.Message, bytes]]: + async def perform(self) -> dns.message.Message | bytes | None: """ Yield a potentially delayed response that is a sequence of bytes. """ @@ -394,7 +397,7 @@ class ResponseDrop(ResponseAction): Action which does nothing - as if a packet was dropped. """ - async def perform(self) -> Optional[Union[dns.message.Message, bytes]]: + async def perform(self) -> dns.message.Message | bytes | None: return None @@ -403,17 +406,16 @@ class _ConnectionTeardownRequested(Exception): @dataclass -class ResponseDropAndCloseConnection(ResponseAction): +class CloseConnection(ResponseAction): """ - Action which makes the server close the connection after the DNS query is - received by the server (TCP only). + Action which makes the server close the connection (TCP only). The connection may be closed with a delay if requested. """ delay: float = 0.0 - async def perform(self) -> Optional[Union[dns.message.Message, bytes]]: + async def perform(self) -> dns.message.Message | bytes | None: if self.delay > 0: logging.info("Waiting %.1fs before closing TCP connection", self.delay) await asyncio.sleep(self.delay) @@ -495,7 +497,7 @@ class IgnoreAllConnections(ConnectionHandler): client socket, effectively ignoring all incoming connections. """ - _connections: Set[asyncio.StreamWriter] = field(default_factory=set) + _connections: set[asyncio.StreamWriter] = field(default_factory=set) async def handle( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer: Peer @@ -529,8 +531,8 @@ class ConnectionReset(ConnectionHandler): make the server send an RST segment; this happens when the server closes a client's socket while there is still unread data in that socket's buffer. If closing the connection _after_ the query is read by the server is enough - for a given use case, the ResponseDropAndCloseConnection response handler - should be used instead. + for a given use case, the CloseConnection response handler should be used + instead. """ delay: float = 0.0 @@ -606,14 +608,14 @@ class QnameHandler(ResponseHandler): @property @abc.abstractmethod - def qnames(self) -> List[str]: + def qnames(self) -> list[str]: """ A list of QNAMEs handled by this class. """ raise NotImplementedError def __init__(self) -> None: - self._qnames: List[dns.name.Name] = [dns.name.from_text(d) for d in self.qnames] + self._qnames: list[dns.name.Name] = [dns.name.from_text(d) for d in self.qnames] def __str__(self) -> str: return f"{self.__class__.__name__}(QNAMEs: {', '.join(self.qnames)})" @@ -626,6 +628,105 @@ class QnameHandler(ResponseHandler): return qctx.qname in self._qnames +class QnameQtypeHandler(QnameHandler): + """ + Handle queries for which both of the following conditions are true: + + - the query's QNAME is present in `self.qnames`, + - the query's QTYPE is present in `self.qtypes`. + """ + + @property + @abc.abstractmethod + def qtypes(self) -> list[dns.rdatatype.RdataType]: + """ + A list of QTYPEs handled by this class. + """ + raise NotImplementedError + + def __init__(self) -> None: + super().__init__() + self._qtypes: list[dns.rdatatype.RdataType] = self.qtypes + + def __str__(self) -> str: + return f"{self.__class__.__name__}(QNAMEs: {', '.join(self.qnames)}; QTYPEs: {', '.join(map(str, self.qtypes))})" + + def match(self, qctx: QueryContext) -> bool: + """ + Handle queries whose QNAME and QTYPE match any of the QNAMEs and + QTYPEs handled by this class. + """ + return qctx.qtype in self._qtypes and super().match(qctx) + + +class StaticResponseHandler(ResponseHandler): + """ + Base class used for deriving custom static response handlers. + + The derived class can specify the RRsets to be included in the answer, + authority, and additional sections of the response, whether to set the AA + bit in the response, and a delay before sending the response. + + The default implementation of `get_responses()` uses these properties to + prepare and yield a single response. + """ + + @property + def rcode(self) -> dns.rcode.Rcode | None: + """ + Optional RCODE to be set in the response. + """ + return None + + @property + def answer(self) -> Sequence[dns.rrset.RRset]: + """ + RRsets to be included in the answer section of the response. + """ + return [] + + @property + def authority(self) -> Sequence[dns.rrset.RRset]: + """ + RRsets to be included in the authority section of the response. + """ + return [] + + @property + def additional(self) -> Sequence[dns.rrset.RRset]: + """ + RRsets to be included in the additional section of the response. + """ + return [] + + @property + def authoritative(self) -> bool | None: + """ + Whether to set the AA bit in the response. + """ + return None + + @property + def delay(self) -> float: + """ + Delay before sending the response. + """ + return 0.0 + + async def get_responses( + self, qctx: QueryContext + ) -> AsyncGenerator[DnsResponseSend, None]: + qctx.prepare_new_response(with_zone_data=False) + qctx.response.answer.extend(self.answer) + qctx.response.authority.extend(self.authority) + qctx.response.additional.extend(self.additional) + if self.rcode is not None: + qctx.response.set_rcode(self.rcode) + yield DnsResponseSend( + qctx.response, authoritative=self.authoritative, delay=self.delay + ) + + class DomainHandler(ResponseHandler): """ Base class used for deriving custom domain handlers. @@ -633,20 +734,28 @@ class DomainHandler(ResponseHandler): The derived class must specify a list of `domains` that it wants to handle. Queries for any of these domains (and their subdomains) will then be passed to the `get_response()` method in the derived class. + + The most specific matching domain is stored in the `matched_domain` attribute. """ @property @abc.abstractmethod - def domains(self) -> List[str]: + def domains(self) -> list[str]: """ A list of domain names handled by this class. """ raise NotImplementedError def __init__(self) -> None: - self._domains: List[dns.name.Name] = [ - dns.name.from_text(d) for d in self.domains - ] + self._domains: list[dns.name.Name] = sorted( + [dns.name.from_text(d) for d in self.domains], reverse=True + ) + self._matched_domain: dns.name.Name | None = None + + @property + def matched_domain(self) -> dns.name.Name: + assert self._matched_domain is not None + return self._matched_domain def __str__(self) -> str: return f"{self.__class__.__name__}(domains: {', '.join(self.domains)})" @@ -656,20 +765,124 @@ class DomainHandler(ResponseHandler): Handle queries whose QNAME matches any of the domains handled by this class. """ + self._matched_domain = None for domain in self._domains: if qctx.qname.is_subdomain(domain): + self._matched_domain = domain return True return False +class ForwarderHandler(ResponseHandler): + """ + A handler forwarding all received queries to another DNS server with an + optional delay and then relaying the responses back to the original client. + + Queries are currently always forwarded via UDP. + """ + + @property + @abc.abstractmethod + def target(self) -> str: + """ + The address of the DNS server to forward queries to. + """ + raise NotImplementedError + + @property + def port(self) -> int: + """ + The port of the DNS server to forward queries to. + + The default value of 0 causes the same port as the one used by this + server for listening to be used. + """ + return 0 + + @property + def delay(self) -> float: + """ + The number of seconds to wait before forwarding each query. + """ + return 0.0 + + def __str__(self) -> str: + return f"{self.__class__.__name__}(target: {self.target}:{self.port})" + + class ForwarderProtocol(asyncio.DatagramProtocol): + def __init__(self, query: bytes, response: asyncio.Future) -> None: + self._query = query + self._response = response + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + logging.debug("[OUT] %s", self._query.hex()) + cast(asyncio.DatagramTransport, transport).sendto(self._query) + + def datagram_received(self, data: bytes, _: tuple[str, int]) -> None: + logging.debug("[IN] %s", data.hex()) + self._response.set_result(data) + + async def get_responses( + self, qctx: QueryContext + ) -> AsyncGenerator[ResponseAction, None]: + loop = asyncio.get_running_loop() + response = loop.create_future() + forwarding_target = f"{self.target}:{self.port or qctx.socket.port}" + + if self.delay > 0: + logging.info( + "Waiting %.1fs before forwarding %s query from %s to %s over UDP", + self.delay, + qctx.protocol.name, + qctx.peer, + forwarding_target, + ) + await asyncio.sleep(self.delay) + + logging.info( + "Forwarding %s query from %s to %s over UDP", + qctx.protocol.name, + qctx.peer, + forwarding_target, + ) + + transport, _ = await loop.create_datagram_endpoint( + lambda: self.ForwarderProtocol(qctx.query.to_wire(), response), + local_addr=(qctx.socket.host, 0), + remote_addr=(self.target, self.port or qctx.socket.port), + ) + + try: + await response + finally: + transport.close() + + logging.info( + "Relaying UDP response from %s to %s over %s", + forwarding_target, + qctx.peer, + qctx.protocol.name, + ) + + try: + message = _DnsMessageWithTsigDisabled.from_wire(response.result()) + yield DnsResponseSend(message, acknowledge_hand_rolled_response=True) + except dns.exception.DNSException: + logging.warning( + "Failed to parse response from %s as a DNS message, relaying it as raw bytes", + forwarding_target, + ) + yield BytesResponseSend(response.result()) + + @dataclass class _ZoneTreeNode: """ A node representing a zone with one origin. """ - zone: Optional[dns.zone.Zone] - children: List["_ZoneTreeNode"] = field(default_factory=list) + zone: dns.zone.Zone | None + children: list["_ZoneTreeNode"] = field(default_factory=list) class _ZoneTree: @@ -719,7 +932,7 @@ class _ZoneTree: node_from.children.remove(child) node_to.children.append(child) - def find_best_zone(self, name: dns.name.Name) -> Optional[dns.zone.Zone]: + def find_best_zone(self, name: dns.name.Name) -> dns.zone.Zone | None: """ Return the closest matching zone (if any) for the domain name. """ @@ -737,7 +950,7 @@ class _DnsMessageWithTsigDisabled(dns.message.Message): """ class _DisableTsigHandling(contextlib.ContextDecorator): - def __init__(self, message: Optional[dns.message.Message] = None) -> None: + def __init__(self, message: dns.message.Message | None = None) -> None: self.original_tsig_sign = dns.tsig.sign self.original_tsig_validate = dns.tsig.validate if message: @@ -749,7 +962,7 @@ class _DnsMessageWithTsigDisabled(dns.message.Message): from failing on messages initialized with `dns.message.from_wire(keyring=False)`. """ - def sign(*_: Any, **__: Any) -> Tuple[dns.rdata.Rdata, None]: + def sign(*_: Any, **__: Any) -> tuple[dns.rdata.Rdata, None]: assert self.tsig return self.tsig[0], None @@ -792,6 +1005,19 @@ class _NoKeyringType: pass +_ASYNCSERVER_RESPONSE_MARKER = "__is_asyncserver_response__" + + +def _make_asyncserver_response(query: dns.message.Message) -> dns.message.Message: + response = dns.message.make_response(query) + setattr(response, _ASYNCSERVER_RESPONSE_MARKER, True) + return response + + +def _is_asyncserver_response(message: dns.message.Message) -> bool: + return getattr(message, _ASYNCSERVER_RESPONSE_MARKER, False) + + class AsyncDnsServer(AsyncServer): """ DNS server which responds to queries based on zone data and/or custom @@ -812,17 +1038,17 @@ class AsyncDnsServer(AsyncServer): self, /, default_rcode: dns.rcode.Rcode = dns.rcode.REFUSED, - default_aa: bool = True, - keyring: Union[ - Dict[dns.name.Name, dns.tsig.Key], None, _NoKeyringType - ] = _NoKeyringType(), + default_aa: bool = False, + keyring: ( + dict[dns.name.Name, dns.tsig.Key] | None | _NoKeyringType + ) = _NoKeyringType(), acknowledge_manual_dname_handling: bool = False, ) -> None: super().__init__(self._handle_udp, self._handle_tcp, "ans.pid") self._zone_tree: _ZoneTree = _ZoneTree() - self._connection_handler: Optional[ConnectionHandler] = None - self._response_handlers: List[ResponseHandler] = [] + self._connection_handler: ConnectionHandler | None = None + self._response_handlers: list[ResponseHandler] = [] self._default_rcode = default_rcode self._default_aa = default_aa self._keyring = keyring @@ -849,10 +1075,18 @@ class AsyncDnsServer(AsyncServer): else: self._response_handlers.append(handler) - def install_response_handlers(self, handlers: List[ResponseHandler]) -> None: + def install_response_handlers(self, *handlers: ResponseHandler) -> None: for handler in handlers: self.install_response_handler(handler) + def replace_response_handlers(self, *new_handlers: ResponseHandler) -> None: + """ + Uninstall all currently installed handlers and install the provided ones. + """ + logging.info("Uninstalling response handlers: %s", str(self._response_handlers)) + self._response_handlers.clear() + self.install_response_handlers(*new_handlers) + def uninstall_response_handler(self, handler: ResponseHandler) -> None: """ Remove the specified handler from the list of response handlers. @@ -923,11 +1157,13 @@ class AsyncDnsServer(AsyncServer): raise ValueError(error) async def _handle_udp( - self, wire: bytes, addr: Tuple[str, int], transport: asyncio.DatagramTransport + self, wire: bytes, addr: tuple[str, int], transport: asyncio.DatagramTransport ) -> None: logging.debug("Received UDP message: %s", wire.hex()) + socket_info = transport.get_extra_info("sockname") + socket = Peer(socket_info[0], socket_info[1]) peer = Peer(addr[0], addr[1]) - responses = self._handle_query(wire, peer, DnsProtocol.UDP) + responses = self._handle_query(wire, socket, peer, DnsProtocol.UDP) async for response in responses: logging.debug("Sending UDP message: %s", response.hex()) transport.sendto(response, addr) @@ -964,7 +1200,7 @@ class AsyncDnsServer(AsyncServer): async def _read_tcp_query( self, reader: asyncio.StreamReader, peer: Peer - ) -> Optional[bytes]: + ) -> bytes | None: wire_length = await self._read_tcp_query_wire_length(reader, peer) if not wire_length: return None @@ -973,7 +1209,7 @@ class AsyncDnsServer(AsyncServer): async def _read_tcp_query_wire_length( self, reader: asyncio.StreamReader, peer: Peer - ) -> Optional[int]: + ) -> int | None: logging.debug("Receiving TCP message length from %s...", peer) wire_length_bytes = await self._read_tcp_octets(reader, peer, 2) @@ -986,7 +1222,7 @@ class AsyncDnsServer(AsyncServer): async def _read_tcp_query_wire( self, reader: asyncio.StreamReader, peer: Peer, wire_length: int - ) -> Optional[bytes]: + ) -> bytes | None: logging.debug("Receiving TCP message (%d octets) from %s...", wire_length, peer) wire = await self._read_tcp_octets(reader, peer, wire_length) @@ -999,7 +1235,7 @@ class AsyncDnsServer(AsyncServer): async def _read_tcp_octets( self, reader: asyncio.StreamReader, peer: Peer, expected: int - ) -> Optional[bytes]: + ) -> bytes | None: buffer = b"" while len(buffer) < expected: @@ -1024,39 +1260,39 @@ class AsyncDnsServer(AsyncServer): async def _send_tcp_response( self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes ) -> None: - responses = self._handle_query(wire, peer, DnsProtocol.TCP) + socket_info = writer.get_extra_info("sockname") + socket = Peer(socket_info[0], socket_info[1]) + responses = self._handle_query(wire, socket, peer, DnsProtocol.TCP) async for response in responses: logging.debug("Sending TCP response: %s", response.hex()) writer.write(response) await writer.drain() - def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None: + def _log_query(self, qctx: QueryContext) -> None: logging.info( - "Received %s/%s/%s (ID=%d) query from %s (%s)", + "Received %s/%s/%s (ID=%d) query from %s on %s (%s)", qctx.qname.to_text(omit_final_dot=True), dns.rdataclass.to_text(qctx.qclass), dns.rdatatype.to_text(qctx.qtype), qctx.query.id, - peer, - protocol.name, + qctx.peer, + qctx.socket, + qctx.protocol.name, ) logging.debug( "\n".join([f"[IN] {l}" for l in [""] + str(qctx.query).splitlines()]) ) def _log_response( - self, - qctx: QueryContext, - response: Optional[Union[dns.message.Message, bytes]], - peer: Peer, - protocol: DnsProtocol, + self, qctx: QueryContext, response: dns.message.Message | bytes | None ) -> None: if not response: logging.info( - "Not sending a response to query (ID=%d) from %s (%s)", + "Not sending a response to query (ID=%d) from %s on %s (%s)", qctx.query.id, - peer, - protocol.name, + qctx.peer, + qctx.socket, + qctx.protocol.name, ) return @@ -1071,7 +1307,7 @@ class AsyncDnsServer(AsyncServer): qtype = "-" logging.info( - "Sending %s/%s/%s (ID=%d) response (%d/%d/%d/%d) to a query (ID=%d) from %s (%s)", + "Sending %s/%s/%s (ID=%d) response (%d/%d/%d/%d) to a query (ID=%d) from %s on %s (%s)", qname, qclass, qtype, @@ -1081,8 +1317,9 @@ class AsyncDnsServer(AsyncServer): len(response.authority), len(response.additional), qctx.query.id, - peer, - protocol.name, + qctx.peer, + qctx.socket, + qctx.protocol.name, ) logging.debug( "\n".join([f"[OUT] {l}" for l in [""] + str(response).splitlines()]) @@ -1090,16 +1327,17 @@ class AsyncDnsServer(AsyncServer): return logging.info( - "Sending response (%d bytes) to a query (ID=%d) from %s (%s)", + "Sending response (%d bytes) to a query (ID=%d) from %s on %s (%s)", len(response), qctx.query.id, - peer, - protocol.name, + qctx.peer, + qctx.socket, + qctx.protocol.name, ) logging.debug("[OUT] %s", response.hex()) async def _handle_query( - self, wire: bytes, peer: Peer, protocol: DnsProtocol + self, wire: bytes, socket: Peer, peer: Peer, protocol: DnsProtocol ) -> AsyncGenerator[bytes, None]: """ Yield wire data to send as a response over the established transport. @@ -1109,12 +1347,12 @@ class AsyncDnsServer(AsyncServer): except dns.exception.DNSException as exc: logging.error("Invalid query from %s (%s): %s", peer, wire.hex(), exc) return - response_stub = dns.message.make_response(query) - qctx = QueryContext(query, response_stub, peer, protocol) - self._log_query(qctx, peer, protocol) + response_stub = _make_asyncserver_response(query) + qctx = QueryContext(query, response_stub, socket, peer, protocol) + self._log_query(qctx) responses = self._prepare_responses(qctx) async for response in responses: - self._log_response(qctx, response, peer, protocol) + self._log_response(qctx, response) if response: if isinstance(response, dns.message.Message): response = response.to_wire(max_size=65535) @@ -1146,7 +1384,7 @@ class AsyncDnsServer(AsyncServer): async def _prepare_responses( self, qctx: QueryContext - ) -> AsyncGenerator[Optional[Union[dns.message.Message, bytes]], None]: + ) -> AsyncGenerator[dns.message.Message | bytes | None, None]: """ Yield response(s) either from response handlers or zone data. """ @@ -1339,10 +1577,10 @@ class ControllableAsyncDnsServer(AsyncDnsServer): return dns.name.from_text(self._CONTROL_DOMAIN) @functools.cached_property - def _commands(self) -> Dict[dns.name.Name, "ControlCommand"]: + def _commands(self) -> dict[dns.name.Name, "ControlCommand"]: return {} - def install_control_commands(self, commands: List["ControlCommand"]) -> None: + def install_control_commands(self, *commands: "ControlCommand") -> None: for command in commands: self.install_control_command(command) @@ -1360,7 +1598,7 @@ class ControllableAsyncDnsServer(AsyncDnsServer): async def _prepare_responses( self, qctx: QueryContext - ) -> AsyncGenerator[Optional[Union[dns.message.Message, bytes]], None]: + ) -> AsyncGenerator[dns.message.Message | bytes | None, None]: """ Detect and handle control queries, falling back to normal processing for non-control queries. @@ -1373,9 +1611,7 @@ class ControllableAsyncDnsServer(AsyncDnsServer): async for response in super()._prepare_responses(qctx): yield response - def _handle_control_command( - self, qctx: QueryContext - ) -> Optional[dns.message.Message]: + def _handle_control_command(self, qctx: QueryContext) -> dns.message.Message | None: """ Detect and handle control queries. @@ -1450,8 +1686,8 @@ class ControlCommand(abc.ABC): @abc.abstractmethod def handle( - self, args: List[str], server: ControllableAsyncDnsServer, qctx: QueryContext - ) -> Optional[str]: + self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext + ) -> str | None: """ This method is expected to carry out arbitrary actions in response to a control query. Note that it is invoked synchronously (it is not a @@ -1489,11 +1725,11 @@ class ToggleResponsesCommand(ControlCommand): control_subdomain = "send-responses" def __init__(self) -> None: - self._current_handler: Optional[IgnoreAllQueries] = None + self._current_handler: IgnoreAllQueries | None = None def handle( - self, args: List[str], server: ControllableAsyncDnsServer, qctx: QueryContext - ) -> Optional[str]: + self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext + ) -> str | None: if len(args) != 1: logging.error("Invalid %s query %s", self, qctx.qname) qctx.response.set_rcode(dns.rcode.SERVFAIL) @@ -1518,3 +1754,30 @@ class ToggleResponsesCommand(ControlCommand): logging.error("Unrecognized response sending mode '%s'", mode) qctx.response.set_rcode(dns.rcode.SERVFAIL) return f"unrecognized response sending mode '{mode}'" + + +class SwitchControlCommand(ControlCommand): + """ + Switch the server's response handlers based on the control query. + + A sequence of response handlers is associated with each key. When a + control query is received, the server's response handlers are replaced + with the sequence associated with the key extracted from the control + query. + """ + + control_subdomain = "switch" + + def __init__(self, handler_mapping: dict[str, Sequence[ResponseHandler]]): + self._handler_mapping = handler_mapping + + def handle( + self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext + ) -> str | None: + if len(args) != 1 or args[0] not in self._handler_mapping: + logging.error("Invalid %s query %s", self, qctx.qname) + qctx.response.set_rcode(dns.rcode.SERVFAIL) + return f"invalid query; exactly one of {list(self._handler_mapping.keys())} is expected in QNAME" + + server.replace_response_handlers(*self._handler_mapping[args[0]]) + return f"switched to handler set '{args[0]}'"