]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Sync asyncserver.py with the development branch
authorMichał Kępień <michal@isc.org>
Fri, 17 Apr 2026 15:57:05 +0000 (17:57 +0200)
committerMichał Kępień <michal@isc.org>
Thu, 7 May 2026 11:21:59 +0000 (13:21 +0200)
Import bin/tests/system/isctest/asyncserver.py as present in commit
ced002c4ab7b920c9528d315a611a477cb4a9409 on the "main" branch.  This
enables using newer asyncserver.py infrastructure code in system tests
that need to be backported to maintenance branches.

bin/tests/system/isctest/asyncserver.py

index d35710ba5d4004ae3b496ed9ab6ca5b75b5cbbe2..080c08c380684af2e497fcfd679b0261a48bfce6 100644 (file)
@@ -11,20 +11,9 @@ See the COPYRIGHT file distributed with this work for additional
 information regarding copyright ownership.
 """
 
+from collections.abc import AsyncGenerator, Callable, Coroutine, Sequence
 from dataclasses import dataclass, field
-from typing import (
-    Any,
-    AsyncGenerator,
-    Callable,
-    Coroutine,
-    Dict,
-    List,
-    Optional,
-    Set,
-    Tuple,
-    Union,
-    cast,
-)
+from typing import Any, cast
 
 import abc
 import asyncio
@@ -52,11 +41,10 @@ import dns.rdataset
 import dns.rdatatype
 import dns.rrset
 import dns.tsig
-import dns.version
 import dns.zone
 
 _UdpHandler = Callable[
-    [bytes, Tuple[str, int], asyncio.DatagramTransport], Coroutine[Any, Any, None]
+    [bytes, tuple[str, int], asyncio.DatagramTransport], Coroutine[Any, Any, None]
 ]
 
 
@@ -74,7 +62,7 @@ class _AsyncUdpHandler(asyncio.DatagramProtocol):
         self,
         handler: _UdpHandler,
     ) -> None:
-        self._transport: Optional[asyncio.DatagramTransport] = None
+        self._transport: asyncio.DatagramTransport | None = None
         self._handler: _UdpHandler = handler
 
     def connection_made(self, transport: asyncio.BaseTransport) -> None:
@@ -83,7 +71,7 @@ class _AsyncUdpHandler(asyncio.DatagramProtocol):
         """
         self._transport = cast(asyncio.DatagramTransport, transport)
 
-    def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None:
+    def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
         """
         Called by asyncio when a datagram is received.
         """
