From 4aa08b62adacd66a5d8fce23d6cc2b073d934dc5 Mon Sep 17 00:00:00 2001 From: Richard Gibson Date: Sun, 26 Aug 2018 21:42:16 -0400 Subject: [PATCH] dnsdist: Refactor trailing data tests --- regression-tests.dnsdist/dnsdisttests.py | 51 ++++++++++----- regression-tests.dnsdist/test_Trailing.py | 78 +++++++++++++++-------- 2 files changed, 89 insertions(+), 40 deletions(-) diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index c02afb8437..aa79d941e5 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -142,7 +142,7 @@ class DNSDistTest(unittest.TestCase): cls._responsesCounter[threading.currentThread().name] = 1 @classmethod - def _getResponse(cls, request, fromQueue, toQueue): + def _getResponse(cls, request, fromQueue, toQueue, synthesize=None): response = None if len(request.question) != 1: print("Skipping query with question count %d" % (len(request.question))) @@ -150,18 +150,21 @@ class DNSDistTest(unittest.TestCase): healthCheck = str(request.question[0].name).endswith(cls._healthCheckName) if healthCheck: cls._healthCheckCounter += 1 + response = dns.message.make_response(request) else: cls._ResponderIncrementCounter() if not fromQueue.empty(): - response = fromQueue.get(True, cls._queueTimeout) - if response: - response = copy.copy(response) - response.id = request.id - toQueue.put(request, True, cls._queueTimeout) + toQueue.put(request, True, cls._queueTimeout) + if synthesize is None: + response = fromQueue.get(True, cls._queueTimeout) + if response: + response = copy.copy(response) + response.id = request.id if not response: - if healthCheck: + if synthesize is not None: response = dns.message.make_response(request) + response.set_rcode(synthesize) elif cls._answerUnexpected: response = dns.message.make_response(request) response.set_rcode(dns.rcode.SERVFAIL) @@ -169,15 +172,24 @@ class DNSDistTest(unittest.TestCase): return response @classmethod - def UDPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False): + def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False): + ignoreTrailing = trailingDataResponse is True sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) sock.bind(("127.0.0.1", port)) while True: data, addr = sock.recvfrom(4096) - request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) - response = cls._getResponse(request, fromQueue, toQueue) - + forceRcode = None + try: + request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) + except dns.message.TrailingJunk as e: + if trailingDataResponse is False: + raise + print("UDP query with trailing data, synthesizing response") + request = dns.message.from_wire(data, ignore_trailing=True) + forceRcode = trailingDataResponse + + response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) if not response: continue @@ -187,7 +199,8 @@ class DNSDistTest(unittest.TestCase): sock.close() @classmethod - def TCPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False, multipleResponses=False): + def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False): + ignoreTrailing = trailingDataResponse is True sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) try: @@ -207,9 +220,17 @@ class DNSDistTest(unittest.TestCase): (datalen,) = struct.unpack("!H", data) data = conn.recv(datalen) - request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) - response = cls._getResponse(request, fromQueue, toQueue) - + forceRcode = None + try: + request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) + except dns.message.TrailingJunk as e: + if trailingDataResponse is False: + raise + print("TCP query with trailing data, synthesizing response") + request = dns.message.from_wire(data, ignore_trailing=True) + forceRcode = trailingDataResponse + + response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) if not response: conn.close() continue diff --git a/regression-tests.dnsdist/test_Trailing.py b/regression-tests.dnsdist/test_Trailing.py index 34b1e1f14a..ed70869f31 100644 --- a/regression-tests.dnsdist/test_Trailing.py +++ b/regression-tests.dnsdist/test_Trailing.py @@ -3,7 +3,7 @@ import threading import dns from dnsdisttests import DNSDistTest -class TestTrailing(DNSDistTest): +class TestTrailingDataToBackend(DNSDistTest): # this test suite uses a different responder port # because, contrary to the other ones, its @@ -12,26 +12,27 @@ class TestTrailing(DNSDistTest): _testServerPort = 5360 _config_template = """ newServer{address="127.0.0.1:%s"} - addAction(AndRule({QTypeRule(dnsdist.AAAA), TrailingDataRule()}), DropAction()) """ @classmethod def startResponders(cls): print("Launching responders..") - cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, True]) + # Respond SERVFAIL to queries with trailing data. + cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.SERVFAIL]) cls._UDPResponder.setDaemon(True) cls._UDPResponder.start() - cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, True]) + # Respond SERVFAIL to queries with trailing data. + cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.SERVFAIL]) cls._TCPResponder.setDaemon(True) cls._TCPResponder.start() - def testTrailingAllowed(self): + def testTrailingPassthrough(self): """ - Trailing: Allowed + Trailing data: Pass through """ - name = 'allowed.trailing.tests.powerdns.com.' + name = 'passthrough.trailing.tests.powerdns.com.' query = dns.message.make_query(name, 'A', 'IN') response = dns.message.make_response(query) rrset = dns.rrset.from_text(name, @@ -40,35 +41,62 @@ class TestTrailing(DNSDistTest): dns.rdatatype.A, '127.0.0.1') response.answer.append(rrset) + expectedResponse = dns.message.make_response(query) + expectedResponse.set_rcode(dns.rcode.SERVFAIL) raw = query.to_wire() raw = raw + b'A'* 20 - (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True) - self.assertTrue(receivedQuery) - self.assertTrue(receivedResponse) - receivedQuery.id = query.id - self.assertEquals(query, receivedQuery) - self.assertEquals(response, receivedResponse) - (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True) - self.assertTrue(receivedQuery) - self.assertTrue(receivedResponse) - receivedQuery.id = query.id - self.assertEquals(query, receivedQuery) - self.assertEquals(response, receivedResponse) + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True) + # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True) + (receivedQuery, receivedResponse) = sender(raw, response, rawQuery=True) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(receivedQuery, query) + self.assertEquals(receivedResponse, expectedResponse) + +class TestTrailingDataToDnsdist(DNSDistTest): + _config_template = """ + newServer{address="127.0.0.1:%s"} + addAction(AndRule({QNameRule("dropped.trailing.tests.powerdns.com."), TrailingDataRule()}), DropAction()) + """ def testTrailingDropped(self): """ - Trailing: dropped + Trailing data: Drop query """ name = 'dropped.trailing.tests.powerdns.com.' - query = dns.message.make_query(name, 'AAAA', 'IN') + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) raw = query.to_wire() raw = raw + b'A'* 20 - (_, receivedResponse) = self.sendUDPQuery(raw, response=None, rawQuery=True) - self.assertEquals(receivedResponse, None) - (_, receivedResponse) = self.sendTCPQuery(raw, response=None, rawQuery=True) - self.assertEquals(receivedResponse, None) + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + + # Verify that queries with no trailing data make it through. + # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + (receivedQuery, receivedResponse) = sender(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + + # Verify that queries with trailing data don't make it through. + # (_, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True) + # (_, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True) + (_, receivedResponse) = sender(raw, response, rawQuery=True) + self.assertEquals(receivedResponse, None) -- 2.47.2