--- /dev/null
+"""
+Copyright (C) Internet Systems Consortium, Inc. ("ISC")
+
+SPDX-License-Identifier: MPL-2.0
+
+This Source Code Form is subject to the terms of the Mozilla Public
+License, v. 2.0. If a copy of the MPL was not distributed with this
+file, you can obtain one at https://mozilla.org/MPL/2.0/.
+
+See the COPYRIGHT file distributed with this work for additional
+information regarding copyright ownership.
+"""
+
+from collections.abc import AsyncGenerator
+
+import asyncio
+import functools
+
+import dns.flags
+import dns.name
+import dns.rcode
+import dns.rdataclass
+import dns.rdatatype
+import dns.rrset
+
+from isctest.asyncserver import (
+ ControlCommand,
+ ControllableAsyncDnsServer,
+ DnsResponseSend,
+ QnameHandler,
+ QueryContext,
+ ResponseAction,
+ ResponseHandler,
+ StaticResponseHandler,
+)
+
+
+class ReclimitStateHandler(QnameHandler):
+ """
+ Handler for shared state of all the handlers in one server
+ """
+
+ qnames = ["count.", "reset."]
+
+ def __init__(self, indirect_send_response_default: bool = True) -> None:
+ self._count = 0
+ self._count_lock = asyncio.Lock()
+ self._indirect_send_response_default = indirect_send_response_default
+ self._indirect_send_response = indirect_send_response_default
+ self._indirect_send_response_lock = asyncio.Lock()
+ self._limit = 0
+ super().__init__()
+
+ async def get_responses(
+ self, qctx: QueryContext
+ ) -> AsyncGenerator[DnsResponseSend, None]:
+ if f"{qctx.qname}" == "count.":
+ await self.increment_count()
+ qctx.response.answer.append(await self._count_txt_rrset())
+ yield DnsResponseSend(qctx.response, authoritative=True)
+ elif f"{qctx.qname}" == "reset.":
+ await self.set_indirect_send_response(self._indirect_send_response_default)
+ await self.reset_count()
+ yield DnsResponseSend(qctx.response, authoritative=False)
+
+ async def reset_count(self) -> None:
+ async with self._count_lock:
+ self._count = 0
+
+ async def increment_count(self) -> None:
+ async with self._count_lock:
+ self._count += 1
+
+ async def _count_txt_rrset(self) -> dns.rrset.RRset:
+ async with self._count_lock:
+ count = self._count
+ return dns.rrset.from_text(
+ "count.", 0, dns.rdataclass.IN, dns.rdatatype.TXT, f"{count}"
+ )
+
+ async def get_indirect_send_response(self) -> bool:
+ async with self._indirect_send_response_lock:
+ return self._indirect_send_response
+
+ async def set_indirect_send_response(self, value: bool) -> None:
+ async with self._indirect_send_response_lock:
+ self._indirect_send_response = value
+
+ def set_limit(self, limit: int) -> None:
+ self._limit = limit
+
+ def get_limit(self) -> int:
+ return self._limit
+
+
+class ReclimitHandler(ResponseHandler):
+ """
+ Base class for handlers in this test
+
+ Provides access to the shared state through the state handler and automatically increments the count on each query.
+ """
+
+ _COUNTED_HANDLER_WRAPPED_ATTR = "__counted_handler_wrapped__"
+
+ def __init__(self, state_handler: ReclimitStateHandler) -> None:
+ self._state_handler = state_handler
+ super().__init__()
+
+ def __init_subclass__(cls) -> None:
+ """
+ Wrap the get_responses method of all subclasses to increment the count
+ in the state handler whenever they are called.
+ """
+
+ super().__init_subclass__()
+ original = cls.get_responses
+ if original is ResponseHandler.get_responses:
+ return
+ if getattr(original, ReclimitHandler._COUNTED_HANDLER_WRAPPED_ATTR, False):
+ return
+
+ @functools.wraps(original)
+ async def wrapped_get_responses(
+ self: "ReclimitHandler", qctx: QueryContext
+ ) -> AsyncGenerator[DnsResponseSend, None]:
+ await self._state_handler.increment_count() # pylint: disable=protected-access
+ async for response in original(self, qctx):
+ yield response
+
+ setattr(
+ wrapped_get_responses, ReclimitHandler._COUNTED_HANDLER_WRAPPED_ATTR, True
+ )
+ cls.get_responses = wrapped_get_responses
+
+ def _get_limit(self) -> int:
+ return self._state_handler.get_limit()
+
+ async def _get_indirect_send_response(self) -> bool:
+ return await self._state_handler.get_indirect_send_response()
+
+ async def _set_indirect_send_response(self, value: bool) -> None:
+ await self._state_handler.set_indirect_send_response(value)
+
+
+class LimitControlCommand(ControlCommand):
+ control_subdomain = "limit"
+
+ def __init__(self, state_handler: ReclimitStateHandler) -> None:
+ self._state_handler = state_handler
+ super().__init__()
+
+ def handle(
+ self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
+ ) -> str | None:
+ if len(args) != 1:
+ return "Expected exactly one label"
+
+ try:
+ limit = int(args[0])
+ except ValueError:
+ return "Expected an integer"
+
+ self._state_handler.set_limit(limit)
+ return f"Limit set to {limit}"
+
+
+def a(owner: str | dns.name.Name, ns_number: int) -> dns.rrset.RRset:
+ return dns.rrset.from_text(
+ f"{owner}", 3600, dns.rdataclass.IN, dns.rdatatype.A, f"10.53.0.{ns_number}"
+ )
+
+
+def ns(owner: str | dns.name.Name, target: str | dns.name.Name) -> dns.rrset.RRset:
+ return dns.rrset.from_text(
+ f"{owner}", 86400, dns.rdataclass.IN, dns.rdatatype.NS, f"{target}"
+ )
+
+
+class DirectExampleHandler(ReclimitHandler, QnameHandler):
+ qnames = ["direct.example.org", "direct.example.net"]
+
+ async def get_responses(
+ self, qctx: QueryContext
+ ) -> AsyncGenerator[DnsResponseSend, None]:
+ if qctx.qtype == dns.rdatatype.A:
+ qctx.response.answer.append(a(qctx.qname, 4))
+ yield DnsResponseSend(qctx.response)
+
+
+class IndirectExampleOrgHandler(ReclimitHandler, QnameHandler):
+ qnames = [f"indirect{i}.example.org" for i in range(1, 9)]
+
+ async def get_responses(
+ self, qctx: QueryContext
+ ) -> AsyncGenerator[DnsResponseSend, None]:
+ if not await self._get_indirect_send_response():
+ qctx.response.authority.append(ns(f"{qctx.qname}", "ns1.1.example.org."))
+ qctx.response.flags &= ~dns.flags.AA
+ elif qctx.qtype == dns.rdatatype.A:
+ qctx.response.answer.append(a(qctx.qname, 4))
+ yield DnsResponseSend(qctx.response)
+
+
+def is_ns1_example(qname: dns.name.Name, tld: str) -> bool:
+ labels = qname.labels
+ return (
+ len(labels) == 5
+ and labels[3] == tld.encode()
+ and labels[2] == b"example"
+ and labels[1].isdigit()
+ and labels[0] == b"ns1"
+ )
+
+
+class Ns1ExampleOrgHandler(ReclimitHandler):
+ def __init__(self, state_handler: ReclimitStateHandler) -> None:
+ self._waiting_responses_locks: dict[dns.name.Name, asyncio.Lock] = {}
+ super().__init__(state_handler)
+
+ def match(self, qctx: QueryContext) -> bool:
+ return is_ns1_example(qctx.qname, "org") and qctx.qtype in (
+ dns.rdatatype.A,
+ dns.rdatatype.AAAA,
+ )
+
+ async def get_responses(
+ self, qctx: QueryContext
+ ) -> AsyncGenerator[ResponseAction, None]:
+ ns_number = int(qctx.qname.labels[1])
+ next_ns_number = ns_number + 1
+ if not self._get_limit() or (
+ not await self._get_indirect_send_response()
+ and next_ns_number <= self._get_limit()
+ ):
+ qctx.response.authority.append(
+ ns(f"{ns_number}.example.org.", f"ns1.{next_ns_number}.example.org.")
+ )
+ qctx.response.flags &= ~dns.flags.AA
+ else:
+ await self._set_indirect_send_response(True)
+ if qctx.qtype == dns.rdatatype.A:
+ qctx.response.answer.append(a(qctx.qname, 4))
+
+ # XXX: The original Perl implementation doesn't set AA on empty responses,
+ # let's do the same for packet-for-packet compatibility and drop it later.
+ elif qctx.qtype == dns.rdatatype.AAAA:
+ qctx.response.flags &= ~dns.flags.AA
+
+ if qctx.qname in self._waiting_responses_locks:
+ # Second query arrived, release the first response.
+ self._waiting_responses_locks[qctx.qname].release()
+ await asyncio.sleep(0) # Yield to allow the first response to be sent.
+ yield DnsResponseSend(qctx.response)
+ else:
+ lock = asyncio.Lock()
+ self._waiting_responses_locks[qctx.qname] = lock
+ await lock.acquire()
+
+ # Release the lock forcefully after 500 ms
+ asyncio.get_event_loop().call_later(
+ 0.5, lambda: lock.release() if lock.locked() else None
+ )
+
+ # Wait until the second query for the same query arrives.
+ async with lock:
+ yield DnsResponseSend(qctx.response)
+ del self._waiting_responses_locks[qctx.qname]
+
+
+class FallbackNxdomainHandler(ReclimitHandler, StaticResponseHandler):
+ rcode = dns.rcode.NXDOMAIN