@@ -108,9 +96,9 @@ class AsyncServer:
 
     def __init__(
         self,
-        udp_handler: Optional[_UdpHandler],
-        tcp_handler: Optional[_TcpHandler],
-        pidfile: Optional[str] = None,
+        udp_handler: _UdpHandler | None,
+        tcp_handler: _TcpHandler | None,
+        pidfile: str | None = None,
     ) -> None:
         logging.basicConfig(
             format="%(asctime)s %(levelname)8s  %(message)s",
@@ -132,12 +120,12 @@ class AsyncServer:
         logging.info("Setting up IPv4 listener at %s:%d", ipv4_address, port)
         logging.info("Setting up IPv6 listener at [%s]:%d", ipv6_address, port)
 
-        self._ip_addresses: Tuple[str, str] = (ipv4_address, ipv6_address)
+        self._ip_addresses: tuple[str, str] = (ipv4_address, ipv6_address)
         self._port: int = port
-        self._udp_handler: Optional[_UdpHandler] = udp_handler
-        self._tcp_handler: Optional[_TcpHandler] = tcp_handler
-        self._pidfile: Optional[str] = pidfile
-        self._work_done: Optional[asyncio.Future] = None
+        self._udp_handler: _UdpHandler | None = udp_handler
+        self._tcp_handler: _TcpHandler | None = tcp_handler
+        self._pidfile: str | None = pidfile
+        self._work_done: asyncio.Future | None = None
 
     def _get_ipv4_address_from_directory_name(self) -> str:
         containing_directory = pathlib.Path().absolute().stem
@@ -185,7 +173,7 @@ class AsyncServer:
         loop.set_exception_handler(self._handle_exception)
 
     def _handle_exception(
-        self, _: asyncio.AbstractEventLoop, context: Dict[str, Any]
+        self, _: asyncio.AbstractEventLoop, context: dict[str, Any]
     ) -> None:
         assert self._work_done
         exception = context.get("exception", RuntimeError(context["message"]))
@@ -265,17 +253,16 @@ class QueryContext:
 
     query: dns.message.Message
     response: dns.message.Message
+    socket: Peer
     peer: Peer
     protocol: DnsProtocol
-    zone: Optional[dns.zone.Zone] = field(default=None, init=False)
-    soa: Optional[dns.rrset.RRset] = field(default=None, init=False)
-    node: Optional[dns.node.Node] = field(default=None, init=False)
-    answer: Optional[dns.rdataset.Rdataset] = field(default=None, init=False)
-    alias: Optional[dns.name.Name] = field(default=None, init=False)
-    _initialized_response: Optional[dns.message.Message] = field(
-        default=None, init=False
-    )
-    _initialized_response_with_zone_data: Optional[dns.message.Message] = field(
+    zone: dns.zone.Zone | None = field(default=None, init=False)
+    soa: dns.rrset.RRset | None = field(default=None, init=False)
+    node: dns.node.Node | None = field(default=None, init=False)
+    answer: dns.rdataset.Rdataset | None = field(default=None, init=False)
+    alias: dns.name.Name | None = field(default=None, init=False)
+    _initialized_response: dns.message.Message | None = field(default=None, init=False)
+    _initialized_response_with_zone_data: dns.message.Message | None = field(
         default=None, init=False
     )
 
@@ -320,7 +307,7 @@ class ResponseAction(abc.ABC):
     """
 
     @abc.abstractmethod
-    async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
+    async def perform(self) -> dns.message.Message | bytes | None:
         """
         This method is expected to carry out arbitrary actions (e.g. wait for a
         specific amount of time, modify the answer, etc.) and then return the
@@ -343,14 +330,30 @@ class DnsResponseSend(ResponseAction):
     """
 
     response: dns.message.Message
-    authoritative: Optional[bool] = None
+    authoritative: bool | None = None
     delay: float = 0.0
+    acknowledge_hand_rolled_response: bool = False
 
-    async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
+    async def perform(self) -> dns.message.Message | bytes | None:
         """
         Yield a potentially delayed response that is a dns.message.Message.
         """
         assert isinstance(self.response, dns.message.Message)
+        if not (
+            _is_asyncserver_response(self.response)
+            or self.acknowledge_hand_rolled_response
+        ):
+            error = "The response you are trying to send was not created using "
+            error += "AsyncDnsServer's response preparation methods. "
+            error += "This will break features such as automatic AA flag "
+            error += "and RCODE handling. If you need a fresh copy of a "
+            error += "response, use `QueryContext.prepare_new_response` "
+            error += "instead of `dns.message.make_response`. "
+            error += "To acknowledge this and proceed anyway, set "
+            error += "`acknowledge_hand_rolled_response=True` in "
+            error += "DnsResponseSend's constructor."
+            raise RuntimeError(error)
+
         if self.authoritative is not None:
             if self.authoritative:
                 self.response.flags |= dns.flags.AA
@@ -377,7 +380,7 @@ class BytesResponseSend(ResponseAction):
     response: bytes
     delay: float = 0.0
 
-    async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
+    async def perform(self) -> dns.message.Message | bytes | None:
         """
         Yield a potentially delayed response that is a sequence of bytes.
         """
@@ -394,7 +397,7 @@ class ResponseDrop(ResponseAction):
     Action which does nothing - as if a packet was dropped.
     """
 
-    async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
+    async def perform(self) -> dns.message.Message | bytes | None:
         return None
 
 
@@ -403,17 +406,16 @@ class _ConnectionTeardownRequested(Exception):
 
 
 @dataclass
-class ResponseDropAndCloseConnection(ResponseAction):
+class CloseConnection(ResponseAction):
     """
-    Action which makes the server close the connection after the DNS query is
-    received by the server (TCP only).
+    Action which makes the server close the connection (TCP only).
 
     The connection may be closed with a delay if requested.
     """
 
     delay: float = 0.0
 
-    async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
+    async def perform(self) -> dns.message.Message | bytes | None:
         if self.delay > 0:
             logging.info("Waiting %.1fs before closing TCP connection", self.delay)
             await asyncio.sleep(self.delay)
@@ -495,7 +497,7 @@ class IgnoreAllConnections(ConnectionHandler):
     client socket, effectively ignoring all incoming connections.
     """
 
-    _connections: Set[asyncio.StreamWriter] = field(default_factory=set)
+    _connections: set[asyncio.StreamWriter] = field(default_factory=set)
 
     async def handle(
         self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer: Peer
@@ -529,8 +531,8 @@ class ConnectionReset(ConnectionHandler):
     make the server send an RST segment; this happens when the server closes a
     client's socket while there is still unread data in that socket's buffer.
     If closing the connection _after_ the query is read by the server is enough
-    for a given use case, the ResponseDropAndCloseConnection response handler
-    should be used instead.
+    for a given use case, the CloseConnection response handler should be used
+    instead.
     """
 
     delay: float = 0.0
@@ -606,14 +608,14 @@ class QnameHandler(ResponseHandler):
 
     @property
     @abc.abstractmethod
-    def qnames(self) -> List[str]:
+    def qnames(self) -> list[str]:
         """
         A list of QNAMEs handled by this class.
         """
         raise NotImplementedError
 
     def __init__(self) -> None:
-        self._qnames: List[dns.name.Name] = [dns.name.from_text(d) for d in self.qnames]
+        self._qnames: list[dns.name.Name] = [dns.name.from_text(d) for d in self.qnames]
 
     def __str__(self) -> str:
         return f"{self.__class__.__name__}(QNAMEs: {', '.join(self.qnames)})"
@@ -626,6 +628,105 @@ class QnameHandler(ResponseHandler):
         return qctx.qname in self._qnames
 
 
+class QnameQtypeHandler(QnameHandler):
+    """
+    Handle queries for which both of the following conditions are true:
+
+    - the query's QNAME is present in `self.qnames`,
+    - the query's QTYPE is present in `self.qtypes`.
+    """
+
+    @property
+    @abc.abstractmethod
+    def qtypes(self) -> list[dns.rdatatype.RdataType]:
+        """
+        A list of QTYPEs handled by this class.
+        """
+        raise NotImplementedError
+
+    def __init__(self) -> None:
+        super().__init__()
+        self._qtypes: list[dns.rdatatype.RdataType] = self.qtypes
+
+    def __str__(self) -> str:
+        return f"{self.__class__.__name__}(QNAMEs: {', '.join(self.qnames)}; QTYPEs: {', '.join(map(str, self.qtypes))})"
+
+    def match(self, qctx: QueryContext) -> bool:
+        """
+        Handle queries whose QNAME and QTYPE match any of the QNAMEs and
+        QTYPEs handled by this class.
+        """
+        return qctx.qtype in self._qtypes and super().match(qctx)
+
+
+class StaticResponseHandler(ResponseHandler):
+    """
+    Base class used for deriving custom static response handlers.
+
+    The derived class can specify the RRsets to be included in the answer,
+    authority, and additional sections of the response, whether to set the AA
+    bit in the response, and a delay before sending the response.
+
+    The default implementation of `get_responses()` uses these properties to
+    prepare and yield a single response.
+    """
+
+    @property
+    def rcode(self) -> dns.rcode.Rcode | None:
+        """
+        Optional RCODE to be set in the response.
+        """
+        return None
+
+    @property
+    def answer(self) -> Sequence[dns.rrset.RRset]:
+        """
+        RRsets to be included in the answer section of the response.
+        """
+        return []
+
+    @property
+    def authority(self) -> Sequence[dns.rrset.RRset]:
+        """
+        RRsets to be included in the authority section of the response.
+        """
+        return []
+
+    @property
+    def additional(self) -> Sequence[dns.rrset.RRset]:
+        """
+        RRsets to be included in the additional section of the response.
+        """
+        return []
+
+    @property
+    def authoritative(self) -> bool | None:
+        """
+        Whether to set the AA bit in the response.
+        """
+        return None
+
+    @property
+    def delay(self) -> float:
+        """
+        Delay before sending the response.
+        """
+        return 0.0
+
+    async def get_responses(
+        self, qctx: QueryContext
+    ) -> AsyncGenerator[DnsResponseSend, None]:
+        qctx.prepare_new_response(with_zone_data=False)
+        qctx.response.answer.extend(self.answer)
+        qctx.response.authority.extend(self.authority)
+        qctx.response.additional.extend(self.additional)
+        if self.rcode is not None:
+            qctx.response.set_rcode(self.rcode)
+        yield DnsResponseSend(
+            qctx.response, authoritative=self.authoritative, delay=self.delay
+        )
+
+
 class DomainHandler(ResponseHandler):
     """
     Base class used for deriving custom domain handlers.
@@ -633,20 +734,28 @@ class DomainHandler(ResponseHandler):
     The derived class must specify a list of `domains` that it wants to handle.
     Queries for any of these domains (and their subdomains) will then be passed
     to the `get_response()` method in the derived class.
+
+    The most specific matching domain is stored in the `matched_domain` attribute.
     """
 
     @property
     @abc.abstractmethod
-    def domains(self) -> List[str]:
+    def domains(self) -> list[str]:
         """
         A list of domain names handled by this class.
         """
         raise NotImplementedError
 
     def __init__(self) -> None:
-        self._domains: List[dns.name.Name] = [
-            dns.name.from_text(d) for d in self.domains
-        ]
+        self._domains: list[dns.name.Name] = sorted(
+            [dns.name.from_text(d) for d in self.domains], reverse=True
+        )
+        self._matched_domain: dns.name.Name | None = None
+
+    @property
+    def matched_domain(self) -> dns.name.Name:
+        assert self._matched_domain is not None
+        return self._matched_domain
 
     def __str__(self) -> str:
         return f"{self.__class__.__name__}(domains: {', '.join(self.domains)})"
@@ -656,20 +765,124 @@ class DomainHandler(ResponseHandler):
         Handle queries whose QNAME matches any of the domains handled by this
         class.
         """
+        self._matched_domain = None
         for domain in self._domains:
             if qctx.qname.is_subdomain(domain):
+                self._matched_domain = domain
                 return True
         return False
 
 
+class ForwarderHandler(ResponseHandler):
+    """
+    A handler forwarding all received queries to another DNS server with an
+    optional delay and then relaying the responses back to the original client.
+
+    Queries are currently always forwarded via UDP.
+    """
+
+    @property
+    @abc.abstractmethod
+    def target(self) -> str:
+        """
+        The address of the DNS server to forward queries to.
+        """
+        raise NotImplementedError
+
+    @property
+    def port(self) -> int:
+        """
+        The port of the DNS server to forward queries to.
+
+        The default value of 0 causes the same port as the one used by this
+        server for listening to be used.
+        """
+        return 0
+
+    @property
+    def delay(self) -> float:
+        """
+        The number of seconds to wait before forwarding each query.
+        """
+        return 0.0
+
+    def __str__(self) -> str:
+        return f"{self.__class__.__name__}(target: {self.target}:{self.port})"
+
+    class ForwarderProtocol(asyncio.DatagramProtocol):
+        def __init__(self, query: bytes, response: asyncio.Future) -> None:
+            self._query = query
+            self._response = response
+
+        def connection_made(self, transport: asyncio.BaseTransport) -> None:
+            logging.debug("[OUT] %s", self._query.hex())
+            cast(asyncio.DatagramTransport, transport).sendto(self._query)
+
+        def datagram_received(self, data: bytes, _: tuple[str, int]) -> None:
+            logging.debug("[IN] %s", data.hex())
+            self._response.set_result(data)
+
+    async def get_responses(
+        self, qctx: QueryContext
+    ) -> AsyncGenerator[ResponseAction, None]:
+        loop = asyncio.get_running_loop()
+        response = loop.create_future()
+        forwarding_target = f"{self.target}:{self.port or qctx.socket.port}"
+
+        if self.delay > 0:
+            logging.info(
+                "Waiting %.1fs before forwarding %s query from %s to %s over UDP",
+                self.delay,
+                qctx.protocol.name,
+                qctx.peer,
+                forwarding_target,
+            )
+            await asyncio.sleep(self.delay)
+
+        logging.info(
+            "Forwarding %s query from %s to %s over UDP",
+            qctx.protocol.name,
+            qctx.peer,
+            forwarding_target,
+        )
+
+        transport, _ = await loop.create_datagram_endpoint(
+            lambda: self.ForwarderProtocol(qctx.query.to_wire(), response),
+            local_addr=(qctx.socket.host, 0),
+            remote_addr=(self.target, self.port or qctx.socket.port),
+        )
+
+        try:
+            await response
+        finally:
+            transport.close()
+
+        logging.info(
+            "Relaying UDP response from %s to %s over %s",
+            forwarding_target,
+            qctx.peer,
+            qctx.protocol.name,
+        )
+
+        try:
+            message = _DnsMessageWithTsigDisabled.from_wire(response.result())
+            yield DnsResponseSend(message, acknowledge_hand_rolled_response=True)
+        except dns.exception.DNSException:
+            logging.warning(
+                "Failed to parse response from %s as a DNS message, relaying it as raw bytes",
+                forwarding_target,
+            )
+            yield BytesResponseSend(response.result())
+
+
 @dataclass
 class _ZoneTreeNode:
     """
     A node representing a zone with one origin.
     """
 
-    zone: Optional[dns.zone.Zone]
-    children: List["_ZoneTreeNode"] = field(default_factory=list)
+    zone: dns.zone.Zone | None
+    children: list["_ZoneTreeNode"] = field(default_factory=list)
 
 
 class _ZoneTree:
@@ -719,7 +932,7 @@ class _ZoneTree:
             node_from.children.remove(child)
             node_to.children.append(child)
 
-    def find_best_zone(self, name: dns.name.Name) -> Optional[dns.zone.Zone]:
+    def find_best_zone(self, name: dns.name.Name) -> dns.zone.Zone | None:
         """
         Return the closest matching zone (if any) for the domain name.
         """
@@ -737,7 +950,7 @@ class _DnsMessageWithTsigDisabled(dns.message.Message):
     """
 
     class _DisableTsigHandling(contextlib.ContextDecorator):
-        def __init__(self, message: Optional[dns.message.Message] = None) -> None:
+        def __init__(self, message: dns.message.Message | None = None) -> None:
             self.original_tsig_sign = dns.tsig.sign
             self.original_tsig_validate = dns.tsig.validate
             if message:
@@ -749,7 +962,7 @@ class _DnsMessageWithTsigDisabled(dns.message.Message):
             from failing on messages initialized with `dns.message.from_wire(keyring=False)`.
             """
 
-            def sign(*_: Any, **__: Any) -> Tuple[dns.rdata.Rdata, None]:
+            def sign(*_: Any, **__: Any) -> tuple[dns.rdata.Rdata, None]:
                 assert self.tsig
                 return self.tsig[0], None
 
@@ -792,6 +1005,19 @@ class _NoKeyringType:
     pass
 
 
+_ASYNCSERVER_RESPONSE_MARKER = "__is_asyncserver_response__"
+
+
+def _make_asyncserver_response(query: dns.message.Message) -> dns.message.Message:
+    response = dns.message.make_response(query)
+    setattr(response, _ASYNCSERVER_RESPONSE_MARKER, True)
+    return response
+
+
+def _is_asyncserver_response(message: dns.message.Message) -> bool:
+    return getattr(message, _ASYNCSERVER_RESPONSE_MARKER, False)
+
+
 class AsyncDnsServer(AsyncServer):
     """
     DNS server which responds to queries based on zone data and/or custom
@@ -812,17 +1038,17 @@ class AsyncDnsServer(AsyncServer):
         self,
         /,
         default_rcode: dns.rcode.Rcode = dns.rcode.REFUSED,
-        default_aa: bool = True,
-        keyring: Union[
-            Dict[dns.name.Name, dns.tsig.Key], None, _NoKeyringType
-        ] = _NoKeyringType(),
+        default_aa: bool = False,
+        keyring: (
+            dict[dns.name.Name, dns.tsig.Key] | None | _NoKeyringType
+        ) = _NoKeyringType(),
         acknowledge_manual_dname_handling: bool = False,
     ) -> None:
         super().__init__(self._handle_udp, self._handle_tcp, "ans.pid")
 
         self._zone_tree: _ZoneTree = _ZoneTree()
-        self._connection_handler: Optional[ConnectionHandler] = None
-        self._response_handlers: List[ResponseHandler] = []
+        self._connection_handler: ConnectionHandler | None = None
+        self._response_handlers: list[ResponseHandler] = []
         self._default_rcode = default_rcode
         self._default_aa = default_aa
         self._keyring = keyring
@@ -849,10 +1075,18 @@ class AsyncDnsServer(AsyncServer):
         else:
             self._response_handlers.append(handler)
 
-    def install_response_handlers(self, handlers: List[ResponseHandler]) -> None:
+    def install_response_handlers(self, *handlers: ResponseHandler) -> None:
         for handler in handlers:
             self.install_response_handler(handler)
 
+    def replace_response_handlers(self, *new_handlers: ResponseHandler) -> None:
+        """
+        Uninstall all currently installed handlers and install the provided ones.
+        """
+        logging.info("Uninstalling response handlers: %s", str(self._response_handlers))
+        self._response_handlers.clear()
+        self.install_response_handlers(*new_handlers)
+
     def uninstall_response_handler(self, handler: ResponseHandler) -> None:
         """
         Remove the specified handler from the list of response handlers.
@@ -923,11 +1157,13 @@ class AsyncDnsServer(AsyncServer):
                     raise ValueError(error)
 
     async def _handle_udp(
-        self, wire: bytes, addr: Tuple[str, int], transport: asyncio.DatagramTransport
+        self, wire: bytes, addr: tuple[str, int], transport: asyncio.DatagramTransport
     ) -> None:
         logging.debug("Received UDP message: %s", wire.hex())
+        socket_info = transport.get_extra_info("sockname")
+        socket = Peer(socket_info[0], socket_info[1])
         peer = Peer(addr[0], addr[1])
-        responses = self._handle_query(wire, peer, DnsProtocol.UDP)
+        responses = self._handle_query(wire, socket, peer, DnsProtocol.UDP)
         async for response in responses:
             logging.debug("Sending UDP message: %s", response.hex())
             transport.sendto(response, addr)
@@ -964,7 +1200,7 @@ class AsyncDnsServer(AsyncServer):
 
     async def _read_tcp_query(
         self, reader: asyncio.StreamReader, peer: Peer
-    ) -> Optional[bytes]:
+    ) -> bytes | None:
         wire_length = await self._read_tcp_query_wire_length(reader, peer)
         if not wire_length:
             return None
@@ -973,7 +1209,7 @@ class AsyncDnsServer(AsyncServer):
 
     async def _read_tcp_query_wire_length(
         self, reader: asyncio.StreamReader, peer: Peer
-    ) -> Optional[int]:
+    ) -> int | None:
         logging.debug("Receiving TCP message length from %s...", peer)
 
         wire_length_bytes = await self._read_tcp_octets(reader, peer, 2)
@@ -986,7 +1222,7 @@ class AsyncDnsServer(AsyncServer):
 
     async def _read_tcp_query_wire(
         self, reader: asyncio.StreamReader, peer: Peer, wire_length: int
-    ) -> Optional[bytes]:
+    ) -> bytes | None:
         logging.debug("Receiving TCP message (%d octets) from %s...", wire_length, peer)
 
         wire = await self._read_tcp_octets(reader, peer, wire_length)
@@ -999,7 +1235,7 @@ class AsyncDnsServer(AsyncServer):
 
     async def _read_tcp_octets(
         self, reader: asyncio.StreamReader, peer: Peer, expected: int
-    ) -> Optional[bytes]:
+    ) -> bytes | None:
         buffer = b""
 
         while len(buffer) < expected:
@@ -1024,39 +1260,39 @@ class AsyncDnsServer(AsyncServer):
     async def _send_tcp_response(
         self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes
     ) -> None:
-        responses = self._handle_query(wire, peer, DnsProtocol.TCP)
+        socket_info = writer.get_extra_info("sockname")
+        socket = Peer(socket_info[0], socket_info[1])
+        responses = self._handle_query(wire, socket, peer, DnsProtocol.TCP)
         async for response in responses:
             logging.debug("Sending TCP response: %s", response.hex())
             writer.write(response)
             await writer.drain()
 
-    def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None:
+    def _log_query(self, qctx: QueryContext) -> None:
         logging.info(
-            "Received %s/%s/%s (ID=%d) query from %s (%s)",
+            "Received %s/%s/%s (ID=%d) query from %s on %s (%s)",
             qctx.qname.to_text(omit_final_dot=True),
             dns.rdataclass.to_text(qctx.qclass),
             dns.rdatatype.to_text(qctx.qtype),
             qctx.query.id,
-            peer,
-            protocol.name,
+            qctx.peer,
+            qctx.socket,
+            qctx.protocol.name,
         )
         logging.debug(
             "\n".join([f"[IN] {l}" for l in [""] + str(qctx.query).splitlines()])
         )
 
     def _log_response(
-        self,
-        qctx: QueryContext,
-        response: Optional[Union[dns.message.Message, bytes]],
-        peer: Peer,
-        protocol: DnsProtocol,
+        self, qctx: QueryContext, response: dns.message.Message | bytes | None
     ) -> None:
         if not response:
             logging.info(
-                "Not sending a response to query (ID=%d) from %s (%s)",
+                "Not sending a response to query (ID=%d) from %s on %s (%s)",
                 qctx.query.id,
-                peer,
-                protocol.name,
+                qctx.peer,
+                qctx.socket,
+                qctx.protocol.name,
             )
             return
 
@@ -1071,7 +1307,7 @@ class AsyncDnsServer(AsyncServer):
                 qtype = "-"
 
             logging.info(
-                "Sending %s/%s/%s (ID=%d) response (%d/%d/%d/%d) to a query (ID=%d) from %s (%s)",
+                "Sending %s/%s/%s (ID=%d) response (%d/%d/%d/%d) to a query (ID=%d) from %s on %s (%s)",
                 qname,
                 qclass,
                 qtype,
@@ -1081,8 +1317,9 @@ class AsyncDnsServer(AsyncServer):
                 len(response.authority),
                 len(response.additional),
                 qctx.query.id,
-                peer,
-                protocol.name,
+                qctx.peer,
+                qctx.socket,
+                qctx.protocol.name,
             )
             logging.debug(
                 "\n".join([f"[OUT] {l}" for l in [""] + str(response).splitlines()])
@@ -1090,16 +1327,17 @@ class AsyncDnsServer(AsyncServer):
             return
 
         logging.info(
-            "Sending response (%d bytes) to a query (ID=%d) from %s (%s)",
+            "Sending response (%d bytes) to a query (ID=%d) from %s on %s (%s)",
             len(response),
             qctx.query.id,
-            peer,
-            protocol.name,
+            qctx.peer,
+            qctx.socket,
+            qctx.protocol.name,
         )
         logging.debug("[OUT] %s", response.hex())
 
     async def _handle_query(
-        self, wire: bytes, peer: Peer, protocol: DnsProtocol
+        self, wire: bytes, socket: Peer, peer: Peer, protocol: DnsProtocol
     ) -> AsyncGenerator[bytes, None]:
         """
         Yield wire data to send as a response over the established transport.
@@ -1109,12 +1347,12 @@ class AsyncDnsServer(AsyncServer):
         except dns.exception.DNSException as exc:
             logging.error("Invalid query from %s (%s): %s", peer, wire.hex(), exc)
             return
-        response_stub = dns.message.make_response(query)
-        qctx = QueryContext(query, response_stub, peer, protocol)
-        self._log_query(qctx, peer, protocol)
+        response_stub = _make_asyncserver_response(query)
+        qctx = QueryContext(query, response_stub, socket, peer, protocol)
+        self._log_query(qctx)
         responses = self._prepare_responses(qctx)
         async for response in responses:
-            self._log_response(qctx, response, peer, protocol)
+            self._log_response(qctx, response)
             if response:
                 if isinstance(response, dns.message.Message):
                     response = response.to_wire(max_size=65535)
@@ -1146,7 +1384,7 @@ class AsyncDnsServer(AsyncServer):
 
     async def _prepare_responses(
         self, qctx: QueryContext
-    ) -> AsyncGenerator[Optional[Union[dns.message.Message, bytes]], None]:
+    ) -> AsyncGenerator[dns.message.Message | bytes | None, None]:
         """
         Yield response(s) either from response handlers or zone data.
         """
