]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/test_Trailing.py
Merge pull request #13756 from rgacogne/ddist-xsk-doc-typos
[thirdparty/pdns.git] / regression-tests.dnsdist / test_Trailing.py
index b87b519f7ad7b9fe11f42ee82ec84383ab497cb8..5bb97ff8b8b8006a740504fa8d3b4d5f3b1f73e8 100644 (file)
@@ -1,7 +1,7 @@
 #!/usr/bin/env python
 import threading
 import dns
-from dnsdisttests import DNSDistTest
+from dnsdisttests import DNSDistTest, pickAvailablePort
 
 class TestTrailingDataToBackend(DNSDistTest):
 
@@ -9,7 +9,8 @@ class TestTrailingDataToBackend(DNSDistTest):
     # because, contrary to the other ones, its
     # responders allow trailing data and we don't want
     # to mix things up.
-    _testServerPort = 5360
+    _testServerPort = pickAvailablePort()
+    _verboseMode = True
     _config_template = """
     newServer{address="127.0.0.1:%s"}
 
@@ -23,7 +24,10 @@ class TestTrailingDataToBackend(DNSDistTest):
     addAction("added.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData))
 
     function fillBuffer(dq)
-        local available = dq.size - dq.len
+        local available = 4096 - dq.len
+        if dq.tcp then
+            available = 65535 - dq.len
+        end
         local tail = string.rep("A", available)
         local success = dq:setTrailingData(tail)
         if not success then
@@ -34,7 +38,10 @@ class TestTrailingDataToBackend(DNSDistTest):
     addAction("max.trailing.tests.powerdns.com.", LuaAction(fillBuffer))
 
     function exceedBuffer(dq)
-        local available = dq.size - dq.len
+        local available = 4096 - dq.len
+        if dq.tcp then
+            available = 65535 - dq.len
+        end
         local tail = string.rep("A", available + 1)
         local success = dq:setTrailingData(tail)
         if not success then
@@ -50,12 +57,12 @@ class TestTrailingDataToBackend(DNSDistTest):
 
         # Respond REFUSED to queries with trailing data.
         cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.REFUSED])
-        cls._UDPResponder.setDaemon(True)
+        cls._UDPResponder.daemon = True
         cls._UDPResponder.start()
 
         # Respond REFUSED to queries with trailing data.
         cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.REFUSED])
-        cls._TCPResponder.setDaemon(True)
+        cls._TCPResponder.daemon = True
         cls._TCPResponder.start()
 
     def testTrailingPassthrough(self):
@@ -84,8 +91,8 @@ class TestTrailingDataToBackend(DNSDistTest):
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
             receivedQuery.id = query.id
-            self.assertEquals(receivedQuery, query)
-            self.assertEquals(receivedResponse, expectedResponse)
+            self.assertEqual(receivedQuery, query)
+            self.assertEqual(receivedResponse, expectedResponse)
 
     def testTrailingCapacity(self):
         """
@@ -110,8 +117,8 @@ class TestTrailingDataToBackend(DNSDistTest):
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
             receivedQuery.id = query.id
-            self.assertEquals(receivedQuery, query)
-            self.assertEquals(receivedResponse, expectedResponse)
+            self.assertEqual(receivedQuery, query)
+            self.assertEqual(receivedResponse, expectedResponse)
 
     def testTrailingLimited(self):
         """
@@ -132,9 +139,9 @@ class TestTrailingDataToBackend(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            (_, receivedResponse) = sender(query, response)
+            (_, receivedResponse) = sender(query, response, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEquals(receivedResponse, expectedResponse)
+            self.assertEqual(receivedResponse, expectedResponse)
 
     def testTrailingAdded(self):
         """
@@ -159,10 +166,11 @@ class TestTrailingDataToBackend(DNSDistTest):
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
             receivedQuery.id = query.id
-            self.assertEquals(receivedQuery, query)
-            self.assertEquals(receivedResponse, expectedResponse)
+            self.assertEqual(receivedQuery, query)
+            self.assertEqual(receivedResponse, expectedResponse)
 
 class TestTrailingDataToDnsdist(DNSDistTest):
+    _verboseMode = True
     _config_template = """
     newServer{address="127.0.0.1:%s"}
 
@@ -171,6 +179,7 @@ class TestTrailingDataToDnsdist(DNSDistTest):
     function removeTrailingData(dq)
         local success = dq:setTrailingData("")
         if not success then
+            print("Trailing removal failed")
             return DNSAction.ServFail, ""
         end
         return DNSAction.None, ""
@@ -239,12 +248,12 @@ class TestTrailingDataToDnsdist(DNSDistTest):
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
             receivedQuery.id = query.id
-            self.assertEquals(query, receivedQuery)
-            self.assertEquals(response, receivedResponse)
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
 
             # Verify that queries with trailing data don't make it through.
-            (_, receivedResponse) = sender(raw, response, rawQuery=True)
-            self.assertEquals(receivedResponse, None)
+            (_, receivedResponse) = sender(raw, response, rawQuery=True, useQueue=False)
+            self.assertEqual(receivedResponse, None)
 
     def testTrailingRemoved(self):
         """
@@ -270,8 +279,8 @@ class TestTrailingDataToDnsdist(DNSDistTest):
             self.assertTrue(receivedQuery)
             self.assertTrue(receivedResponse)
             receivedQuery.id = query.id
-            self.assertEquals(receivedQuery, query)
-            self.assertEquals(receivedResponse, response)
+            self.assertEqual(receivedQuery, query)
+            self.assertEqual(receivedResponse, response)
 
     def testTrailingRead(self):
         """
@@ -295,10 +304,10 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            (_, receivedResponse) = sender(raw, response=None, rawQuery=True, useQueue=False)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
-            self.assertEquals(receivedResponse, expectedResponse)
+            self.assertEqual(receivedResponse, expectedResponse)
 
     def testTrailingReplaced(self):
         """
@@ -322,10 +331,10 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            (_, receivedResponse) = sender(raw, response=None, rawQuery=True, useQueue=False)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
-            self.assertEquals(receivedResponse, expectedResponse)
+            self.assertEqual(receivedResponse, expectedResponse)
 
     def testTrailingReadUnsafe(self):
         """
@@ -349,10 +358,10 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            (_, receivedResponse) = sender(raw, response=None, rawQuery=True, useQueue=False)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
-            self.assertEquals(receivedResponse, expectedResponse)
+            self.assertEqual(receivedResponse, expectedResponse)
 
     def testTrailingReplacedUnsafe(self):
         """
@@ -376,7 +385,7 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            (_, receivedResponse) = sender(raw, response=None, rawQuery=True, useQueue=False)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
-            self.assertEquals(receivedResponse, expectedResponse)
+            self.assertEqual(receivedResponse, expectedResponse)