]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Implement a response handler that forwards queries
authorMichał Kępień <michal@isc.org>
Fri, 13 Feb 2026 13:27:10 +0000 (14:27 +0100)
committerMichał Kępień <michal@isc.org>
Fri, 13 Feb 2026 13:27:10 +0000 (14:27 +0100)
Add a new response handler, ForwarderHandler, which enables forwarding
all queries to another DNS server.  To simplify implementation, always
forward queries to the target server via UDP, even if they are
originally received using a different transport protocol.

bin/tests/system/isctest/asyncserver.py

index c894d16f00ac86ebb2aaee454c1cbd80a130a028..8e4ea245e5ac9fcca94441177bab552ff372902b 100644 (file)
@@ -788,6 +788,108 @@ class DomainHandler(ResponseHandler):
         return False
 
 
+class ForwarderHandler(ResponseHandler):
+    """
+    A handler forwarding all received queries to another DNS server with an
+    optional delay and then relaying the responses back to the original client.
+
+    Queries are currently always forwarded via UDP.
+    """
+
+    @property
+    @abc.abstractmethod
+    def target(self) -> str:
+        """
+        The address of the DNS server to forward queries to.
+        """
+        raise NotImplementedError
+
+    @property
+    def port(self) -> int:
+        """
+        The port of the DNS server to forward queries to.
+
+        The default value of 0 causes the same port as the one used by this
+        server for listening to be used.
+        """
+        return 0
+
+    @property
+    def delay(self) -> float:
+        """
+        The number of seconds to wait before forwarding each query.
+        """
+        return 0.0
+
+    def __str__(self) -> str:
+        return f"{self.__class__.__name__}(target: {self.target}:{self.port})"
+
+    class ForwarderProtocol(asyncio.DatagramProtocol):
+        def __init__(self, query: bytes, response: asyncio.Future) -> None:
+            self._query = query
+            self._response = response
+
+        def connection_made(self, transport: asyncio.BaseTransport) -> None:
+            logging.debug("[OUT] %s", self._query.hex())
+            cast(asyncio.DatagramTransport, transport).sendto(self._query)
+
+        def datagram_received(self, data: bytes, _: Tuple[str, int]) -> None:
+            logging.debug("[IN] %s", data.hex())
+            self._response.set_result(data)
+
+    async def get_responses(
+        self, qctx: QueryContext
+    ) -> AsyncGenerator[ResponseAction, None]:
+        loop = asyncio.get_running_loop()
+        response = loop.create_future()
+        forwarding_target = f"{self.target}:{self.port or qctx.socket.port}"
+
+        if self.delay > 0:
+            logging.info(
+                "Waiting %.1fs before forwarding %s query from %s to %s over UDP",
+                self.delay,
+                qctx.protocol.name,
+                qctx.peer,
+                forwarding_target,
+            )
+            await asyncio.sleep(self.delay)
+
+        logging.info(
+            "Forwarding %s query from %s to %s over UDP",
+            qctx.protocol.name,
+            qctx.peer,
+            forwarding_target,
+        )
+
+        transport, _ = await loop.create_datagram_endpoint(
+            lambda: self.ForwarderProtocol(qctx.query.to_wire(), response),
+            local_addr=(qctx.socket.host, 0),
+            remote_addr=(self.target, self.port or qctx.socket.port),
+        )
+
+        try:
+            await response
+        finally:
+            transport.close()
+
+        logging.info(
+            "Relaying UDP response from %s to %s over %s",
+            forwarding_target,
+            qctx.peer,
+            qctx.protocol.name,
+        )
+
+        try:
+            message = _DnsMessageWithTsigDisabled.from_wire(response.result())
+            yield DnsResponseSend(message, acknowledge_hand_rolled_response=True)
+        except dns.exception.DNSException:
+            logging.warning(
+                "Failed to parse response from %s as a DNS message, relaying it as raw bytes",
+                forwarding_target,
+            )
+            yield BytesResponseSend(response.result())
+
+
 @dataclass
 class _ZoneTreeNode:
     """