@@ -1339,10 +1577,10 @@ class ControllableAsyncDnsServer(AsyncDnsServer):
         return dns.name.from_text(self._CONTROL_DOMAIN)
 
     @functools.cached_property
-    def _commands(self) -> Dict[dns.name.Name, "ControlCommand"]:
+    def _commands(self) -> dict[dns.name.Name, "ControlCommand"]:
         return {}
 
-    def install_control_commands(self, commands: List["ControlCommand"]) -> None:
+    def install_control_commands(self, *commands: "ControlCommand") -> None:
         for command in commands:
             self.install_control_command(command)
 
@@ -1360,7 +1598,7 @@ class ControllableAsyncDnsServer(AsyncDnsServer):
 
     async def _prepare_responses(
         self, qctx: QueryContext
-    ) -> AsyncGenerator[Optional[Union[dns.message.Message, bytes]], None]:
+    ) -> AsyncGenerator[dns.message.Message | bytes | None, None]:
         """
         Detect and handle control queries, falling back to normal processing
         for non-control queries.
@@ -1373,9 +1611,7 @@ class ControllableAsyncDnsServer(AsyncDnsServer):
         async for response in super()._prepare_responses(qctx):
             yield response
 
-    def _handle_control_command(
-        self, qctx: QueryContext
-    ) -> Optional[dns.message.Message]:
+    def _handle_control_command(self, qctx: QueryContext) -> dns.message.Message | None:
         """
         Detect and handle control queries.
 
