]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Allow ResponseHandlers to roll back changes made to a response
authorŠtěpán Balážik <stepan@isc.org>
Wed, 12 Nov 2025 15:19:25 +0000 (16:19 +0100)
committerŠtěpán Balážik <stepan@isc.org>
Thu, 18 Dec 2025 12:13:59 +0000 (13:13 +0100)
Previously, this was only possible by making a new response by calling
make_response on qctx.query. This however ignored the `default_aa` and
`default_rcode` parameters of AsyncDnsServer.

Add prepare_new_response and save_initialized_response methods to
QueryContext.

bin/tests/system/isctest/asyncserver.py

index 8300ff2b014ee523aee855329fef01f08a012cac..98fec6b6634e1330e97e1ac766cf54620af6c711 100644 (file)
@@ -28,6 +28,7 @@ from typing import (
 import abc
 import asyncio
 import contextlib
+import copy
 import enum
 import functools
 import logging
@@ -269,11 +270,17 @@ class QueryContext:
     response: dns.message.Message
     peer: Peer
     protocol: DnsProtocol
-    zone: Optional[dns.zone.Zone] = None
-    soa: Optional[dns.rrset.RRset] = None
-    node: Optional[dns.node.Node] = None
-    answer: Optional[dns.rdataset.Rdataset] = None
-    alias: Optional[dns.name.Name] = None
+    zone: Optional[dns.zone.Zone] = field(default=None, init=False)
+    soa: Optional[dns.rrset.RRset] = field(default=None, init=False)
+    node: Optional[dns.node.Node] = field(default=None, init=False)
+    answer: Optional[dns.rdataset.Rdataset] = field(default=None, init=False)
+    alias: Optional[dns.name.Name] = field(default=None, init=False)
+    _initialized_response: Optional[dns.message.Message] = field(
+        default=None, init=False
+    )
+    _initialized_response_with_zone_data: Optional[dns.message.Message] = field(
+        default=None, init=False
+    )
 
     @property
     def qname(self) -> dns.name.Name:
@@ -291,6 +298,23 @@ class QueryContext:
     def qtype(self) -> dns.rdatatype.RdataType:
         return self.query.question[0].rdtype
 
+    def prepare_new_response(
+        self, /, with_zone_data: bool = True
+    ) -> dns.message.Message:
+        if with_zone_data:
+            assert self._initialized_response_with_zone_data
+            self.response = copy.deepcopy(self._initialized_response_with_zone_data)
+        else:
+            assert self._initialized_response
+            self.response = copy.deepcopy(self._initialized_response)
+        return self.response
+
+    def save_initialized_response(self, /, with_zone_data: bool) -> None:
+        if with_zone_data:
+            self._initialized_response_with_zone_data = copy.deepcopy(self.response)
+        else:
+            self._initialized_response = copy.deepcopy(self.response)
+
 
 @dataclass
 class ResponseAction(abc.ABC):
@@ -1116,8 +1140,10 @@ class AsyncDnsServer(AsyncServer):
         qctx.response.set_rcode(self._default_rcode)
         if self._default_aa:
             qctx.response.flags |= dns.flags.AA
+        qctx.save_initialized_response(with_zone_data=False)
 
         self._prepare_response_from_zone_data(qctx)
+        qctx.save_initialized_response(with_zone_data=True)
 
         response_handled = False
         async for action in self._run_response_handlers(qctx):