information regarding copyright ownership.
"""
+from collections.abc import AsyncGenerator, Callable, Coroutine, Sequence
from dataclasses import dataclass, field
-from typing import (
- Any,
- AsyncGenerator,
- Callable,
- Coroutine,
- Dict,
- List,
- Optional,
- Set,
- Tuple,
- Union,
- cast,
-)
+from typing import Any, cast
import abc
import asyncio
import dns.rdatatype
import dns.rrset
import dns.tsig
-import dns.version
import dns.zone
_UdpHandler = Callable[
- [bytes, Tuple[str, int], asyncio.DatagramTransport], Coroutine[Any, Any, None]
+ [bytes, tuple[str, int], asyncio.DatagramTransport], Coroutine[Any, Any, None]
]
self,
handler: _UdpHandler,
) -> None:
- self._transport: Optional[asyncio.DatagramTransport] = None
+ self._transport: asyncio.DatagramTransport | None = None
self._handler: _UdpHandler = handler
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""
self._transport = cast(asyncio.DatagramTransport, transport)
- def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None:
+ def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
"""
Called by asyncio when a datagram is received.
"""
def __init__(
self,
- udp_handler: Optional[_UdpHandler],
- tcp_handler: Optional[_TcpHandler],
- pidfile: Optional[str] = None,
+ udp_handler: _UdpHandler | None,
+ tcp_handler: _TcpHandler | None,
+ pidfile: str | None = None,
) -> None:
logging.basicConfig(
format="%(asctime)s %(levelname)8s %(message)s",
logging.info("Setting up IPv4 listener at %s:%d", ipv4_address, port)
logging.info("Setting up IPv6 listener at [%s]:%d", ipv6_address, port)
- self._ip_addresses: Tuple[str, str] = (ipv4_address, ipv6_address)
+ self._ip_addresses: tuple[str, str] = (ipv4_address, ipv6_address)
self._port: int = port
- self._udp_handler: Optional[_UdpHandler] = udp_handler
- self._tcp_handler: Optional[_TcpHandler] = tcp_handler
- self._pidfile: Optional[str] = pidfile
- self._work_done: Optional[asyncio.Future] = None
+ self._udp_handler: _UdpHandler | None = udp_handler
+ self._tcp_handler: _TcpHandler | None = tcp_handler
+ self._pidfile: str | None = pidfile
+ self._work_done: asyncio.Future | None = None
def _get_ipv4_address_from_directory_name(self) -> str:
containing_directory = pathlib.Path().absolute().stem
loop.set_exception_handler(self._handle_exception)
def _handle_exception(
- self, _: asyncio.AbstractEventLoop, context: Dict[str, Any]
+ self, _: asyncio.AbstractEventLoop, context: dict[str, Any]
) -> None:
assert self._work_done
exception = context.get("exception", RuntimeError(context["message"]))
query: dns.message.Message
response: dns.message.Message
+ socket: Peer
peer: Peer
protocol: DnsProtocol
- zone: Optional[dns.zone.Zone] = field(default=None, init=False)
- soa: Optional[dns.rrset.RRset] = field(default=None, init=False)
- node: Optional[dns.node.Node] = field(default=None, init=False)
- answer: Optional[dns.rdataset.Rdataset] = field(default=None, init=False)
- alias: Optional[dns.name.Name] = field(default=None, init=False)
- _initialized_response: Optional[dns.message.Message] = field(
- default=None, init=False
- )
- _initialized_response_with_zone_data: Optional[dns.message.Message] = field(
+ zone: dns.zone.Zone | None = field(default=None, init=False)
+ soa: dns.rrset.RRset | None = field(default=None, init=False)
+ node: dns.node.Node | None = field(default=None, init=False)
+ answer: dns.rdataset.Rdataset | None = field(default=None, init=False)
+ alias: dns.name.Name | None = field(default=None, init=False)
+ _initialized_response: dns.message.Message | None = field(default=None, init=False)
+ _initialized_response_with_zone_data: dns.message.Message | None = field(
default=None, init=False
)
"""
@abc.abstractmethod
- async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
+ async def perform(self) -> dns.message.Message | bytes | None:
"""
This method is expected to carry out arbitrary actions (e.g. wait for a
specific amount of time, modify the answer, etc.) and then return the
"""
response: dns.message.Message
- authoritative: Optional[bool] = None
+ authoritative: bool | None = None
delay: float = 0.0
+ acknowledge_hand_rolled_response: bool = False
- async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
+ async def perform(self) -> dns.message.Message | bytes | None:
"""
Yield a potentially delayed response that is a dns.message.Message.
"""
assert isinstance(self.response, dns.message.Message)
+ if not (
+ _is_asyncserver_response(self.response)
+ or self.acknowledge_hand_rolled_response
+ ):
+ error = "The response you are trying to send was not created using "
+ error += "AsyncDnsServer's response preparation methods. "
+ error += "This will break features such as automatic AA flag "
+ error += "and RCODE handling. If you need a fresh copy of a "
+ error += "response, use `QueryContext.prepare_new_response` "
+ error += "instead of `dns.message.make_response`. "
+ error += "To acknowledge this and proceed anyway, set "
+ error += "`acknowledge_hand_rolled_response=True` in "
+ error += "DnsResponseSend's constructor."
+ raise RuntimeError(error)
+
if self.authoritative is not None:
if self.authoritative:
self.response.flags |= dns.flags.AA
response: bytes
delay: float = 0.0
- async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
+ async def perform(self) -> dns.message.Message | bytes | None:
"""
Yield a potentially delayed response that is a sequence of bytes.
"""
Action which does nothing - as if a packet was dropped.
"""
- async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
+ async def perform(self) -> dns.message.Message | bytes | None:
return None
@dataclass
-class ResponseDropAndCloseConnection(ResponseAction):
+class CloseConnection(ResponseAction):
"""
- Action which makes the server close the connection after the DNS query is
- received by the server (TCP only).
+ Action which makes the server close the connection (TCP only).
The connection may be closed with a delay if requested.
"""
delay: float = 0.0
- async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
+ async def perform(self) -> dns.message.Message | bytes | None:
if self.delay > 0:
logging.info("Waiting %.1fs before closing TCP connection", self.delay)
await asyncio.sleep(self.delay)
client socket, effectively ignoring all incoming connections.
"""
- _connections: Set[asyncio.StreamWriter] = field(default_factory=set)
+ _connections: set[asyncio.StreamWriter] = field(default_factory=set)
async def handle(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer: Peer
make the server send an RST segment; this happens when the server closes a
client's socket while there is still unread data in that socket's buffer.
If closing the connection _after_ the query is read by the server is enough
- for a given use case, the ResponseDropAndCloseConnection response handler
- should be used instead.
+ for a given use case, the CloseConnection response handler should be used
+ instead.
"""
delay: float = 0.0
@property
@abc.abstractmethod
- def qnames(self) -> List[str]:
+ def qnames(self) -> list[str]:
"""
A list of QNAMEs handled by this class.
"""
raise NotImplementedError
def __init__(self) -> None:
- self._qnames: List[dns.name.Name] = [dns.name.from_text(d) for d in self.qnames]
+ self._qnames: list[dns.name.Name] = [dns.name.from_text(d) for d in self.qnames]
def __str__(self) -> str:
return f"{self.__class__.__name__}(QNAMEs: {', '.join(self.qnames)})"
return qctx.qname in self._qnames
+class QnameQtypeHandler(QnameHandler):
+ """
+ Handle queries for which both of the following conditions are true:
+
+ - the query's QNAME is present in `self.qnames`,
+ - the query's QTYPE is present in `self.qtypes`.
+ """
+
+ @property
+ @abc.abstractmethod
+ def qtypes(self) -> list[dns.rdatatype.RdataType]:
+ """
+ A list of QTYPEs handled by this class.
+ """
+ raise NotImplementedError
+
+ def __init__(self) -> None:
+ super().__init__()
+ self._qtypes: list[dns.rdatatype.RdataType] = self.qtypes
+
+ def __str__(self) -> str:
+ return f"{self.__class__.__name__}(QNAMEs: {', '.join(self.qnames)}; QTYPEs: {', '.join(map(str, self.qtypes))})"
+
+ def match(self, qctx: QueryContext) -> bool:
+ """
+ Handle queries whose QNAME and QTYPE match any of the QNAMEs and
+ QTYPEs handled by this class.
+ """
+ return qctx.qtype in self._qtypes and super().match(qctx)
+
+
+class StaticResponseHandler(ResponseHandler):
+ """
+ Base class used for deriving custom static response handlers.
+
+ The derived class can specify the RRsets to be included in the answer,
+ authority, and additional sections of the response, whether to set the AA
+ bit in the response, and a delay before sending the response.
+
+ The default implementation of `get_responses()` uses these properties to
+ prepare and yield a single response.
+ """
+
+ @property
+ def rcode(self) -> dns.rcode.Rcode | None:
+ """
+ Optional RCODE to be set in the response.
+ """
+ return None
+
+ @property
+ def answer(self) -> Sequence[dns.rrset.RRset]:
+ """
+ RRsets to be included in the answer section of the response.
+ """
+ return []
+
+ @property
+ def authority(self) -> Sequence[dns.rrset.RRset]:
+ """
+ RRsets to be included in the authority section of the response.
+ """
+ return []
+
+ @property
+ def additional(self) -> Sequence[dns.rrset.RRset]:
+ """
+ RRsets to be included in the additional section of the response.
+ """
+ return []
+
+ @property
+ def authoritative(self) -> bool | None:
+ """
+ Whether to set the AA bit in the response.
+ """
+ return None
+
+ @property
+ def delay(self) -> float:
+ """
+ Delay before sending the response.
+ """
+ return 0.0
+
+ async def get_responses(
+ self, qctx: QueryContext
+ ) -> AsyncGenerator[DnsResponseSend, None]:
+ qctx.prepare_new_response(with_zone_data=False)
+ qctx.response.answer.extend(self.answer)
+ qctx.response.authority.extend(self.authority)
+ qctx.response.additional.extend(self.additional)
+ if self.rcode is not None:
+ qctx.response.set_rcode(self.rcode)
+ yield DnsResponseSend(
+ qctx.response, authoritative=self.authoritative, delay=self.delay
+ )
+
+
class DomainHandler(ResponseHandler):
"""
Base class used for deriving custom domain handlers.
The derived class must specify a list of `domains` that it wants to handle.
Queries for any of these domains (and their subdomains) will then be passed
to the `get_response()` method in the derived class.
+
+ The most specific matching domain is stored in the `matched_domain` attribute.
"""
@property
@abc.abstractmethod
- def domains(self) -> List[str]:
+ def domains(self) -> list[str]:
"""
A list of domain names handled by this class.
"""
raise NotImplementedError
def __init__(self) -> None:
- self._domains: List[dns.name.Name] = [
- dns.name.from_text(d) for d in self.domains
- ]
+ self._domains: list[dns.name.Name] = sorted(
+ [dns.name.from_text(d) for d in self.domains], reverse=True
+ )
+ self._matched_domain: dns.name.Name | None = None
+
+ @property
+ def matched_domain(self) -> dns.name.Name:
+ assert self._matched_domain is not None
+ return self._matched_domain
def __str__(self) -> str:
return f"{self.__class__.__name__}(domains: {', '.join(self.domains)})"
Handle queries whose QNAME matches any of the domains handled by this
class.
"""
+ self._matched_domain = None
for domain in self._domains:
if qctx.qname.is_subdomain(domain):
+ self._matched_domain = domain
return True
return False
+class ForwarderHandler(ResponseHandler):
+ """
+ A handler forwarding all received queries to another DNS server with an
+ optional delay and then relaying the responses back to the original client.
+
+ Queries are currently always forwarded via UDP.
+ """
+
+ @property
+ @abc.abstractmethod
+ def target(self) -> str:
+ """
+ The address of the DNS server to forward queries to.
+ """
+ raise NotImplementedError
+
+ @property
+ def port(self) -> int:
+ """
+ The port of the DNS server to forward queries to.
+
+ The default value of 0 causes the same port as the one used by this
+ server for listening to be used.
+ """
+ return 0
+
+ @property
+ def delay(self) -> float:
+ """
+ The number of seconds to wait before forwarding each query.
+ """
+ return 0.0
+
+ def __str__(self) -> str:
+ return f"{self.__class__.__name__}(target: {self.target}:{self.port})"
+
+ class ForwarderProtocol(asyncio.DatagramProtocol):
+ def __init__(self, query: bytes, response: asyncio.Future) -> None:
+ self._query = query
+ self._response = response
+
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
+ logging.debug("[OUT] %s", self._query.hex())
+ cast(asyncio.DatagramTransport, transport).sendto(self._query)
+
+ def datagram_received(self, data: bytes, _: tuple[str, int]) -> None:
+ logging.debug("[IN] %s", data.hex())
+ self._response.set_result(data)
+
+ async def get_responses(
+ self, qctx: QueryContext
+ ) -> AsyncGenerator[ResponseAction, None]:
+ loop = asyncio.get_running_loop()
+ response = loop.create_future()
+ forwarding_target = f"{self.target}:{self.port or qctx.socket.port}"
+
+ if self.delay > 0:
+ logging.info(
+ "Waiting %.1fs before forwarding %s query from %s to %s over UDP",
+ self.delay,
+ qctx.protocol.name,
+ qctx.peer,
+ forwarding_target,
+ )
+ await asyncio.sleep(self.delay)
+
+ logging.info(
+ "Forwarding %s query from %s to %s over UDP",
+ qctx.protocol.name,
+ qctx.peer,
+ forwarding_target,
+ )
+
+ transport, _ = await loop.create_datagram_endpoint(
+ lambda: self.ForwarderProtocol(qctx.query.to_wire(), response),
+ local_addr=(qctx.socket.host, 0),
+ remote_addr=(self.target, self.port or qctx.socket.port),
+ )
+
+ try:
+ await response
+ finally:
+ transport.close()
+
+ logging.info(
+ "Relaying UDP response from %s to %s over %s",
+ forwarding_target,
+ qctx.peer,
+ qctx.protocol.name,
+ )
+
+ try:
+ message = _DnsMessageWithTsigDisabled.from_wire(response.result())
+ yield DnsResponseSend(message, acknowledge_hand_rolled_response=True)
+ except dns.exception.DNSException:
+ logging.warning(
+ "Failed to parse response from %s as a DNS message, relaying it as raw bytes",
+ forwarding_target,
+ )
+ yield BytesResponseSend(response.result())
+
+
@dataclass
class _ZoneTreeNode:
"""
A node representing a zone with one origin.
"""
- zone: Optional[dns.zone.Zone]
- children: List["_ZoneTreeNode"] = field(default_factory=list)
+ zone: dns.zone.Zone | None
+ children: list["_ZoneTreeNode"] = field(default_factory=list)
class _ZoneTree:
node_from.children.remove(child)
node_to.children.append(child)
- def find_best_zone(self, name: dns.name.Name) -> Optional[dns.zone.Zone]:
+ def find_best_zone(self, name: dns.name.Name) -> dns.zone.Zone | None:
"""
Return the closest matching zone (if any) for the domain name.
"""
"""
class _DisableTsigHandling(contextlib.ContextDecorator):
- def __init__(self, message: Optional[dns.message.Message] = None) -> None:
+ def __init__(self, message: dns.message.Message | None = None) -> None:
self.original_tsig_sign = dns.tsig.sign
self.original_tsig_validate = dns.tsig.validate
if message:
from failing on messages initialized with `dns.message.from_wire(keyring=False)`.
"""
- def sign(*_: Any, **__: Any) -> Tuple[dns.rdata.Rdata, None]:
+ def sign(*_: Any, **__: Any) -> tuple[dns.rdata.Rdata, None]:
assert self.tsig
return self.tsig[0], None
pass
+_ASYNCSERVER_RESPONSE_MARKER = "__is_asyncserver_response__"
+
+
+def _make_asyncserver_response(query: dns.message.Message) -> dns.message.Message:
+ response = dns.message.make_response(query)
+ setattr(response, _ASYNCSERVER_RESPONSE_MARKER, True)
+ return response
+
+
+def _is_asyncserver_response(message: dns.message.Message) -> bool:
+ return getattr(message, _ASYNCSERVER_RESPONSE_MARKER, False)
+
+
class AsyncDnsServer(AsyncServer):
"""
DNS server which responds to queries based on zone data and/or custom
self,
/,
default_rcode: dns.rcode.Rcode = dns.rcode.REFUSED,
- default_aa: bool = True,
- keyring: Union[
- Dict[dns.name.Name, dns.tsig.Key], None, _NoKeyringType
- ] = _NoKeyringType(),
+ default_aa: bool = False,
+ keyring: (
+ dict[dns.name.Name, dns.tsig.Key] | None | _NoKeyringType
+ ) = _NoKeyringType(),
acknowledge_manual_dname_handling: bool = False,
) -> None:
super().__init__(self._handle_udp, self._handle_tcp, "ans.pid")
self._zone_tree: _ZoneTree = _ZoneTree()
- self._connection_handler: Optional[ConnectionHandler] = None
- self._response_handlers: List[ResponseHandler] = []
+ self._connection_handler: ConnectionHandler | None = None
+ self._response_handlers: list[ResponseHandler] = []
self._default_rcode = default_rcode
self._default_aa = default_aa
self._keyring = keyring
else:
self._response_handlers.append(handler)
- def install_response_handlers(self, handlers: List[ResponseHandler]) -> None:
+ def install_response_handlers(self, *handlers: ResponseHandler) -> None:
for handler in handlers:
self.install_response_handler(handler)
+ def replace_response_handlers(self, *new_handlers: ResponseHandler) -> None:
+ """
+ Uninstall all currently installed handlers and install the provided ones.
+ """
+ logging.info("Uninstalling response handlers: %s", str(self._response_handlers))
+ self._response_handlers.clear()
+ self.install_response_handlers(*new_handlers)
+
def uninstall_response_handler(self, handler: ResponseHandler) -> None:
"""
Remove the specified handler from the list of response handlers.
raise ValueError(error)
async def _handle_udp(
- self, wire: bytes, addr: Tuple[str, int], transport: asyncio.DatagramTransport
+ self, wire: bytes, addr: tuple[str, int], transport: asyncio.DatagramTransport
) -> None:
logging.debug("Received UDP message: %s", wire.hex())
+ socket_info = transport.get_extra_info("sockname")
+ socket = Peer(socket_info[0], socket_info[1])
peer = Peer(addr[0], addr[1])
- responses = self._handle_query(wire, peer, DnsProtocol.UDP)
+ responses = self._handle_query(wire, socket, peer, DnsProtocol.UDP)
async for response in responses:
logging.debug("Sending UDP message: %s", response.hex())
transport.sendto(response, addr)
async def _read_tcp_query(
self, reader: asyncio.StreamReader, peer: Peer
- ) -> Optional[bytes]:
+ ) -> bytes | None:
wire_length = await self._read_tcp_query_wire_length(reader, peer)
if not wire_length:
return None
async def _read_tcp_query_wire_length(
self, reader: asyncio.StreamReader, peer: Peer
- ) -> Optional[int]:
+ ) -> int | None:
logging.debug("Receiving TCP message length from %s...", peer)
wire_length_bytes = await self._read_tcp_octets(reader, peer, 2)
async def _read_tcp_query_wire(
self, reader: asyncio.StreamReader, peer: Peer, wire_length: int
- ) -> Optional[bytes]:
+ ) -> bytes | None:
logging.debug("Receiving TCP message (%d octets) from %s...", wire_length, peer)
wire = await self._read_tcp_octets(reader, peer, wire_length)
async def _read_tcp_octets(
self, reader: asyncio.StreamReader, peer: Peer, expected: int
- ) -> Optional[bytes]:
+ ) -> bytes | None:
buffer = b""
while len(buffer) < expected:
async def _send_tcp_response(
self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes
) -> None:
- responses = self._handle_query(wire, peer, DnsProtocol.TCP)
+ socket_info = writer.get_extra_info("sockname")
+ socket = Peer(socket_info[0], socket_info[1])
+ responses = self._handle_query(wire, socket, peer, DnsProtocol.TCP)
async for response in responses:
logging.debug("Sending TCP response: %s", response.hex())
writer.write(response)
await writer.drain()
- def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None:
+ def _log_query(self, qctx: QueryContext) -> None:
logging.info(
- "Received %s/%s/%s (ID=%d) query from %s (%s)",
+ "Received %s/%s/%s (ID=%d) query from %s on %s (%s)",
qctx.qname.to_text(omit_final_dot=True),
dns.rdataclass.to_text(qctx.qclass),
dns.rdatatype.to_text(qctx.qtype),
qctx.query.id,
- peer,
- protocol.name,
+ qctx.peer,
+ qctx.socket,
+ qctx.protocol.name,
)
logging.debug(
"\n".join([f"[IN] {l}" for l in [""] + str(qctx.query).splitlines()])
)
def _log_response(
- self,
- qctx: QueryContext,
- response: Optional[Union[dns.message.Message, bytes]],
- peer: Peer,
- protocol: DnsProtocol,
+ self, qctx: QueryContext, response: dns.message.Message | bytes | None
) -> None:
if not response:
logging.info(
- "Not sending a response to query (ID=%d) from %s (%s)",
+ "Not sending a response to query (ID=%d) from %s on %s (%s)",
qctx.query.id,
- peer,
- protocol.name,
+ qctx.peer,
+ qctx.socket,
+ qctx.protocol.name,
)
return
qtype = "-"
logging.info(
- "Sending %s/%s/%s (ID=%d) response (%d/%d/%d/%d) to a query (ID=%d) from %s (%s)",
+ "Sending %s/%s/%s (ID=%d) response (%d/%d/%d/%d) to a query (ID=%d) from %s on %s (%s)",
qname,
qclass,
qtype,
len(response.authority),
len(response.additional),
qctx.query.id,
- peer,
- protocol.name,
+ qctx.peer,
+ qctx.socket,
+ qctx.protocol.name,
)
logging.debug(
"\n".join([f"[OUT] {l}" for l in [""] + str(response).splitlines()])
return
logging.info(
- "Sending response (%d bytes) to a query (ID=%d) from %s (%s)",
+ "Sending response (%d bytes) to a query (ID=%d) from %s on %s (%s)",
len(response),
qctx.query.id,
- peer,
- protocol.name,
+ qctx.peer,
+ qctx.socket,
+ qctx.protocol.name,
)
logging.debug("[OUT] %s", response.hex())
async def _handle_query(
- self, wire: bytes, peer: Peer, protocol: DnsProtocol
+ self, wire: bytes, socket: Peer, peer: Peer, protocol: DnsProtocol
) -> AsyncGenerator[bytes, None]:
"""
Yield wire data to send as a response over the established transport.
except dns.exception.DNSException as exc:
logging.error("Invalid query from %s (%s): %s", peer, wire.hex(), exc)
return
- response_stub = dns.message.make_response(query)
- qctx = QueryContext(query, response_stub, peer, protocol)
- self._log_query(qctx, peer, protocol)
+ response_stub = _make_asyncserver_response(query)
+ qctx = QueryContext(query, response_stub, socket, peer, protocol)
+ self._log_query(qctx)
responses = self._prepare_responses(qctx)
async for response in responses:
- self._log_response(qctx, response, peer, protocol)
+ self._log_response(qctx, response)
if response:
if isinstance(response, dns.message.Message):
response = response.to_wire(max_size=65535)
async def _prepare_responses(
self, qctx: QueryContext
- ) -> AsyncGenerator[Optional[Union[dns.message.Message, bytes]], None]:
+ ) -> AsyncGenerator[dns.message.Message | bytes | None, None]:
"""
Yield response(s) either from response handlers or zone data.
"""
return dns.name.from_text(self._CONTROL_DOMAIN)
@functools.cached_property
- def _commands(self) -> Dict[dns.name.Name, "ControlCommand"]:
+ def _commands(self) -> dict[dns.name.Name, "ControlCommand"]:
return {}
- def install_control_commands(self, commands: List["ControlCommand"]) -> None:
+ def install_control_commands(self, *commands: "ControlCommand") -> None:
for command in commands:
self.install_control_command(command)
async def _prepare_responses(
self, qctx: QueryContext
- ) -> AsyncGenerator[Optional[Union[dns.message.Message, bytes]], None]:
+ ) -> AsyncGenerator[dns.message.Message | bytes | None, None]:
"""
Detect and handle control queries, falling back to normal processing
for non-control queries.
async for response in super()._prepare_responses(qctx):
yield response
- def _handle_control_command(
- self, qctx: QueryContext
- ) -> Optional[dns.message.Message]:
+ def _handle_control_command(self, qctx: QueryContext) -> dns.message.Message | None:
"""
Detect and handle control queries.
@abc.abstractmethod
def handle(
- self, args: List[str], server: ControllableAsyncDnsServer, qctx: QueryContext
- ) -> Optional[str]:
+ self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
+ ) -> str | None:
"""
This method is expected to carry out arbitrary actions in response to a
control query. Note that it is invoked synchronously (it is not a
control_subdomain = "send-responses"
def __init__(self) -> None:
- self._current_handler: Optional[IgnoreAllQueries] = None
+ self._current_handler: IgnoreAllQueries | None = None
def handle(
- self, args: List[str], server: ControllableAsyncDnsServer, qctx: QueryContext
- ) -> Optional[str]:
+ self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
+ ) -> str | None:
if len(args) != 1:
logging.error("Invalid %s query %s", self, qctx.qname)
qctx.response.set_rcode(dns.rcode.SERVFAIL)
logging.error("Unrecognized response sending mode '%s'", mode)
qctx.response.set_rcode(dns.rcode.SERVFAIL)
return f"unrecognized response sending mode '{mode}'"
+
+
+class SwitchControlCommand(ControlCommand):
+ """
+ Switch the server's response handlers based on the control query.
+
+ A sequence of response handlers is associated with each key. When a
+ control query is received, the server's response handlers are replaced
+ with the sequence associated with the key extracted from the control
+ query.
+ """
+
+ control_subdomain = "switch"
+
+ def __init__(self, handler_mapping: dict[str, Sequence[ResponseHandler]]):
+ self._handler_mapping = handler_mapping
+
+ def handle(
+ self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
+ ) -> str | None:
+ if len(args) != 1 or args[0] not in self._handler_mapping:
+ logging.error("Invalid %s query %s", self, qctx.qname)
+ qctx.response.set_rcode(dns.rcode.SERVFAIL)
+ return f"invalid query; exactly one of {list(self._handler_mapping.keys())} is expected in QNAME"
+
+ server.replace_response_handlers(*self._handler_mapping[args[0]])
+ return f"switched to handler set '{args[0]}'"