From: Remi Gacogne Date: Wed, 2 Dec 2020 09:57:06 +0000 (+0100) Subject: dnsdist: Clean up the internal queues use for self-answered and trailing test responses X-Git-Tag: rec-4.5.0-alpha1~91^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=refs%2Fpull%2F9801%2Fhead;p=thirdparty%2Fpdns.git dnsdist: Clean up the internal queues use for self-answered and trailing test responses --- diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 4bea3f9a11..f84569ffbe 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -174,17 +174,17 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): cls._ResponderIncrementCounter() if not fromQueue.empty(): toQueue.put(request, True, cls._queueTimeout) - if synthesize is None: - response = fromQueue.get(True, cls._queueTimeout) - if response: - response = copy.copy(response) - response.id = request.id + response = fromQueue.get(True, cls._queueTimeout) + if response: + response = copy.copy(response) + response.id = request.id + + if synthesize is not None: + response = dns.message.make_response(request) + response.set_rcode(synthesize) if not response: - if synthesize is not None: - response = dns.message.make_response(request) - response.set_rcode(synthesize) - elif cls._answerUnexpected: + if cls._answerUnexpected: response = dns.message.make_response(request) response.set_rcode(dns.rcode.SERVFAIL) @@ -307,7 +307,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): @classmethod def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False): - if useQueue: + if useQueue and response is not None: cls._toResponderQueue.put(response, True, timeout) if timeout: diff --git a/regression-tests.dnsdist/test_Advanced.py b/regression-tests.dnsdist/test_Advanced.py index 8fe66ac9f2..089b8551d5 100644 --- a/regression-tests.dnsdist/test_Advanced.py +++ b/regression-tests.dnsdist/test_Advanced.py @@ -647,7 +647,7 @@ class TestAdvancedDNSSEC(DNSDistTest): for method in ("sendUDPQuery", "sendTCPQuery"): sender = getattr(self, method) - (_, receivedResponse) = sender(doquery, response) + (_, receivedResponse) = sender(doquery, response=None, useQueue=False) self.assertEquals(receivedResponse, None) class TestAdvancedQClass(DNSDistTest): @@ -666,7 +666,7 @@ class TestAdvancedQClass(DNSDistTest): for method in ("sendUDPQuery", "sendTCPQuery"): sender = getattr(self, method) - (_, receivedResponse) = sender(query, response=None) + (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertEquals(receivedResponse, None) def testAdvancedQClassINAllow(self): @@ -710,7 +710,7 @@ class TestAdvancedOpcode(DNSDistTest): for method in ("sendUDPQuery", "sendTCPQuery"): sender = getattr(self, method) - (_, receivedResponse) = sender(query, response=None) + (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertEquals(receivedResponse, None) def testAdvancedOpcodeUpdateINAllow(self): @@ -1569,7 +1569,7 @@ class TestAdvancedEDNSOptionRule(DNSDistTest): for method in ("sendUDPQuery", "sendTCPQuery"): sender = getattr(self, method) - (_, receivedResponse) = sender(query, response=None) + (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertEquals(receivedResponse, None) def testReplied(self): @@ -1688,7 +1688,7 @@ class TestAdvancedEDNSVersionRule(DNSDistTest): for method in ("sendUDPQuery", "sendTCPQuery"): sender = getattr(self, method) - (_, receivedResponse) = sender(query, response=None) + (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertEquals(receivedResponse, expectedResponse) def testNoEDNS0Pass(self): diff --git a/regression-tests.dnsdist/test_Trailing.py b/regression-tests.dnsdist/test_Trailing.py index 2de5580fb7..fdf009a827 100644 --- a/regression-tests.dnsdist/test_Trailing.py +++ b/regression-tests.dnsdist/test_Trailing.py @@ -133,7 +133,7 @@ class TestTrailingDataToBackend(DNSDistTest): for method in ("sendUDPQuery", "sendTCPQuery"): sender = getattr(self, method) - (_, receivedResponse) = sender(query, response) + (_, receivedResponse) = sender(query, response, useQueue=False) self.assertTrue(receivedResponse) self.assertEquals(receivedResponse, expectedResponse) @@ -246,7 +246,7 @@ class TestTrailingDataToDnsdist(DNSDistTest): self.assertEquals(response, receivedResponse) # Verify that queries with trailing data don't make it through. - (_, receivedResponse) = sender(raw, response, rawQuery=True) + (_, receivedResponse) = sender(raw, response, rawQuery=True, useQueue=False) self.assertEquals(receivedResponse, None) def testTrailingRemoved(self): @@ -298,7 +298,7 @@ class TestTrailingDataToDnsdist(DNSDistTest): for method in ("sendUDPQuery", "sendTCPQuery"): sender = getattr(self, method) - (_, receivedResponse) = sender(raw, response, rawQuery=True) + (_, receivedResponse) = sender(raw, response=None, rawQuery=True, useQueue=False) self.assertTrue(receivedResponse) expectedResponse.flags = receivedResponse.flags self.assertEquals(receivedResponse, expectedResponse) @@ -325,7 +325,7 @@ class TestTrailingDataToDnsdist(DNSDistTest): for method in ("sendUDPQuery", "sendTCPQuery"): sender = getattr(self, method) - (_, receivedResponse) = sender(raw, response, rawQuery=True) + (_, receivedResponse) = sender(raw, response=None, rawQuery=True, useQueue=False) self.assertTrue(receivedResponse) expectedResponse.flags = receivedResponse.flags self.assertEquals(receivedResponse, expectedResponse) @@ -352,7 +352,7 @@ class TestTrailingDataToDnsdist(DNSDistTest): for method in ("sendUDPQuery", "sendTCPQuery"): sender = getattr(self, method) - (_, receivedResponse) = sender(raw, response, rawQuery=True) + (_, receivedResponse) = sender(raw, response=None, rawQuery=True, useQueue=False) self.assertTrue(receivedResponse) expectedResponse.flags = receivedResponse.flags self.assertEquals(receivedResponse, expectedResponse) @@ -379,7 +379,7 @@ class TestTrailingDataToDnsdist(DNSDistTest): for method in ("sendUDPQuery", "sendTCPQuery"): sender = getattr(self, method) - (_, receivedResponse) = sender(raw, response, rawQuery=True) + (_, receivedResponse) = sender(raw, response=None, rawQuery=True, useQueue=False) self.assertTrue(receivedResponse) expectedResponse.flags = receivedResponse.flags self.assertEquals(receivedResponse, expectedResponse)