]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Clean up the internal queues use for self-answered and trailing test responses 9801/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 2 Dec 2020 09:57:06 +0000 (10:57 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 2 Dec 2020 09:57:06 +0000 (10:57 +0100)
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_Advanced.py
regression-tests.dnsdist/test_Trailing.py

index 4bea3f9a11d0281792cf1e78c9af387b3294244d..f84569ffbe439e7d1dff06fbdf26dcffa4f8f4be 100644 (file)
@@ -174,17 +174,17 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             cls._ResponderIncrementCounter()
             if not fromQueue.empty():
                 toQueue.put(request, True, cls._queueTimeout)
-                if synthesize is None:
-                    response = fromQueue.get(True, cls._queueTimeout)
-                    if response:
-                        response = copy.copy(response)
-                        response.id = request.id
+                response = fromQueue.get(True, cls._queueTimeout)
+                if response:
+                  response = copy.copy(response)
+                  response.id = request.id
+
+        if synthesize is not None:
+          response = dns.message.make_response(request)
+          response.set_rcode(synthesize)
 
         if not response:
-            if synthesize is not None:
-                response = dns.message.make_response(request)
-                response.set_rcode(synthesize)
-            elif cls._answerUnexpected:
+            if cls._answerUnexpected:
                 response = dns.message.make_response(request)
                 response.set_rcode(dns.rcode.SERVFAIL)
 
@@ -307,7 +307,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
 
     @classmethod
     def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
-        if useQueue:
+        if useQueue and response is not None:
             cls._toResponderQueue.put(response, True, timeout)
 
         if timeout:
index 8fe66ac9f26ef9fab31fb512b1cf68b61e8494fc..089b8551d5d9b8802bfb586ccf8b08e3595a86cd 100644 (file)
@@ -647,7 +647,7 @@ class TestAdvancedDNSSEC(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            (_, receivedResponse) = sender(doquery, response)
+            (_, receivedResponse) = sender(doquery, response=None, useQueue=False)
             self.assertEquals(receivedResponse, None)
 
 class TestAdvancedQClass(DNSDistTest):
@@ -666,7 +666,7 @@ class TestAdvancedQClass(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            (_, receivedResponse) = sender(query, response=None)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertEquals(receivedResponse, None)
 
     def testAdvancedQClassINAllow(self):
@@ -710,7 +710,7 @@ class TestAdvancedOpcode(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            (_, receivedResponse) = sender(query, response=None)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertEquals(receivedResponse, None)
 
     def testAdvancedOpcodeUpdateINAllow(self):
@@ -1569,7 +1569,7 @@ class TestAdvancedEDNSOptionRule(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            (_, receivedResponse) = sender(query, response=None)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertEquals(receivedResponse, None)
 
     def testReplied(self):
@@ -1688,7 +1688,7 @@ class TestAdvancedEDNSVersionRule(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            (_, receivedResponse) = sender(query, response=None)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertEquals(receivedResponse, expectedResponse)
 
     def testNoEDNS0Pass(self):
index 2de5580fb74c3db13698cd072bbaf91de0eb96d4..fdf009a82735cae1dd9e84ccee0a7d6791ef4951 100644 (file)
@@ -133,7 +133,7 @@ class TestTrailingDataToBackend(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            (_, receivedResponse) = sender(query, response)
+            (_, receivedResponse) = sender(query, response, useQueue=False)
             self.assertTrue(receivedResponse)
             self.assertEquals(receivedResponse, expectedResponse)
 
@@ -246,7 +246,7 @@ class TestTrailingDataToDnsdist(DNSDistTest):
             self.assertEquals(response, receivedResponse)
 
             # Verify that queries with trailing data don't make it through.
-            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            (_, receivedResponse) = sender(raw, response, rawQuery=True, useQueue=False)
             self.assertEquals(receivedResponse, None)
 
     def testTrailingRemoved(self):
@@ -298,7 +298,7 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            (_, receivedResponse) = sender(raw, response=None, rawQuery=True, useQueue=False)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
             self.assertEquals(receivedResponse, expectedResponse)
@@ -325,7 +325,7 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            (_, receivedResponse) = sender(raw, response=None, rawQuery=True, useQueue=False)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
             self.assertEquals(receivedResponse, expectedResponse)
@@ -352,7 +352,7 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            (_, receivedResponse) = sender(raw, response=None, rawQuery=True, useQueue=False)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
             self.assertEquals(receivedResponse, expectedResponse)
@@ -379,7 +379,7 @@ class TestTrailingDataToDnsdist(DNSDistTest):
 
         for method in ("sendUDPQuery", "sendTCPQuery"):
             sender = getattr(self, method)
-            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            (_, receivedResponse) = sender(raw, response=None, rawQuery=True, useQueue=False)
             self.assertTrue(receivedResponse)
             expectedResponse.flags = receivedResponse.flags
             self.assertEquals(receivedResponse, expectedResponse)