]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/test_Responses.py
Merge pull request #13860 from Habbie/auth-lua-dblookup-qtype
[thirdparty/pdns.git] / regression-tests.dnsdist / test_Responses.py
index fb9276493fd647e014ae787771d059f48a05d127..af16a644bf3bc4527632eb696dd0be4a18d63be0 100644 (file)
@@ -29,8 +29,8 @@ class TestResponseRuleNXDelayed(DNSDistTest):
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
         end = datetime.now()
         receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(response, receivedResponse)
         self.assertTrue((end - begin) > timedelta(0, 1))
 
         # NoError over UDP
@@ -39,8 +39,8 @@ class TestResponseRuleNXDelayed(DNSDistTest):
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
         end = datetime.now()
         receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(response, receivedResponse)
         self.assertTrue((end - begin) < timedelta(0, 1))
 
         # NX over TCP
@@ -49,12 +49,13 @@ class TestResponseRuleNXDelayed(DNSDistTest):
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
         end = datetime.now()
         receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(response, receivedResponse)
         self.assertTrue((end - begin) < timedelta(0, 1))
 
 class TestResponseRuleERCode(DNSDistTest):
 
+    _extraStartupSleep = 1
     _config_template = """
     newServer{address="127.0.0.1:%s"}
     addResponseAction(ERCodeRule(DNSRCode.BADVERS), DelayResponseAction(1000))
@@ -80,8 +81,8 @@ class TestResponseRuleERCode(DNSDistTest):
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
         end = datetime.now()
         receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(response, receivedResponse)
         self.assertTrue((end - begin) > timedelta(0, 1))
 
         # BADKEY (17, an ERCode) over UDP
@@ -90,8 +91,8 @@ class TestResponseRuleERCode(DNSDistTest):
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
         end = datetime.now()
         receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(response, receivedResponse)
         self.assertTrue((end - begin) < timedelta(0, 1))
 
         # NoError (non-ERcode, basic RCode bits match BADVERS) over UDP
@@ -100,8 +101,8 @@ class TestResponseRuleERCode(DNSDistTest):
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
         end = datetime.now()
         receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(response, receivedResponse)
         self.assertTrue((end - begin) < timedelta(0, 1))
 
 class TestResponseRuleQNameDropped(DNSDistTest):
@@ -126,8 +127,8 @@ class TestResponseRuleQNameDropped(DNSDistTest):
             sender = getattr(self, method)
             (receivedQuery, receivedResponse) = sender(query, response)
             receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(receivedResponse, None)
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(receivedResponse, None)
 
     def testNotDropped(self):
         """
@@ -144,8 +145,8 @@ class TestResponseRuleQNameDropped(DNSDistTest):
             sender = getattr(self, method)
             (receivedQuery, receivedResponse) = sender(query, response)
             receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(response, receivedResponse)
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
 
 class TestResponseRuleQNameAllowed(DNSDistTest):
 
