]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Let queries with TSIG parse in isctest.asyncserver.AsyncDnsServer
authorŠtěpán Balážik <stepan@isc.org>
Mon, 23 Jun 2025 14:43:56 +0000 (16:43 +0200)
committerŠtěpán Balážik <stepan@isc.org>
Sun, 13 Jul 2025 08:57:04 +0000 (10:57 +0200)
Previously, upon receiving a query with TSIG, the server would log
an error and timeout. As there is no way to set up the keyring in the
class anyway (and I believe we don't need it), this commit lets such
queries parse but logs the fact that the query has TSIG.

However, there is a bug [1] in dnspython, which causes `make_response`
and `to_wire` to crash on messages constructed by `from_wire` with
`keyring=False`, so the hack with `message.__class__` is needed to work
around this.

This makes just enough changes for the tsig system test to work with
dnspython >= 2.0.0. On older version the server gives up.

[1] https://github.com/rthalley/dnspython/issues/1205

bin/tests/system/isctest/asyncserver.py

index 2341f0311048d797e67eb09cac538a4d111db2ec..784ee8efd226305007cbc65d944a5e6424d96749 100644 (file)
@@ -28,6 +28,7 @@ from typing import (
 
 import abc
 import asyncio
+import contextlib
 import enum
 import functools
 import logging
@@ -46,6 +47,8 @@ import dns.rcode
 import dns.rdataclass
 import dns.rdatatype
 import dns.rrset
+import dns.tsig
+import dns.version
 import dns.zone
 
 try:
@@ -517,6 +520,67 @@ class _ZoneTree:
         return node.zone if node != self._root else None
 
 
+class _DnsMessageWithTsigDisabled(dns.message.Message):
+    """
+    A wrapper for `dns.message.Message` that works around a dnspython bug
+    causing exceptions to be raised when `make_response()` or `to_wire()` are
+    called for a message created using `dns.message.from_wire(keyring=False)`.
+
+    See https://github.com/rthalley/dnspython/issues/1205 for more details.
+    """
+
+    class _DisableTsigHandling(contextlib.ContextDecorator):
+        def __init__(self, message: Optional[dns.message.Message] = None) -> None:
+            self.original_tsig_sign = dns.tsig.sign
+            self.original_tsig_validate = dns.tsig.validate
+            if message:
+                self.tsig = message.tsig
+
+        def __enter__(self) -> None:
+            """
+            Override the `dns.tsig.sign` and `dns.tsig.validate` functions to prevent them
+            from failing on messages initialized with `dns.message.from_wire(keyring=False)`.
+            """
+
+            def sign(*_: Any, **__: Any) -> Tuple[dns.rdata.Rdata, None]:
+                assert self.tsig
+                return self.tsig[0], None
+
+            def validate(*_: Any, **__: Any) -> None:
+                return None
+
+            dns.tsig.sign = sign
+            dns.tsig.validate = validate
+
+        def __exit__(self, *_: Any, **__: Any) -> None:
+            dns.tsig.sign = self.original_tsig_sign
+            dns.tsig.validate = self.original_tsig_validate
+
+    @classmethod
+    def from_wire(cls, wire: bytes) -> "_DnsMessageWithTsigDisabled":
+        with cls._DisableTsigHandling():
+            message = dns.message.from_wire(wire, keyring=False)
+            message.__class__ = _DnsMessageWithTsigDisabled
+
+        return cast(_DnsMessageWithTsigDisabled, message)
+
+    @property
+    def had_tsig(self) -> bool:
+        """
+        Override the `had_tsig()` method to always return False, to prevent
+        `make_response()` from crashing.
+        """
+        return False
+
+    def to_wire(self, *args: Any, **kwargs: Any) -> bytes:
+        """
+        Override the `to_wire()` method to prevent it from trying to sign
+        the message with TSIG.
+        """
+        with self._DisableTsigHandling(self):
+            return super().to_wire(*args, **kwargs)
+
+
 class AsyncDnsServer(AsyncServer):
     """
     DNS server which responds to queries based on zone data and/or custom
@@ -533,12 +597,17 @@ class AsyncDnsServer(AsyncServer):
     response from scratch, without using zone data at all.
     """
 
-    def __init__(self, acknowledge_manual_dname_handling: bool = False) -> None:
+    def __init__(
+        self,
+        acknowledge_manual_dname_handling: bool = False,
+        acknowledge_tsig_dnspython_hacks: bool = False,
+    ) -> None:
         super().__init__(self._handle_udp, self._handle_tcp, "ans.pid")
 
         self._zone_tree: _ZoneTree = _ZoneTree()
         self._response_handlers: List[ResponseHandler] = []
         self._acknowledge_manual_dname_handling = acknowledge_manual_dname_handling
+        self._acknowledge_tsig_dnspython_hacks = acknowledge_tsig_dnspython_hacks
 
         self._load_zones()
 
@@ -778,6 +847,10 @@ class AsyncDnsServer(AsyncServer):
         """
         try:
             query = dns.message.from_wire(wire)
+        except dns.message.UnknownTSIGKey:
+            self._abort_if_on_dnspython_version_less_than_2_0_0()
+            self._abort_if_tsig_signed_query_received_unless_acknowledged()
+            query = _DnsMessageWithTsigDisabled.from_wire(wire)
         except dns.exception.DNSException as exc:
             logging.error("Invalid query from %s (%s): %s", peer, wire.hex(), exc)
             return
@@ -796,6 +869,26 @@ class AsyncDnsServer(AsyncServer):
                     response_length = struct.pack("!H", len(response))
                     yield response_length + response
 
+    def _abort_if_on_dnspython_version_less_than_2_0_0(self) -> None:
+        if dns.version.MAJOR < 2:
+            error = "Receiving TSIG signed queries requires dnspython >= 2.0.0; "
+            error += 'add `pytest.importorskip("dns", minversion="2.0.0")` '
+            error += "to the test module to skip this test."
+            raise RuntimeError(error)
+
+    def _abort_if_tsig_signed_query_received_unless_acknowledged(self) -> None:
+        if self._acknowledge_tsig_dnspython_hacks:
+            return
+
+        error = "TSIG-signed query received; "
+        error += "due to a bug in dnspython, this requires some hacking around; "
+        error += "you may experience unexpected behavior when dealing with TSIG; "
+        error += "TSIG validation is disabled, so any TSIG handling must be done "
+        error += "manually; pass `acknowledge_tsig_dnspython_hacks=True` to the "
+        error += "AsyncDnsServer constructor to acknowledge this and continue."
+
+        raise ValueError(error)
+
     async def _prepare_responses(
         self, qctx: QueryContext
     ) -> AsyncGenerator[Optional[Union[dns.message.Message, bytes]], None]: