]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add a regression test rewriting the response code via Lua 14946/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 9 Dec 2024 13:39:03 +0000 (14:39 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 9 Dec 2024 13:39:03 +0000 (14:39 +0100)
regression-tests.dnsdist/test_Responses.py

index af16a644bf3bc4527632eb696dd0be4a18d63be0..26c1042c0a03111d20f85b9de570576178f6fd67 100644 (file)
@@ -505,3 +505,64 @@ class TestResponseClearRecordsType(DNSDistTest):
             receivedQuery.id = query.id
             self.assertEqual(query, receivedQuery)
             self.assertEqual(expectedResponse, receivedResponse)
+
+class TestResponseRewriteServFail(DNSDistTest):
+
+    _config_params = ['_testServerPort']
+    _config_template = """
+    newServer{address="127.0.0.1:%s"}
+
+    function rewriteServFail(dq)
+      if dq.rcode == DNSRCode.SERVFAIL then
+         dq.rcode = DNSRCode.NOERROR
+        return DNSResponseAction.HeaderModify
+      end
+      return DNSResponse.None
+    end
+    addResponseAction(AndRule({QTypeRule(DNSQType.AAAA),RCodeRule(DNSRCode.SERVFAIL)}), LuaResponseAction(rewriteServFail))
+    """
+
+    def testRewriteServFail(self):
+        """
+        Responses: Rewrite AAAA ServFails as NoError (don't ask)
+        """
+        name = 'rewrite-servfail.responses.tests.powerdns.com.'
+
+        query = dns.message.make_query(name, 'AAAA', 'IN')
+        response = dns.message.make_response(query)
+        expectedResponse = dns.message.make_response(query)
+
+        response.set_rcode(dns.rcode.SERVFAIL)
+        expectedResponse.set_rcode(dns.rcode.NOERROR)
+
+        rrset = dns.rrset.from_text(name,
+                                    3660,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '2001:DB8::1', '2001:DB8::2')
+        response.answer.append(rrset)
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+        # but ServFail for a different type should stay the same
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        response.set_rcode(dns.rcode.SERVFAIL)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)