]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/test_Trailing.py
Make sure we can install unsigned packages.
[thirdparty/pdns.git] / regression-tests.dnsdist / test_Trailing.py
index 33d23b1e9e7372da260ebc4c9663732372599dc6..b87b519f7ad7b9fe11f42ee82ec84383ab497cb8 100644 (file)
@@ -14,50 +14,35 @@ class TestTrailingDataToBackend(DNSDistTest):
     newServer{address="127.0.0.1:%s"}
 
     function replaceTrailingData(dq)
-        local success = dq:setTrailingData({65, 66, 67}) -- "ABC"
+        local success = dq:setTrailingData("ABC")
         if not success then
             return DNSAction.ServFail, ""
         end
         return DNSAction.None, ""
     end
-    addLuaAction("added.trailing.tests.powerdns.com.", replaceTrailingData)
+    addAction("added.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData))
 
     function fillBuffer(dq)
         local available = dq.size - dq.len
-        local tail = extendTableBy({}, available)
+        local tail = string.rep("A", available)
         local success = dq:setTrailingData(tail)
         if not success then
             return DNSAction.ServFail, ""
         end
         return DNSAction.None, ""
     end
-    addLuaAction("max.trailing.tests.powerdns.com.", fillBuffer)
+    addAction("max.trailing.tests.powerdns.com.", LuaAction(fillBuffer))
 
     function exceedBuffer(dq)
         local available = dq.size - dq.len
-        local tail = extendTableBy({}, available + 1)
+        local tail = string.rep("A", available + 1)
         local success = dq:setTrailingData(tail)
         if not success then
             return DNSAction.ServFail, ""
         end
         return DNSAction.None, ""
     end
-    addLuaAction("limited.trailing.tests.powerdns.com.", exceedBuffer)
-
-    function extendTableBy(t, n)
-        if n <= 1 then
-            if n == 1 then
-                t[#t + 1] = 0
-            end
-            return t
-        end
-
-        local lower = math.floor(n / 2)
-        local upper = n - lower
-        t = extendTableBy(t, lower)
-        t = extendTableBy(t, upper)
-        return t
-    end
+    addAction("limited.trailing.tests.powerdns.com.", LuaAction(exceedBuffer))
     """
     @classmethod
     def startResponders(cls):
@@ -95,8 +80,6 @@ class TestTrailingDataToBackend(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (receivedQuery, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
@@ -123,8 +106,6 @@ class TestTrailingDataToBackend(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
             (receivedQuery, receivedResponse) = sender(query, response)
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
@@ -151,8 +132,6 @@ class TestTrailingDataToBackend(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
             (_, receivedResponse) = sender(query, response)
             self.assertTrue(receivedResponse)
             self.assertEquals(receivedResponse, expectedResponse)
@@ -176,8 +155,6 @@ class TestTrailingDataToBackend(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
             (receivedQuery, receivedResponse) = sender(query, response)
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
@@ -192,30 +169,48 @@ class TestTrailingDataToDnsdist(DNSDistTest):
     addAction(AndRule({QNameRule("dropped.trailing.tests.powerdns.com."), TrailingDataRule()}), DropAction())
 
     function removeTrailingData(dq)
-        local success = dq:setTrailingData({})
+        local success = dq:setTrailingData("")
         if not success then
             return DNSAction.ServFail, ""
         end
         return DNSAction.None, ""
     end
-    addLuaAction("removed.trailing.tests.powerdns.com.", removeTrailingData)
+    addAction("removed.trailing.tests.powerdns.com.", LuaAction(removeTrailingData))
 
     function reportTrailingData(dq)
-        local tailBytes = dq:getTrailingData()
-        local tailChars = string.char(unpack(tailBytes))
-        return DNSAction.Spoof, "-" .. tailChars .. ".echoed.trailing.tests.powerdns.com."
+        local tail = dq:getTrailingData()
+        return DNSAction.Spoof, "-" .. tail .. ".echoed.trailing.tests.powerdns.com."
     end
-    addLuaAction("echoed.trailing.tests.powerdns.com.", reportTrailingData)
+    addAction("echoed.trailing.tests.powerdns.com.", LuaAction(reportTrailingData))
 
     function replaceTrailingData(dq)
-        local success = dq:setTrailingData({65, 66, 67}) -- "ABC"
+        local success = dq:setTrailingData("ABC")
+        if not success then
+            return DNSAction.ServFail, ""
+        end
+        return DNSAction.None, ""
+    end
+    addAction("replaced.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData))
+    addAction("replaced.trailing.tests.powerdns.com.", LuaAction(reportTrailingData))
+
+    function reportTrailingHex(dq)
+        local tail = dq:getTrailingData()
+        local hex = string.gsub(tail, ".", function(ch)
+            return string.sub(string.format("\\x2502X", string.byte(ch)), -2)
+        end)
+        return DNSAction.Spoof, "-0x" .. hex .. ".echoed-hex.trailing.tests.powerdns.com."
+    end
+    addAction("echoed-hex.trailing.tests.powerdns.com.", LuaAction(reportTrailingHex))
+
+    function replaceTrailingData_unsafe(dq)
+        local success = dq:setTrailingData("\\xB0\\x00\\xDE\\xADB\\xF0\\x9F\\x91\\xBB\\xC3\\xBE")
         if not success then
             return DNSAction.ServFail, ""
         end
         return DNSAction.None, ""
     end
-    addLuaAction("replaced.trailing.tests.powerdns.com.", replaceTrailingData)
-    addLuaAction("replaced.trailing.tests.powerdns.com.", reportTrailingData)
+    addAction("replaced-unsafe.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData_unsafe))
+    addAction("replaced-unsafe.trailing.tests.powerdns.com.", LuaAction(reportTrailingHex))
     """
 
     def testTrailingDropped(self):
@@ -240,8 +235,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
             sender = getattr(self, method)
 
             # Verify that queries with no trailing data make it through.
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
             (receivedQuery, receivedResponse) = sender(query, response)
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
@@ -250,8 +243,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
             self.assertEquals(response, receivedResponse)
 
             # Verify that queries with trailing data don't make it through.
-            # (_, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (_, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (_, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertEquals(receivedResponse, None)
 
@@ -275,8 +266,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (receivedQuery, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
@@ -286,7 +275,7 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
     def testTrailingRead(self):
         """
-        Trailing data: Count
+        Trailing data: Echo
 
         """
         name = 'echoed.trailing.tests.powerdns.com.'
@@ -306,8 +295,6 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
             (_, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
@@ -335,8 +322,60 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
+            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            self.assertTrue(receivedResponse)
+            expectedResponse.flags = receivedResponse.flags
+            self.assertEquals(receivedResponse, expectedResponse)
+
+    def testTrailingReadUnsafe(self):
+        """
+        Trailing data: Echo as hex
+
+        """
+        name = 'echoed-hex.trailing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        response.set_rcode(dns.rcode.SERVFAIL)
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.CNAME,
+                                    '-0x0000DEAD.echoed-hex.trailing.tests.powerdns.com.')
+        expectedResponse.answer.append(rrset)
+
+        raw = query.to_wire()
+        raw = raw + b'\x00\x00\xDE\xAD'
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            self.assertTrue(receivedResponse)
+            expectedResponse.flags = receivedResponse.flags
+            self.assertEquals(receivedResponse, expectedResponse)
+
+    def testTrailingReplacedUnsafe(self):
+        """
+        Trailing data: Replace with null and/or non-ASCII bytes
+
+        """
+        name = 'replaced-unsafe.trailing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        response.set_rcode(dns.rcode.SERVFAIL)
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.CNAME,
+                                    '-0xB000DEAD42F09F91BBC3BE.echoed-hex.trailing.tests.powerdns.com.')
+        expectedResponse.answer.append(rrset)
+
+        raw = query.to_wire()
+        raw = raw + b'TrailingData'
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
             (_, receivedResponse) = sender(raw, response, rawQuery=True)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags