]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Refactor AxfrHandler and hoist it to isctest.asyncserver
authorŠtěpán Balážik <stepan@isc.org>
Mon, 19 Jan 2026 21:04:32 +0000 (22:04 +0100)
committerŠtěpán Balážik <stepan@isc.org>
Fri, 17 Apr 2026 14:26:07 +0000 (14:26 +0000)
It will be useful in the xfer system test as well.

bin/tests/system/isctest/asyncserver.py
bin/tests/system/ixfr/ans2/ans.py

index 080c08c380684af2e497fcfd679b0261a48bfce6..e29d6f7e9a0e98284d1ab27ecab270aa0ddcebd4 100644 (file)
@@ -11,7 +11,7 @@ See the COPYRIGHT file distributed with this work for additional
 information regarding copyright ownership.
 """
 
-from collections.abc import AsyncGenerator, Callable, Coroutine, Sequence
+from collections.abc import AsyncGenerator, Callable, Collection, Coroutine, Sequence
 from dataclasses import dataclass, field
 from typing import Any, cast
 
@@ -875,6 +875,63 @@ class ForwarderHandler(ResponseHandler):
             yield BytesResponseSend(response.result())
 
 
+class AxfrHandler(ResponseHandler):
+    """
+    Base class for AXFR response handlers.
+
+    Subclasses must define the `initial_soa`, `zone_contents`, and `final_soa`
+    properties to specify the content of the AXFR responses.
+
+    The responses are constructed without any regard to zone data.
+    """
+
+    @property
+    @abc.abstractmethod
+    def initial_soa(self) -> dns.rrset.RRset:
+        """
+        Initial SOA record of response packets sent in response to
+        AXFR queries.
+        """
+        raise NotImplementedError
+
+    @property
+    @abc.abstractmethod
+    def zone_contents(self) -> Collection[dns.rrset.RRset]:
+        """
+        Answer section of the second response packet sent in response to
+        AXFR queries.
+        """
+        raise NotImplementedError
+
+    @property
+    @abc.abstractmethod
+    def final_soa(self) -> dns.rrset.RRset:
+        """
+        Final SOA record of response packets sent in response to
+        AXFR queries.
+        """
+        raise NotImplementedError
+
+    def match(self, qctx: QueryContext) -> bool:
+        return qctx.qtype == dns.rdatatype.AXFR
+
+    async def get_responses(
+        self, qctx: QueryContext
+    ) -> AsyncGenerator[DnsResponseSend, None]:
+        qctx.prepare_new_response(with_zone_data=False)
+        qctx.response.answer.append(self.initial_soa)
+        yield DnsResponseSend(qctx.response)
+
+        qctx.prepare_new_response(with_zone_data=False)
+        for rrset_ in self.zone_contents:
+            qctx.response.answer.append(rrset_)
+        yield DnsResponseSend(qctx.response)
+
+        qctx.prepare_new_response(with_zone_data=False)
+        qctx.response.answer.append(self.final_soa)
+        yield DnsResponseSend(qctx.response)
+
+
 @dataclass
 class _ZoneTreeNode:
     """
index b6a052646901dc2284fee706704a3418f274782d..87f7dbef48375a99930cdb423f380efaab72a5bb 100644 (file)
@@ -11,7 +11,7 @@ See the COPYRIGHT file distributed with this work for additional
 information regarding copyright ownership.
 """
 
-from collections.abc import AsyncGenerator, Collection, Iterable
+from collections.abc import AsyncGenerator, Collection
 
 import abc
 
@@ -21,6 +21,7 @@ import dns.rdatatype
 import dns.rrset
 
 from isctest.asyncserver import (
+    AxfrHandler,
     ControllableAsyncDnsServer,
     DnsResponseSend,
     QueryContext,
@@ -85,29 +86,6 @@ class SoaHandler(ResponseHandler):
         yield DnsResponseSend(qctx.response)
 
 
-class AxfrHandler(ResponseHandler):
-    @property
-    @abc.abstractmethod
-    def answers(self) -> Iterable[Collection[dns.rrset.RRset]]:
-        """
-        Answer sections of response packets sent in response to
-        AXFR queries.
-        """
-        raise NotImplementedError
-
-    def match(self, qctx: QueryContext) -> bool:
-        return qctx.qtype == dns.rdatatype.AXFR
-
-    async def get_responses(
-        self, qctx: QueryContext
-    ) -> AsyncGenerator[DnsResponseSend, None]:
-        for answer in self.answers:
-            response = qctx.prepare_new_response()
-            for rrset_ in answer:
-                response.answer.append(rrset_)
-            yield DnsResponseSend(response)
-
-
 class IxfrHandler(ResponseHandler):
     @property
     @abc.abstractmethod
@@ -130,16 +108,14 @@ class IxfrHandler(ResponseHandler):
 
 
 class InitialAfxrHandler(AxfrHandler):
-    answers = (
-        (soa(1),),
-        (
-            ns(),
-            txt("initial AXFR"),
-            a("10.0.0.61", owner="a.nil."),
-            a("10.0.0.62", owner="b.nil."),
-        ),
-        (soa(1),),
+    initial_soa = soa(1)
+    zone_contents = (
+        ns(),
+        txt("initial AXFR"),
+        a("10.0.0.61", owner="a.nil."),
+        a("10.0.0.62", owner="b.nil."),
     )
+    final_soa = soa(1)
 
 
 class SuccessfulIfxrHandler(IxfrHandler):
@@ -169,14 +145,12 @@ class NotExactIxfrHandler(IxfrHandler):
 
 
 class FallbackNotExactAxfrHandler(AxfrHandler):
-    answers = (
-        (soa(3),),
-        (
-            ns(),
-            txt("fallback AXFR"),
-        ),
-        (soa(3),),
+    initial_soa = soa(3)
+    zone_contents = (
+        ns(),
+        txt("fallback AXFR"),
     )
+    final_soa = soa(3)
 
 
 class TooManyRecordsIxfrHandler(IxfrHandler):
@@ -195,14 +169,12 @@ class TooManyRecordsIxfrHandler(IxfrHandler):
 
 
 class FallbackTooManyRecordsAxfrHandler(AxfrHandler):
-    answers = (
-        (soa(3),),
-        (
-            ns(),
-            txt("fallback AXFR on too many records"),
-        ),
-        (soa(3),),
+    initial_soa = soa(3)
+    zone_contents = (
+        ns(),
+        txt("fallback AXFR on too many records"),
     )
+    final_soa = soa(3)
 
 
 class BadSoaOwnerIxfrHandler(IxfrHandler):
@@ -216,14 +188,12 @@ class BadSoaOwnerIxfrHandler(IxfrHandler):
 
 
 class FallbackBadSoaOwnerAxfrHandler(AxfrHandler):
-    answers = (
-        (soa(4),),
-        (
-            ns(),
-            txt("serial 4, fallback AXFR", owner="test.nil."),
-        ),
-        (soa(4),),
+    initial_soa = soa(4)
+    zone_contents = (
+        ns(),
+        txt("serial 4, fallback AXFR", owner="test.nil."),
     )
+    final_soa = soa(4)
 
 
 def main() -> None: