]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Add common parts of reclimit test custom servers
authorŠtěpán Balážik <stepan@isc.org>
Fri, 16 Jan 2026 09:14:04 +0000 (10:14 +0100)
committerŠtěpán Balážik <stepan@isc.org>
Thu, 9 Apr 2026 00:28:13 +0000 (02:28 +0200)
These will be shared by all the ans*/ans.py files.

bin/tests/system/reclimit/reclimit_ans.py [new file with mode: 0644]

diff --git a/bin/tests/system/reclimit/reclimit_ans.py b/bin/tests/system/reclimit/reclimit_ans.py
new file mode 100644 (file)
index 0000000..b08d5f1
--- /dev/null
@@ -0,0 +1,271 @@
+"""
+Copyright (C) Internet Systems Consortium, Inc. ("ISC")
+
+SPDX-License-Identifier: MPL-2.0
+
+This Source Code Form is subject to the terms of the Mozilla Public
+License, v. 2.0.  If a copy of the MPL was not distributed with this
+file, you can obtain one at https://mozilla.org/MPL/2.0/.
+
+See the COPYRIGHT file distributed with this work for additional
+information regarding copyright ownership.
+"""
+
+from collections.abc import AsyncGenerator
+
+import asyncio
+import functools
+
+import dns.flags
+import dns.name
+import dns.rcode
+import dns.rdataclass
+import dns.rdatatype
+import dns.rrset
+
+from isctest.asyncserver import (
+    ControlCommand,
+    ControllableAsyncDnsServer,
+    DnsResponseSend,
+    QnameHandler,
+    QueryContext,
+    ResponseAction,
+    ResponseHandler,
+    StaticResponseHandler,
+)
+
+
+class ReclimitStateHandler(QnameHandler):
+    """
+    Handler for shared state of all the handlers in one server
+    """
+
+    qnames = ["count.", "reset."]
+
+    def __init__(self, indirect_send_response_default: bool = True) -> None:
+        self._count = 0
+        self._count_lock = asyncio.Lock()
+        self._indirect_send_response_default = indirect_send_response_default
+        self._indirect_send_response = indirect_send_response_default
+        self._indirect_send_response_lock = asyncio.Lock()
+        self._limit = 0
+        super().__init__()
+
+    async def get_responses(
+        self, qctx: QueryContext
+    ) -> AsyncGenerator[DnsResponseSend, None]:
+        if f"{qctx.qname}" == "count.":
+            await self.increment_count()
+            qctx.response.answer.append(await self._count_txt_rrset())
+            yield DnsResponseSend(qctx.response, authoritative=True)
+        elif f"{qctx.qname}" == "reset.":
+            await self.set_indirect_send_response(self._indirect_send_response_default)
+            await self.reset_count()
+            yield DnsResponseSend(qctx.response, authoritative=False)
+
+    async def reset_count(self) -> None:
+        async with self._count_lock:
+            self._count = 0
+
+    async def increment_count(self) -> None:
+        async with self._count_lock:
+            self._count += 1
+
+    async def _count_txt_rrset(self) -> dns.rrset.RRset:
+        async with self._count_lock:
+            count = self._count
+        return dns.rrset.from_text(
+            "count.", 0, dns.rdataclass.IN, dns.rdatatype.TXT, f"{count}"
+        )
+
+    async def get_indirect_send_response(self) -> bool:
+        async with self._indirect_send_response_lock:
+            return self._indirect_send_response
+
+    async def set_indirect_send_response(self, value: bool) -> None:
+        async with self._indirect_send_response_lock:
+            self._indirect_send_response = value
+
+    def set_limit(self, limit: int) -> None:
+        self._limit = limit
+
+    def get_limit(self) -> int:
+        return self._limit
+
+
+class ReclimitHandler(ResponseHandler):
+    """
+    Base class for handlers in this test
+
+    Provides access to the shared state through the state handler and automatically increments the count on each query.
+    """
+
+    _COUNTED_HANDLER_WRAPPED_ATTR = "__counted_handler_wrapped__"
+
+    def __init__(self, state_handler: ReclimitStateHandler) -> None:
+        self._state_handler = state_handler
+        super().__init__()
+
+    def __init_subclass__(cls) -> None:
+        """
+        Wrap the get_responses method of all subclasses to increment the count
+        in the state handler whenever they are called.
+        """
+
+        super().__init_subclass__()
+        original = cls.get_responses
+        if original is ResponseHandler.get_responses:
+            return
+        if getattr(original, ReclimitHandler._COUNTED_HANDLER_WRAPPED_ATTR, False):
+            return
+
+        @functools.wraps(original)
+        async def wrapped_get_responses(
+            self: "ReclimitHandler", qctx: QueryContext
+        ) -> AsyncGenerator[DnsResponseSend, None]:
+            await self._state_handler.increment_count()  # pylint: disable=protected-access
+            async for response in original(self, qctx):
+                yield response
+
+        setattr(
+            wrapped_get_responses, ReclimitHandler._COUNTED_HANDLER_WRAPPED_ATTR, True
+        )
+        cls.get_responses = wrapped_get_responses
+
+    def _get_limit(self) -> int:
+        return self._state_handler.get_limit()
+
+    async def _get_indirect_send_response(self) -> bool:
+        return await self._state_handler.get_indirect_send_response()
+
+    async def _set_indirect_send_response(self, value: bool) -> None:
+        await self._state_handler.set_indirect_send_response(value)
+
+
+class LimitControlCommand(ControlCommand):
+    control_subdomain = "limit"
+
+    def __init__(self, state_handler: ReclimitStateHandler) -> None:
+        self._state_handler = state_handler
+        super().__init__()
+
+    def handle(
+        self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
+    ) -> str | None:
+        if len(args) != 1:
+            return "Expected exactly one label"
+
+        try:
+            limit = int(args[0])
+        except ValueError:
+            return "Expected an integer"
+
+        self._state_handler.set_limit(limit)
+        return f"Limit set to {limit}"
+
+
+def a(owner: str | dns.name.Name, ns_number: int) -> dns.rrset.RRset:
+    return dns.rrset.from_text(
+        f"{owner}", 3600, dns.rdataclass.IN, dns.rdatatype.A, f"10.53.0.{ns_number}"
+    )
+
+
+def ns(owner: str | dns.name.Name, target: str | dns.name.Name) -> dns.rrset.RRset:
+    return dns.rrset.from_text(
+        f"{owner}", 86400, dns.rdataclass.IN, dns.rdatatype.NS, f"{target}"
+    )
+
+
+class DirectExampleHandler(ReclimitHandler, QnameHandler):
+    qnames = ["direct.example.org", "direct.example.net"]
+
+    async def get_responses(
+        self, qctx: QueryContext
+    ) -> AsyncGenerator[DnsResponseSend, None]:
+        if qctx.qtype == dns.rdatatype.A:
+            qctx.response.answer.append(a(qctx.qname, 4))
+        yield DnsResponseSend(qctx.response)
+
+
+class IndirectExampleOrgHandler(ReclimitHandler, QnameHandler):
+    qnames = [f"indirect{i}.example.org" for i in range(1, 9)]
+
+    async def get_responses(
+        self, qctx: QueryContext
+    ) -> AsyncGenerator[DnsResponseSend, None]:
+        if not await self._get_indirect_send_response():
+            qctx.response.authority.append(ns(f"{qctx.qname}", "ns1.1.example.org."))
+            qctx.response.flags &= ~dns.flags.AA
+        elif qctx.qtype == dns.rdatatype.A:
+            qctx.response.answer.append(a(qctx.qname, 4))
+        yield DnsResponseSend(qctx.response)
+
+
+def is_ns1_example(qname: dns.name.Name, tld: str) -> bool:
+    labels = qname.labels
+    return (
+        len(labels) == 5
+        and labels[3] == tld.encode()
+        and labels[2] == b"example"
+        and labels[1].isdigit()
+        and labels[0] == b"ns1"
+    )
+
+
+class Ns1ExampleOrgHandler(ReclimitHandler):
+    def __init__(self, state_handler: ReclimitStateHandler) -> None:
+        self._waiting_responses_locks: dict[dns.name.Name, asyncio.Lock] = {}
+        super().__init__(state_handler)
+
+    def match(self, qctx: QueryContext) -> bool:
+        return is_ns1_example(qctx.qname, "org") and qctx.qtype in (
+            dns.rdatatype.A,
+            dns.rdatatype.AAAA,
+        )
+
+    async def get_responses(
+        self, qctx: QueryContext
+    ) -> AsyncGenerator[ResponseAction, None]:
+        ns_number = int(qctx.qname.labels[1])
+        next_ns_number = ns_number + 1
+        if not self._get_limit() or (
+            not await self._get_indirect_send_response()
+            and next_ns_number <= self._get_limit()
+        ):
+            qctx.response.authority.append(
+                ns(f"{ns_number}.example.org.", f"ns1.{next_ns_number}.example.org.")
+            )
+            qctx.response.flags &= ~dns.flags.AA
+        else:
+            await self._set_indirect_send_response(True)
+            if qctx.qtype == dns.rdatatype.A:
+                qctx.response.answer.append(a(qctx.qname, 4))
+
+            # XXX: The original Perl implementation doesn't set AA on empty responses,
+            #      let's do the same for packet-for-packet compatibility and drop it later.
+            elif qctx.qtype == dns.rdatatype.AAAA:
+                qctx.response.flags &= ~dns.flags.AA
+
+        if qctx.qname in self._waiting_responses_locks:
+            # Second query arrived, release the first response.
+            self._waiting_responses_locks[qctx.qname].release()
+            await asyncio.sleep(0)  # Yield to allow the first response to be sent.
+            yield DnsResponseSend(qctx.response)
+        else:
+            lock = asyncio.Lock()
+            self._waiting_responses_locks[qctx.qname] = lock
+            await lock.acquire()
+
+            # Release the lock forcefully after 500 ms
+            asyncio.get_event_loop().call_later(
+                0.5, lambda: lock.release() if lock.locked() else None
+            )
+
+            # Wait until the second query for the same query arrives.
+            async with lock:
+                yield DnsResponseSend(qctx.response)
+            del self._waiting_responses_locks[qctx.qname]
+
+
+class FallbackNxdomainHandler(ReclimitHandler, StaticResponseHandler):
+    rcode = dns.rcode.NXDOMAIN