@@ -170,8 +171,8 @@ class TestResponseRuleQNameAllowed(DNSDistTest):
             sender = getattr(self, method)
             (receivedQuery, receivedResponse) = sender(query, response)
             receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(response, receivedResponse)
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
 
     def testNotAllowed(self):
         """
@@ -188,8 +189,8 @@ class TestResponseRuleQNameAllowed(DNSDistTest):
             sender = getattr(self, method)
             (receivedQuery, receivedResponse) = sender(query, response)
             receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(receivedResponse, None)
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(receivedResponse, None)
 
 class TestResponseRuleEditTTL(DNSDistTest):
 
@@ -228,10 +229,155 @@ class TestResponseRuleEditTTL(DNSDistTest):
             sender = getattr(self, method)
             (receivedQuery, receivedResponse) = sender(query, response)
             receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(response, receivedResponse)
-            self.assertNotEquals(response.answer[0].ttl, receivedResponse.answer[0].ttl)
-            self.assertEquals(receivedResponse.answer[0].ttl, self._ttl)
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
+            self.assertNotEqual(response.answer[0].ttl, receivedResponse.answer[0].ttl)
+            self.assertEqual(receivedResponse.answer[0].ttl, self._ttl)
+
+class TestResponseRuleLimitTTL(DNSDistTest):
+
+    _lowttl = 60
+    _defaulttl = 3600
+    _highttl = 18000
+    _config_params = ['_lowttl', '_highttl', '_testServerPort']
+    _config_template = """
+    local ffi = require("ffi")
+    local lowttl = %d
+    local highttl = %d
+
+    function luaFFISetMinTTL(dr)
+      ffi.C.dnsdist_ffi_dnsresponse_set_min_ttl(dr, highttl)
+      return DNSResponseAction.None, ""
+    end
+    function luaFFISetMaxTTL(dr)
+      ffi.C.dnsdist_ffi_dnsresponse_set_max_ttl(dr, lowttl)
+      return DNSResponseAction.None, ""
+    end
+
+    newServer{address="127.0.0.1:%s"}
+
+    addResponseAction("min.responses.tests.powerdns.com.", SetMinTTLResponseAction(highttl))
+    addResponseAction("max.responses.tests.powerdns.com.", SetMaxTTLResponseAction(lowttl))
+    addResponseAction("ffi.min.limitttl.responses.tests.powerdns.com.", LuaFFIResponseAction(luaFFISetMinTTL))
+    addResponseAction("ffi.max.limitttl.responses.tests.powerdns.com.", LuaFFIResponseAction(luaFFISetMaxTTL))
+    """
+
+    def testLimitTTL(self):
+        """
+        Responses: Alter the TTLs via Limiter
+        """
+        name = 'min.responses.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
+            self.assertNotEqual(response.answer[0].ttl, receivedResponse.answer[0].ttl)
+            self.assertEqual(receivedResponse.answer[0].ttl, self._highttl)
+
+        name = 'max.responses.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
+            self.assertNotEqual(response.answer[0].ttl, receivedResponse.answer[0].ttl)
+            self.assertEqual(receivedResponse.answer[0].ttl, self._lowttl)
+
+    def testLimitTTLFFI(self):
+        """
+        Responses: Alter the TTLs via Limiter
+        """
+        name = 'ffi.min.responses.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
+            self.assertNotEqual(response.answer[0].ttl, receivedResponse.answer[0].ttl)
+            self.assertEqual(receivedResponse.answer[0].ttl, self._highttl)
+
+        name = 'ffi.max.responses.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
+            self.assertNotEqual(response.answer[0].ttl, receivedResponse.answer[0].ttl)
+            self.assertEqual(receivedResponse.answer[0].ttl, self._lowttl)
+
+class TestSetReducedTTL(DNSDistTest):
+
+    _percentage = 42
+    _initialTTL = 100
+    _config_params = ['_percentage', '_testServerPort']
+    _config_template = """
+    addResponseAction(AllRule(), SetReducedTTLResponseAction(%d))
+    newServer{address="127.0.0.1:%s"}
+    """
+
+    def testLimitTTL(self):
+        """
+        Responses: Reduce TTL to 42%
+        """
+        name = 'reduced-ttl.responses.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    self._initialTTL,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
+            self.assertNotEqual(response.answer[0].ttl, receivedResponse.answer[0].ttl)
+            self.assertEqual(receivedResponse.answer[0].ttl, self._percentage)
 
 class TestResponseLuaActionReturnSyntax(DNSDistTest):
 
@@ -265,8 +411,8 @@ class TestResponseLuaActionReturnSyntax(DNSDistTest):
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
         end = datetime.now()
         receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(response, receivedResponse)
         self.assertTrue((end - begin) > timedelta(0, 1))
 
     def testDropped(self):
@@ -284,5 +430,78 @@ class TestResponseLuaActionReturnSyntax(DNSDistTest):
             sender = getattr(self, method)
             (receivedQuery, receivedResponse) = sender(query, response)
             receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(receivedResponse, None)
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(receivedResponse, None)
+
+class TestResponseClearRecordsType(DNSDistTest):
+
+    _config_params = ['_testServerPort']
+    _config_template = """
+    local ffi = require("ffi")
+
+    function luafct(dr)
+      ffi.C.dnsdist_ffi_dnsresponse_clear_records_type(dr, DNSQType.AAAA)
+      return DNSResponseAction.HeaderModify, ""
+    end
+
+    newServer{address="127.0.0.1:%s"}
+
+    addResponseAction("ffi.clear-records-type.responses.tests.powerdns.com.", LuaFFIResponseAction(luafct))
+    addResponseAction("clear-records-type.responses.tests.powerdns.com.", ClearRecordTypesResponseAction(DNSQType.AAAA))
+    """
+
+    def testClearedFFI(self):
+        """
+        Responses: Removes records of a given type (FFI API)
+        """
+        name = 'ffi.clear-records-type.responses.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+        expectedResponse.answer.append(rrset)
+        rrset = dns.rrset.from_text(name,
+                                    3660,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '2001:DB8::1', '2001:DB8::2')
+        response.answer.append(rrset)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+    def testCleared(self):
+        """
+        Responses: Removes records of a given type
+        """
+        name = 'clear-records-type.responses.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+        expectedResponse.answer.append(rrset)
+        rrset = dns.rrset.from_text(name,
+                                    3660,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '2001:DB8::1', '2001:DB8::2')
+        response.answer.append(rrset)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(expectedResponse, receivedResponse)