]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/test_Responses.py
Merge pull request #8945 from rgacogne/ddist-x-forwarded-for
[thirdparty/pdns.git] / regression-tests.dnsdist / test_Responses.py
index a1f9072929c01b92beb2b17ef8317f88e963c224..fb9276493fd647e014ae787771d059f48a05d127 100644 (file)
@@ -8,7 +8,7 @@ class TestResponseRuleNXDelayed(DNSDistTest):
 
     _config_template = """
     newServer{address="127.0.0.1:%s"}
-    addResponseAction(RCodeRule(dnsdist.NXDOMAIN), DelayResponseAction(1000))
+    addResponseAction(RCodeRule(DNSRCode.NXDOMAIN), DelayResponseAction(1000))
     """
 
     def testNXDelayed(self):
@@ -53,6 +53,57 @@ class TestResponseRuleNXDelayed(DNSDistTest):
         self.assertEquals(response, receivedResponse)
         self.assertTrue((end - begin) < timedelta(0, 1))
 
+class TestResponseRuleERCode(DNSDistTest):
+
+    _config_template = """
+    newServer{address="127.0.0.1:%s"}
+    addResponseAction(ERCodeRule(DNSRCode.BADVERS), DelayResponseAction(1000))
+    """
+
+    def testBADVERSDelayed(self):
+        """
+        Responses: Delayed on BADVERS
+
+        Send an A query to "delayed.responses.tests.powerdns.com.",
+        check that the response delay is longer than 1000 ms
+        for a BADVERS response over UDP, shorter for BADKEY and NoError.
+        """
+        name = 'delayed.responses.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        response.use_edns(edns=True)
+
+        # BADVERS over UDP
+        # BADVERS == 16, so rcode==0, ercode==1
+        response.set_rcode(dns.rcode.BADVERS)
+        begin = datetime.now()
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        end = datetime.now()
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+        self.assertTrue((end - begin) > timedelta(0, 1))
+
+        # BADKEY (17, an ERCode) over UDP
+        response.set_rcode(17)
+        begin = datetime.now()
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        end = datetime.now()
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+        self.assertTrue((end - begin) < timedelta(0, 1))
+
+        # NoError (non-ERcode, basic RCode bits match BADVERS) over UDP
+        response.set_rcode(dns.rcode.NOERROR)
+        begin = datetime.now()
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        end = datetime.now()
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+        self.assertTrue((end - begin) < timedelta(0, 1))
+
 class TestResponseRuleQNameDropped(DNSDistTest):
 
     _config_template = """
@@ -71,15 +122,12 @@ class TestResponseRuleQNameDropped(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)
 
     def testNotDropped(self):
         """
@@ -92,15 +140,12 @@ class TestResponseRuleQNameDropped(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
 class TestResponseRuleQNameAllowed(DNSDistTest):
 
@@ -121,15 +166,12 @@ class TestResponseRuleQNameAllowed(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
 
     def testNotAllowed(self):
         """
@@ -142,15 +184,12 @@ class TestResponseRuleQNameAllowed(DNSDistTest):
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)
 
 class TestResponseRuleEditTTL(DNSDistTest):
 
@@ -185,16 +224,65 @@ class TestResponseRuleEditTTL(DNSDistTest):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
+            self.assertNotEquals(response.answer[0].ttl, receivedResponse.answer[0].ttl)
+            self.assertEquals(receivedResponse.answer[0].ttl, self._ttl)
+
+class TestResponseLuaActionReturnSyntax(DNSDistTest):
+
+    _config_template = """
+    newServer{address="127.0.0.1:%s"}
+    function customDelay(dr)
+      return DNSResponseAction.Delay, "1000"
+    end
+    function customDrop(dr)
+      return DNSResponseAction.Drop
+    end
+    addResponseAction("drop.responses.tests.powerdns.com.", LuaResponseAction(customDrop))
+    addResponseAction(RCodeRule(DNSRCode.NXDOMAIN), LuaResponseAction(customDelay))
+    """
+
+    def testResponseActionDelayed(self):
+        """
+        Responses: Delayed via LuaResponseAction
+
+        Send an A query to "delayed.responses.tests.powerdns.com.",
+        check that the response delay is longer than 1000 ms
+        for a NXDomain response over UDP, shorter for a NoError one.
+        """
+        name = 'delayed.responses.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        # NX over UDP
+        response.set_rcode(dns.rcode.NXDOMAIN)
+        begin = datetime.now()
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        end = datetime.now()
         receivedQuery.id = query.id
         self.assertEquals(query, receivedQuery)
         self.assertEquals(response, receivedResponse)
-        self.assertNotEquals(response.answer[0].ttl, receivedResponse.answer[0].ttl)
-        self.assertEquals(receivedResponse.answer[0].ttl, self._ttl)
+        self.assertTrue((end - begin) > timedelta(0, 1))
 
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-        self.assertNotEquals(response.answer[0].ttl, receivedResponse.answer[0].ttl)
-        self.assertEquals(receivedResponse.answer[0].ttl, self._ttl)
+    def testDropped(self):
+        """
+        Responses: Dropped via user defined LuaResponseAction
+
+        Send an A query to "drop.responses.tests.powerdns.com.",
+        check that the response (not the query) is dropped.
+        """
+        name = 'drop.responses.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(receivedResponse, None)