]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/test_Trailing.py
dnsdist: Add documentation and a regression test for FFI functions
[thirdparty/pdns.git] / regression-tests.dnsdist / test_Trailing.py
index f00d47bf8f23045370a81bf2f2d82bcb203fb472..b87b519f7ad7b9fe11f42ee82ec84383ab497cb8 100644 (file)
@@ -3,7 +3,7 @@ import threading
 import dns
 from dnsdisttests import DNSDistTest
 
-class TestTrailing(DNSDistTest):
+class TestTrailingDataToBackend(DNSDistTest):
 
     # this test suite uses a different responder port
     # because, contrary to the other ones, its
@@ -12,26 +12,58 @@ class TestTrailing(DNSDistTest):
     _testServerPort = 5360
     _config_template = """
     newServer{address="127.0.0.1:%s"}
-    addAction(AndRule({QTypeRule(dnsdist.AAAA), TrailingDataRule()}), DropAction())
+
+    function replaceTrailingData(dq)
+        local success = dq:setTrailingData("ABC")
+        if not success then
+            return DNSAction.ServFail, ""
+        end
+        return DNSAction.None, ""
+    end
+    addAction("added.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData))
+
+    function fillBuffer(dq)
+        local available = dq.size - dq.len
+        local tail = string.rep("A", available)
+        local success = dq:setTrailingData(tail)
+        if not success then
+            return DNSAction.ServFail, ""
+        end
+        return DNSAction.None, ""
+    end
+    addAction("max.trailing.tests.powerdns.com.", LuaAction(fillBuffer))
+
+    function exceedBuffer(dq)
+        local available = dq.size - dq.len
+        local tail = string.rep("A", available + 1)
+        local success = dq:setTrailingData(tail)
+        if not success then
+            return DNSAction.ServFail, ""
+        end
+        return DNSAction.None, ""
+    end
+    addAction("limited.trailing.tests.powerdns.com.", LuaAction(exceedBuffer))
     """
     @classmethod
     def startResponders(cls):
         print("Launching responders..")
 
-        cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, True])
+        # Respond REFUSED to queries with trailing data.
+        cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.REFUSED])
         cls._UDPResponder.setDaemon(True)
         cls._UDPResponder.start()
 
-        cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, True])
+        # Respond REFUSED to queries with trailing data.
+        cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.REFUSED])
         cls._TCPResponder.setDaemon(True)
         cls._TCPResponder.start()
 
-    def testTrailingAllowed(self):
+    def testTrailingPassthrough(self):
         """
-        Trailing: Allowed
+        Trailing data: Pass through
 
         """
-        name = 'allowed.trailing.tests.powerdns.com.'
+        name = 'passthrough.trailing.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
         rrset = dns.rrset.from_text(name,
@@ -40,35 +72,311 @@ class TestTrailing(DNSDistTest):
                                     dns.rdatatype.A,
                                     '127.0.0.1')
         response.answer.append(rrset)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.REFUSED)
 
         raw = query.to_wire()
-        raw = raw + 'A'* 20
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        raw = raw + b'A'* 20
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(raw, response, rawQuery=True)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, expectedResponse)
+
+    def testTrailingCapacity(self):
+        """
+        Trailing data: Fill buffer
+
+        """
+        name = 'max.trailing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.REFUSED)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, expectedResponse)
+
+    def testTrailingLimited(self):
+        """
+        Trailing data: Reject buffer overflows
+
+        """
+        name = 'limited.trailing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.SERVFAIL)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(receivedResponse, expectedResponse)
+
+    def testTrailingAdded(self):
+        """
+        Trailing data: Add
+
+        """
+        name = 'added.trailing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.REFUSED)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, expectedResponse)
+
+class TestTrailingDataToDnsdist(DNSDistTest):
+    _config_template = """
+    newServer{address="127.0.0.1:%s"}
+
+    addAction(AndRule({QNameRule("dropped.trailing.tests.powerdns.com."), TrailingDataRule()}), DropAction())
+
+    function removeTrailingData(dq)
+        local success = dq:setTrailingData("")
+        if not success then
+            return DNSAction.ServFail, ""
+        end
+        return DNSAction.None, ""
+    end
+    addAction("removed.trailing.tests.powerdns.com.", LuaAction(removeTrailingData))
+
+    function reportTrailingData(dq)
+        local tail = dq:getTrailingData()
+        return DNSAction.Spoof, "-" .. tail .. ".echoed.trailing.tests.powerdns.com."
+    end
+    addAction("echoed.trailing.tests.powerdns.com.", LuaAction(reportTrailingData))
+
+    function replaceTrailingData(dq)
+        local success = dq:setTrailingData("ABC")
+        if not success then
+            return DNSAction.ServFail, ""
+        end
+        return DNSAction.None, ""
+    end
+    addAction("replaced.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData))
+    addAction("replaced.trailing.tests.powerdns.com.", LuaAction(reportTrailingData))
+
+    function reportTrailingHex(dq)
+        local tail = dq:getTrailingData()
+        local hex = string.gsub(tail, ".", function(ch)
+            return string.sub(string.format("\\x2502X", string.byte(ch)), -2)
+        end)
+        return DNSAction.Spoof, "-0x" .. hex .. ".echoed-hex.trailing.tests.powerdns.com."
+    end
+    addAction("echoed-hex.trailing.tests.powerdns.com.", LuaAction(reportTrailingHex))
+
+    function replaceTrailingData_unsafe(dq)
+        local success = dq:setTrailingData("\\xB0\\x00\\xDE\\xADB\\xF0\\x9F\\x91\\xBB\\xC3\\xBE")
+        if not success then
+            return DNSAction.ServFail, ""
+        end
+        return DNSAction.None, ""
+    end
+    addAction("replaced-unsafe.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData_unsafe))
+    addAction("replaced-unsafe.trailing.tests.powerdns.com.", LuaAction(reportTrailingHex))
+    """
 
     def testTrailingDropped(self):
         """
-        Trailing: dropped
+        Trailing data: Drop query
 
         """
         name = 'dropped.trailing.tests.powerdns.com.'
