]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add test for setting trailing data
authorRichard Gibson <richard.gibson@gmail.com>
Thu, 30 Aug 2018 20:37:20 +0000 (16:37 -0400)
committerRichard Gibson <richard.gibson@gmail.com>
Tue, 16 Oct 2018 21:45:34 +0000 (17:45 -0400)
regression-tests.dnsdist/test_Trailing.py

index 0a26d5f5dacfd6e1b27f6fa16365ce2e505d3693..33d23b1e9e7372da260ebc4c9663732372599dc6 100644 (file)
@@ -12,18 +12,64 @@ class TestTrailingDataToBackend(DNSDistTest):
     _testServerPort = 5360
     _config_template = """
     newServer{address="127.0.0.1:%s"}
+
+    function replaceTrailingData(dq)
+        local success = dq:setTrailingData({65, 66, 67}) -- "ABC"
+        if not success then
+            return DNSAction.ServFail, ""
+        end
+        return DNSAction.None, ""
+    end
+    addLuaAction("added.trailing.tests.powerdns.com.", replaceTrailingData)
+
+    function fillBuffer(dq)
+        local available = dq.size - dq.len
+        local tail = extendTableBy({}, available)
+        local success = dq:setTrailingData(tail)
+        if not success then
+            return DNSAction.ServFail, ""
+        end
+        return DNSAction.None, ""
+    end
+    addLuaAction("max.trailing.tests.powerdns.com.", fillBuffer)
+
+    function exceedBuffer(dq)
+        local available = dq.size - dq.len
+        local tail = extendTableBy({}, available + 1)
+        local success = dq:setTrailingData(tail)
+        if not success then
+            return DNSAction.ServFail, ""
+        end
+        return DNSAction.None, ""
+    end
+    addLuaAction("limited.trailing.tests.powerdns.com.", exceedBuffer)
+
+    function extendTableBy(t, n)
+        if n <= 1 then
+            if n == 1 then
+                t[#t + 1] = 0
+            end
+            return t
+        end
+
+        local lower = math.floor(n / 2)
+        local upper = n - lower
+        t = extendTableBy(t, lower)
+        t = extendTableBy(t, upper)
+        return t
+    end
     """
     @classmethod
     def startResponders(cls):
         print("Launching responders..")
 
-        # Respond SERVFAIL to queries with trailing data.
-        cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.SERVFAIL])
+        # Respond REFUSED to queries with trailing data.
+        cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.REFUSED])
         cls._UDPResponder.setDaemon(True)
         cls._UDPResponder.start()
 
-        # Respond SERVFAIL to queries with trailing data.
-        cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.SERVFAIL])
+        # Respond REFUSED to queries with trailing data.
+        cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.REFUSED])
         cls._TCPResponder.setDaemon(True)
         cls._TCPResponder.start()
 
@@ -42,7 +88,7 @@ class TestTrailingDataToBackend(DNSDistTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
         expectedResponse = dns.message.make_response(query)
-        expectedResponse.set_rcode(dns.rcode.SERVFAIL)
+        expectedResponse.set_rcode(dns.rcode.REFUSED)
 
         raw = query.to_wire()
         raw = raw + b'A'* 20
@@ -58,16 +104,118 @@ class TestTrailingDataToBackend(DNSDistTest):
             self.assertEquals(receivedQuery, query)
             self.assertEquals(receivedResponse, expectedResponse)
 
+    def testTrailingCapacity(self):
+        """
+        Trailing data: Fill buffer
+
+        """
+        name = 'max.trailing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.REFUSED)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+            # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, expectedResponse)
+
+    def testTrailingLimited(self):
+        """
+        Trailing data: Reject buffer overflows
+
+        """
+        name = 'limited.trailing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.SERVFAIL)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+            # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+            (_, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(receivedResponse, expectedResponse)
+
+    def testTrailingAdded(self):
+        """
+        Trailing data: Add
+
+        """
+        name = 'added.trailing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.REFUSED)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+            # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, expectedResponse)
+
 class TestTrailingDataToDnsdist(DNSDistTest):
     _config_template = """
     newServer{address="127.0.0.1:%s"}
+
+    addAction(AndRule({QNameRule("dropped.trailing.tests.powerdns.com."), TrailingDataRule()}), DropAction())
+
+    function removeTrailingData(dq)
+        local success = dq:setTrailingData({})
+        if not success then
+            return DNSAction.ServFail, ""
+        end
+        return DNSAction.None, ""
+    end
+    addLuaAction("removed.trailing.tests.powerdns.com.", removeTrailingData)
+
     function reportTrailingData(dq)
         local tailBytes = dq:getTrailingData()
         local tailChars = string.char(unpack(tailBytes))
-        return DNSAction.Spoof, tailChars .. ".echoed.trailing.tests.powerdns.com."
+        return DNSAction.Spoof, "-" .. tailChars .. ".echoed.trailing.tests.powerdns.com."
     end
-    addAction(AndRule({QNameRule("dropped.trailing.tests.powerdns.com."), TrailingDataRule()}), DropAction())
     addLuaAction("echoed.trailing.tests.powerdns.com.", reportTrailingData)
+
+    function replaceTrailingData(dq)
+        local success = dq:setTrailingData({65, 66, 67}) -- "ABC"
+        if not success then
+            return DNSAction.ServFail, ""
+        end
+        return DNSAction.None, ""
+    end
+    addLuaAction("replaced.trailing.tests.powerdns.com.", replaceTrailingData)
+    addLuaAction("replaced.trailing.tests.powerdns.com.", reportTrailingData)
     """
 
     def testTrailingDropped(self):
@@ -107,6 +255,35 @@ class TestTrailingDataToDnsdist(DNSDistTest):
             (_, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertEquals(receivedResponse, None)
 
+    def testTrailingRemoved(self):
+        """
+        Trailing data: Remove
+
+        """
+        name = 'removed.trailing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        raw = query.to_wire()
+        raw = raw + b'A'* 20
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
+            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
+            (receivedQuery, receivedResponse) = sender(raw, response, rawQuery=True)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
+
     def testTrailingRead(self):
         """
         Trailing data: Count
@@ -121,7 +298,36 @@ class TestTrailingDataToDnsdist(DNSDistTest):
                                     60,
                                     dns.rdataclass.IN,
                                     dns.rdatatype.CNAME,
-                                    'TrailingData.echoed.trailing.tests.powerdns.com.')
+                                    '-TrailingData.echoed.trailing.tests.powerdns.com.')
+        expectedResponse.answer.append(rrset)
+
+        raw = query.to_wire()
+        raw = raw + b'TrailingData'
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
+            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
+            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            self.assertTrue(receivedResponse)
+            expectedResponse.flags = receivedResponse.flags
+            self.assertEquals(receivedResponse, expectedResponse)
+
+    def testTrailingReplaced(self):
+        """
+        Trailing data: Replace
+
+        """
+        name = 'replaced.trailing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        response.set_rcode(dns.rcode.SERVFAIL)
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.CNAME,
+                                    '-ABC.echoed.trailing.tests.powerdns.com.')
         expectedResponse.answer.append(rrset)
 
         raw = query.to_wire()