]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/dnsdisttests.py
Merge pull request #7496 from rgacogne/auth-catch-invalid-slave-soa
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
index 86789f04614b08ceb13d4c0a8afd9777fe239935..e1df6cfa7df44a5631334af0776505a88560a587 100644 (file)
@@ -46,7 +46,6 @@ class DNSDistTest(unittest.TestCase):
     _dnsdistStartupDelay = 2.0
     _dnsdist = None
     _responsesCounter = {}
-    _shutUp = True
     _config_template = """
     """
     _config_params = ['_testServerPort']
@@ -69,16 +68,16 @@ class DNSDistTest(unittest.TestCase):
         cls._TCPResponder.start()
 
     @classmethod
-    def startDNSDist(cls, shutUp=True):
+    def startDNSDist(cls):
         print("Launching dnsdist..")
-        conffile = 'dnsdist_test.conf'
+        confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__))
         params = tuple([getattr(cls, param) for param in cls._config_params])
         print(params)
-        with open(conffile, 'w') as conf:
+        with open(confFile, 'w') as conf:
             conf.write("-- Autogenerated by dnsdisttests.py\n")
             conf.write(cls._config_template % params)
 
-        dnsdistcmd = [os.environ['DNSDISTBIN'], '-C', conffile,
+        dnsdistcmd = [os.environ['DNSDISTBIN'], '-C', confFile,
                       '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
         for acl in cls._acl:
             dnsdistcmd.extend(['--acl', acl])
@@ -90,14 +89,13 @@ class DNSDistTest(unittest.TestCase):
             output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
         except subprocess.CalledProcessError as exc:
             raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
-        if output != b'Configuration \'dnsdist_test.conf\' OK!\n':
+        expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
+        if output != expectedOutput:
             raise AssertionError('dnsdist --check-config failed: %s' % output)
 
-        if shutUp:
-            with open(os.devnull, 'w') as fdDevNull:
-                cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdDevNull)
-        else:
-            cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True)
+        logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__))
+        with open(logFile, 'w') as fdLog:
+          cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog)
 
         if 'DNSDIST_FAST_TESTS' in os.environ:
             delay = 0.5
@@ -121,7 +119,7 @@ class DNSDistTest(unittest.TestCase):
     def setUpClass(cls):
 
         cls.startResponders()
-        cls.startDNSDist(cls._shutUp)
+        cls.startDNSDist()
         cls.setUpSockets()
 
         print("Launching tests..")
@@ -148,7 +146,7 @@ class DNSDistTest(unittest.TestCase):
             cls._responsesCounter[threading.currentThread().name] = 1
 
     @classmethod
-    def _getResponse(cls, request, fromQueue, toQueue):
+    def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
         response = None
         if len(request.question) != 1:
             print("Skipping query with question count %d" % (len(request.question)))
@@ -156,18 +154,21 @@ class DNSDistTest(unittest.TestCase):
         healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
         if healthCheck:
             cls._healthCheckCounter += 1
+            response = dns.message.make_response(request)
         else:
             cls._ResponderIncrementCounter()
             if not fromQueue.empty():
-                response = fromQueue.get(True, cls._queueTimeout)
-                if response:
-                    response = copy.copy(response)
-                    response.id = request.id
-                    toQueue.put(request, True, cls._queueTimeout)
+                toQueue.put(request, True, cls._queueTimeout)
+                if synthesize is None:
+                    response = fromQueue.get(True, cls._queueTimeout)
+                    if response:
+                        response = copy.copy(response)
+                        response.id = request.id
 
         if not response:
-            if healthCheck:
+            if synthesize is not None:
                 response = dns.message.make_response(request)
+                response.set_rcode(synthesize)
             elif cls._answerUnexpected:
                 response = dns.message.make_response(request)
                 response.set_rcode(dns.rcode.SERVFAIL)
@@ -175,15 +176,28 @@ class DNSDistTest(unittest.TestCase):
         return response
 
     @classmethod
-    def UDPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False):
+    def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False):
+        # trailingDataResponse=True means "ignore trailing data".
+        # Other values are either False (meaning "raise an exception")
+        # or are interpreted as a response RCODE for queries with trailing data.
+        ignoreTrailing = trailingDataResponse is True
+
         sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
         sock.bind(("127.0.0.1", port))
         while True:
             data, addr = sock.recvfrom(4096)
-            request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
-            response = cls._getResponse(request, fromQueue, toQueue)
-
+            forceRcode = None
+            try:
+                request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
+            except dns.message.TrailingJunk as e:
+                if trailingDataResponse is False or forceRcode is True:
+                    raise
+                print("UDP query with trailing data, synthesizing response")
+                request = dns.message.from_wire(data, ignore_trailing=True)
+                forceRcode = trailingDataResponse
+
+            response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
             if not response:
                 continue
 
@@ -193,7 +207,12 @@ class DNSDistTest(unittest.TestCase):
         sock.close()
 
     @classmethod
-    def TCPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False, multipleResponses=False):
+    def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False):
+        # trailingDataResponse=True means "ignore trailing data".
+        # Other values are either False (meaning "raise an exception")
+        # or are interpreted as a response RCODE for queries with trailing data.
+        ignoreTrailing = trailingDataResponse is True
+
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
         try:
@@ -213,9 +232,17 @@ class DNSDistTest(unittest.TestCase):
 
             (datalen,) = struct.unpack("!H", data)
             data = conn.recv(datalen)
-            request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
-            response = cls._getResponse(request, fromQueue, toQueue)
-
+            forceRcode = None
+            try:
+                request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
+            except dns.message.TrailingJunk as e:
+                if trailingDataResponse is False or forceRcode is True:
+                    raise
+                print("TCP query with trailing data, synthesizing response")
+                request = dns.message.from_wire(data, ignore_trailing=True)
+                forceRcode = trailingDataResponse
+
+            response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
             if not response:
                 conn.close()
                 continue