-        query = dns.message.make_query(name, 'AAAA', 'IN')
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        raw = query.to_wire()
+        raw = raw + b'A'* 20
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+
+            # Verify that queries with no trailing data make it through.
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
+
+            # Verify that queries with trailing data don't make it through.
+            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            self.assertEquals(receivedResponse, None)
+
+    def testTrailingRemoved(self):
+        """
+        Trailing data: Remove
+
+        """
+        name = 'removed.trailing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        raw = query.to_wire()
+        raw = raw + b'A'* 20
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(raw, response, rawQuery=True)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, response)
+
+    def testTrailingRead(self):
+        """
+        Trailing data: Echo
+
+        """
+        name = 'echoed.trailing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        response.set_rcode(dns.rcode.SERVFAIL)
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.CNAME,
+                                    '-TrailingData.echoed.trailing.tests.powerdns.com.')
+        expectedResponse.answer.append(rrset)
+
+        raw = query.to_wire()
+        raw = raw + b'TrailingData'
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            self.assertTrue(receivedResponse)
+            expectedResponse.flags = receivedResponse.flags
+            self.assertEquals(receivedResponse, expectedResponse)
+
+    def testTrailingReplaced(self):
+        """
+        Trailing data: Replace
+
+        """
+        name = 'replaced.trailing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        response.set_rcode(dns.rcode.SERVFAIL)
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.CNAME,
+                                    '-ABC.echoed.trailing.tests.powerdns.com.')
+        expectedResponse.answer.append(rrset)
+
+        raw = query.to_wire()
+        raw = raw + b'TrailingData'
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            self.assertTrue(receivedResponse)
+            expectedResponse.flags = receivedResponse.flags
+            self.assertEquals(receivedResponse, expectedResponse)
+
+    def testTrailingReadUnsafe(self):
+        """
+        Trailing data: Echo as hex
+
+        """
+        name = 'echoed-hex.trailing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        response.set_rcode(dns.rcode.SERVFAIL)
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.CNAME,
+                                    '-0x0000DEAD.echoed-hex.trailing.tests.powerdns.com.')
+        expectedResponse.answer.append(rrset)
+
+        raw = query.to_wire()
+        raw = raw + b'\x00\x00\xDE\xAD'
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            self.assertTrue(receivedResponse)
+            expectedResponse.flags = receivedResponse.flags
+            self.assertEquals(receivedResponse, expectedResponse)
+
+    def testTrailingReplacedUnsafe(self):
+        """
+        Trailing data: Replace with null and/or non-ASCII bytes
+
+        """
+        name = 'replaced-unsafe.trailing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        response.set_rcode(dns.rcode.SERVFAIL)
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.CNAME,
+                                    '-0xB000DEAD42F09F91BBC3BE.echoed-hex.trailing.tests.powerdns.com.')
+        expectedResponse.answer.append(rrset)
 
         raw = query.to_wire()
-        raw = raw + 'A'* 20
+        raw = raw + b'TrailingData'
 
-        (_, receivedResponse) = self.sendUDPQuery(raw, response=None, rawQuery=True)
-        self.assertEquals(receivedResponse, None)
-        (_, receivedResponse) = self.sendTCPQuery(raw, response=None, rawQuery=True)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            self.assertTrue(receivedResponse)
+            expectedResponse.flags = receivedResponse.flags
+            self.assertEquals(receivedResponse, expectedResponse)