import abc
import asyncio
+import contextlib
import enum
import functools
import logging
import dns.rdataclass
import dns.rdatatype
import dns.rrset
+import dns.tsig
+import dns.version
import dns.zone
try:
return node.zone if node != self._root else None
+class _DnsMessageWithTsigDisabled(dns.message.Message):
+ """
+ A wrapper for `dns.message.Message` that works around a dnspython bug
+ causing exceptions to be raised when `make_response()` or `to_wire()` are
+ called for a message created using `dns.message.from_wire(keyring=False)`.
+
+ See https://github.com/rthalley/dnspython/issues/1205 for more details.
+ """
+
+ class _DisableTsigHandling(contextlib.ContextDecorator):
+ def __init__(self, message: Optional[dns.message.Message] = None) -> None:
+ self.original_tsig_sign = dns.tsig.sign
+ self.original_tsig_validate = dns.tsig.validate
+ if message:
+ self.tsig = message.tsig
+
+ def __enter__(self) -> None:
+ """
+ Override the `dns.tsig.sign` and `dns.tsig.validate` functions to prevent them
+ from failing on messages initialized with `dns.message.from_wire(keyring=False)`.
+ """
+
+ def sign(*_: Any, **__: Any) -> Tuple[dns.rdata.Rdata, None]:
+ assert self.tsig
+ return self.tsig[0], None
+
+ def validate(*_: Any, **__: Any) -> None:
+ return None
+
+ dns.tsig.sign = sign
+ dns.tsig.validate = validate
+
+ def __exit__(self, *_: Any, **__: Any) -> None:
+ dns.tsig.sign = self.original_tsig_sign
+ dns.tsig.validate = self.original_tsig_validate
+
+ @classmethod
+ def from_wire(cls, wire: bytes) -> "_DnsMessageWithTsigDisabled":
+ with cls._DisableTsigHandling():
+ message = dns.message.from_wire(wire, keyring=False)
+ message.__class__ = _DnsMessageWithTsigDisabled
+
+ return cast(_DnsMessageWithTsigDisabled, message)
+
+ @property
+ def had_tsig(self) -> bool:
+ """
+ Override the `had_tsig()` method to always return False, to prevent
+ `make_response()` from crashing.
+ """
+ return False
+
+ def to_wire(self, *args: Any, **kwargs: Any) -> bytes:
+ """
+ Override the `to_wire()` method to prevent it from trying to sign
+ the message with TSIG.
+ """
+ with self._DisableTsigHandling(self):
+ return super().to_wire(*args, **kwargs)
+
+
class AsyncDnsServer(AsyncServer):
"""
DNS server which responds to queries based on zone data and/or custom
response from scratch, without using zone data at all.
"""
- def __init__(self, acknowledge_manual_dname_handling: bool = False) -> None:
+ def __init__(
+ self,
+ acknowledge_manual_dname_handling: bool = False,
+ acknowledge_tsig_dnspython_hacks: bool = False,
+ ) -> None:
super().__init__(self._handle_udp, self._handle_tcp, "ans.pid")
self._zone_tree: _ZoneTree = _ZoneTree()
self._response_handlers: List[ResponseHandler] = []
self._acknowledge_manual_dname_handling = acknowledge_manual_dname_handling
+ self._acknowledge_tsig_dnspython_hacks = acknowledge_tsig_dnspython_hacks
self._load_zones()
"""
try:
query = dns.message.from_wire(wire)
+ except dns.message.UnknownTSIGKey:
+ self._abort_if_on_dnspython_version_less_than_2_0_0()
+ self._abort_if_tsig_signed_query_received_unless_acknowledged()
+ query = _DnsMessageWithTsigDisabled.from_wire(wire)
except dns.exception.DNSException as exc:
logging.error("Invalid query from %s (%s): %s", peer, wire.hex(), exc)
return
response_length = struct.pack("!H", len(response))
yield response_length + response
+ def _abort_if_on_dnspython_version_less_than_2_0_0(self) -> None:
+ if dns.version.MAJOR < 2:
+ error = "Receiving TSIG signed queries requires dnspython >= 2.0.0; "
+ error += 'add `pytest.importorskip("dns", minversion="2.0.0")` '
+ error += "to the test module to skip this test."
+ raise RuntimeError(error)
+
+ def _abort_if_tsig_signed_query_received_unless_acknowledged(self) -> None:
+ if self._acknowledge_tsig_dnspython_hacks:
+ return
+
+ error = "TSIG-signed query received; "
+ error += "due to a bug in dnspython, this requires some hacking around; "
+ error += "you may experience unexpected behavior when dealing with TSIG; "
+ error += "TSIG validation is disabled, so any TSIG handling must be done "
+ error += "manually; pass `acknowledge_tsig_dnspython_hacks=True` to the "
+ error += "AsyncDnsServer constructor to acknowledge this and continue."
+
+ raise ValueError(error)
+
async def _prepare_responses(
self, qctx: QueryContext
) -> AsyncGenerator[Optional[Union[dns.message.Message, bytes]], None]: