]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Refactor trailing data tests
authorRichard Gibson <richard.gibson@gmail.com>
Mon, 27 Aug 2018 01:42:16 +0000 (21:42 -0400)
committerRichard Gibson <richard.gibson@gmail.com>
Tue, 16 Oct 2018 21:42:35 +0000 (17:42 -0400)
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_Trailing.py

index c02afb8437bbd1da93dfddc95d274e0d6a14740d..aa79d941e5d9b8084d6eb7b08c07c0be9ae9efa1 100644 (file)
@@ -142,7 +142,7 @@ class DNSDistTest(unittest.TestCase):
             cls._responsesCounter[threading.currentThread().name] = 1
 
     @classmethod
-    def _getResponse(cls, request, fromQueue, toQueue):
+    def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
         response = None
         if len(request.question) != 1:
             print("Skipping query with question count %d" % (len(request.question)))
@@ -150,18 +150,21 @@ class DNSDistTest(unittest.TestCase):
         healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
         if healthCheck:
             cls._healthCheckCounter += 1
+            response = dns.message.make_response(request)
         else:
             cls._ResponderIncrementCounter()
             if not fromQueue.empty():
-                response = fromQueue.get(True, cls._queueTimeout)
-                if response:
-                    response = copy.copy(response)
-                    response.id = request.id
-                    toQueue.put(request, True, cls._queueTimeout)
+                toQueue.put(request, True, cls._queueTimeout)
+                if synthesize is None:
+                    response = fromQueue.get(True, cls._queueTimeout)
+                    if response:
+                        response = copy.copy(response)
+                        response.id = request.id
 
         if not response:
-            if healthCheck:
+            if synthesize is not None:
                 response = dns.message.make_response(request)
+                response.set_rcode(synthesize)
             elif cls._answerUnexpected:
                 response = dns.message.make_response(request)
                 response.set_rcode(dns.rcode.SERVFAIL)
@@ -169,15 +172,24 @@ class DNSDistTest(unittest.TestCase):
         return response
 
     @classmethod
-    def UDPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False):
+    def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False):
+        ignoreTrailing = trailingDataResponse is True
         sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
         sock.bind(("127.0.0.1", port))
         while True:
             data, addr = sock.recvfrom(4096)
-            request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
-            response = cls._getResponse(request, fromQueue, toQueue)
-
+            forceRcode = None
+            try:
+                request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
+            except dns.message.TrailingJunk as e:
+                if trailingDataResponse is False:
+                    raise
+                print("UDP query with trailing data, synthesizing response")
+                request = dns.message.from_wire(data, ignore_trailing=True)
+                forceRcode = trailingDataResponse
+
+            response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
             if not response:
                 continue
 
@@ -187,7 +199,8 @@ class DNSDistTest(unittest.TestCase):
         sock.close()
 
     @classmethod
-    def TCPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False, multipleResponses=False):
+    def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False):
+        ignoreTrailing = trailingDataResponse is True
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
         try:
@@ -207,9 +220,17 @@ class DNSDistTest(unittest.TestCase):
 
             (datalen,) = struct.unpack("!H", data)
             data = conn.recv(datalen)
-            request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
-            response = cls._getResponse(request, fromQueue, toQueue)
-
+            forceRcode = None
+            try:
+                request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
+            except dns.message.TrailingJunk as e:
+                if trailingDataResponse is False:
+                    raise
+                print("TCP query with trailing data, synthesizing response")
+                request = dns.message.from_wire(data, ignore_trailing=True)
+                forceRcode = trailingDataResponse
+
+            response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
             if not response:
                 conn.close()
                 continue
index 34b1e1f14a3511c59d85c4d84548c89eb83df88b..ed70869f31526aa9d42bb889229adece43cd0ba6 100644 (file)
@@ -3,7 +3,7 @@ import threading
 import dns
 from dnsdisttests import DNSDistTest
 
-class TestTrailing(DNSDistTest):
+class TestTrailingDataToBackend(DNSDistTest):
 
     # this test suite uses a different responder port
     # because, contrary to the other ones, its
@@ -12,26 +12,27 @@ class TestTrailing(DNSDistTest):
     _testServerPort = 5360
     _config_template = """
     newServer{address="127.0.0.1:%s"}
-    addAction(AndRule({QTypeRule(dnsdist.AAAA), TrailingDataRule()}), DropAction())
     """
     @classmethod
     def startResponders(cls):
         print("Launching responders..")
 
-        cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, True])
+        # Respond SERVFAIL to queries with trailing data.
+        cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.SERVFAIL])
         cls._UDPResponder.setDaemon(True)
         cls._UDPResponder.start()
 
-        cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, True])
+        # Respond SERVFAIL to queries with trailing data.
+        cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.SERVFAIL])
         cls._TCPResponder.setDaemon(True)
         cls._TCPResponder.start()
 
-    def testTrailingAllowed(self):
+    def testTrailingPassthrough(self):
         """
-        Trailing: Allowed
+        Trailing data: Pass through
 
         """
-        name = 'allowed.trailing.tests.powerdns.com.'
+        name = 'passthrough.trailing.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
         rrset = dns.rrset.from_text(name,
@@ -40,35 +41,62 @@ class TestTrailing(DNSDistTest):
                                     dns.rdatatype.A,
                                     '127.0.0.1')
         response.answer.append(rrset)
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.SERVFAIL)
 
         raw = query.to_wire()
         raw = raw + b'A'* 20
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
 
-        (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
+            # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
+            (receivedQuery, receivedResponse) = sender(raw, response, rawQuery=True)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(receivedQuery, query)
+            self.assertEquals(receivedResponse, expectedResponse)
+
+class TestTrailingDataToDnsdist(DNSDistTest):
+    _config_template = """
+    newServer{address="127.0.0.1:%s"}
+    addAction(AndRule({QNameRule("dropped.trailing.tests.powerdns.com."), TrailingDataRule()}), DropAction())
+    """
 
     def testTrailingDropped(self):
         """
-        Trailing: dropped
+        Trailing data: Drop query
 
         """
         name = 'dropped.trailing.tests.powerdns.com.'
-        query = dns.message.make_query(name, 'AAAA', 'IN')
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
 
         raw = query.to_wire()
         raw = raw + b'A'* 20
 
-        (_, receivedResponse) = self.sendUDPQuery(raw, response=None, rawQuery=True)
-        self.assertEquals(receivedResponse, None)
-        (_, receivedResponse) = self.sendTCPQuery(raw, response=None, rawQuery=True)
-        self.assertEquals(receivedResponse, None)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+
+            # Verify that queries with no trailing data make it through.
+            # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+            # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEquals(query, receivedQuery)
+            self.assertEquals(response, receivedResponse)
+
+            # Verify that queries with trailing data don't make it through.
+            # (_, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
+            # (_, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
+            (_, receivedResponse) = sender(raw, response, rawQuery=True)
+            self.assertEquals(receivedResponse, None)