@@ -1450,8 +1686,8 @@ class ControlCommand(abc.ABC):
 
     @abc.abstractmethod
     def handle(
-        self, args: List[str], server: ControllableAsyncDnsServer, qctx: QueryContext
-    ) -> Optional[str]:
+        self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
+    ) -> str | None:
         """
         This method is expected to carry out arbitrary actions in response to a
         control query.  Note that it is invoked synchronously (it is not a
@@ -1489,11 +1725,11 @@ class ToggleResponsesCommand(ControlCommand):
     control_subdomain = "send-responses"
 
     def __init__(self) -> None:
-        self._current_handler: Optional[IgnoreAllQueries] = None
+        self._current_handler: IgnoreAllQueries | None = None
 
     def handle(
-        self, args: List[str], server: ControllableAsyncDnsServer, qctx: QueryContext
-    ) -> Optional[str]:
+        self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
+    ) -> str | None:
         if len(args) != 1:
             logging.error("Invalid %s query %s", self, qctx.qname)
             qctx.response.set_rcode(dns.rcode.SERVFAIL)
@@ -1518,3 +1754,30 @@ class ToggleResponsesCommand(ControlCommand):
         logging.error("Unrecognized response sending mode '%s'", mode)
         qctx.response.set_rcode(dns.rcode.SERVFAIL)
         return f"unrecognized response sending mode '{mode}'"
+
+
+class SwitchControlCommand(ControlCommand):
+    """
+    Switch the server's response handlers based on the control query.
+
+    A sequence of response handlers is associated with each key.  When a
+    control query is received, the server's response handlers are replaced
+    with the sequence associated with the key extracted from the control
+    query.
+    """
+
+    control_subdomain = "switch"
+
+    def __init__(self, handler_mapping: dict[str, Sequence[ResponseHandler]]):
+        self._handler_mapping = handler_mapping
+
+    def handle(
+        self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
+    ) -> str | None:
+        if len(args) != 1 or args[0] not in self._handler_mapping:
+            logging.error("Invalid %s query %s", self, qctx.qname)
+            qctx.response.set_rcode(dns.rcode.SERVFAIL)
+            return f"invalid query; exactly one of {list(self._handler_mapping.keys())} is expected in QNAME"
+
+        server.replace_response_handlers(*self._handler_mapping[args[0]])
+        return f"switched to handler set '{args[0]}'"