]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/dnsdisttests.py
rec: ensure correct service user on debian
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
index 91fc086b49ccd146ebb41932c257865f721aafc6..83dc4b5a423566ad9fdd8bd7607184997eccc6cd 100644 (file)
@@ -46,7 +46,6 @@ class DNSDistTest(unittest.TestCase):
     _dnsdistStartupDelay = 2.0
     _dnsdist = None
     _responsesCounter = {}
-    _shutUp = True
     _config_template = """
     """
     _config_params = ['_testServerPort']
@@ -69,16 +68,16 @@ class DNSDistTest(unittest.TestCase):
         cls._TCPResponder.start()
 
     @classmethod
-    def startDNSDist(cls, shutUp=True):
+    def startDNSDist(cls):
         print("Launching dnsdist..")
-        conffile = 'dnsdist_test.conf'
+        confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__))
         params = tuple([getattr(cls, param) for param in cls._config_params])
         print(params)
-        with open(conffile, 'w') as conf:
+        with open(confFile, 'w') as conf:
             conf.write("-- Autogenerated by dnsdisttests.py\n")
             conf.write(cls._config_template % params)
 
-        dnsdistcmd = [os.environ['DNSDISTBIN'], '-C', conffile,
+        dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile,
                       '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
         for acl in cls._acl:
             dnsdistcmd.extend(['--acl', acl])
@@ -86,15 +85,17 @@ class DNSDistTest(unittest.TestCase):
 
         # validate config with --check-config, which sets client=true, possibly exposing bugs.
         testcmd = dnsdistcmd + ['--check-config']
-        output = subprocess.check_output(testcmd, close_fds=True)
-        if output != b'Configuration \'dnsdist_test.conf\' OK!\n':
+        try:
+            output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
+        except subprocess.CalledProcessError as exc:
+            raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
+        expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
+        if output != expectedOutput:
             raise AssertionError('dnsdist --check-config failed: %s' % output)
 
-        if shutUp:
-            with open(os.devnull, 'w') as fdDevNull:
-                cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdDevNull)
-        else:
-            cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True)
+        logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__))
+        with open(logFile, 'w') as fdLog:
+          cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog)
 
         if 'DNSDIST_FAST_TESTS' in os.environ:
             delay = 0.5
@@ -118,7 +119,7 @@ class DNSDistTest(unittest.TestCase):
     def setUpClass(cls):
 
         cls.startResponders()
-        cls.startDNSDist(cls._shutUp)
+        cls.startDNSDist()
         cls.setUpSockets()
 
         print("Launching tests..")
@@ -145,7 +146,7 @@ class DNSDistTest(unittest.TestCase):
             cls._responsesCounter[threading.currentThread().name] = 1
 
     @classmethod
-    def _getResponse(cls, request, fromQueue, toQueue):
+    def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
         response = None
         if len(request.question) != 1:
             print("Skipping query with question count %d" % (len(request.question)))
@@ -153,18 +154,21 @@ class DNSDistTest(unittest.TestCase):
         healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
         if healthCheck:
             cls._healthCheckCounter += 1
+            response = dns.message.make_response(request)
         else:
             cls._ResponderIncrementCounter()
             if not fromQueue.empty():
-                response = fromQueue.get(True, cls._queueTimeout)
-                if response:
-                    response = copy.copy(response)
-                    response.id = request.id
-                    toQueue.put(request, True, cls._queueTimeout)
+                toQueue.put(request, True, cls._queueTimeout)
+                if synthesize is None:
+                    response = fromQueue.get(True, cls._queueTimeout)
+                    if response:
+                        response = copy.copy(response)
+                        response.id = request.id
 
         if not response:
-            if healthCheck:
+            if synthesize is not None:
                 response = dns.message.make_response(request)
+                response.set_rcode(synthesize)
             elif cls._answerUnexpected:
                 response = dns.message.make_response(request)
                 response.set_rcode(dns.rcode.SERVFAIL)
@@ -172,15 +176,28 @@ class DNSDistTest(unittest.TestCase):
         return response
 
     @classmethod
-    def UDPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False):
+    def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False):
+        # trailingDataResponse=True means "ignore trailing data".
+        # Other values are either False (meaning "raise an exception")
+        # or are interpreted as a response RCODE for queries with trailing data.
+        ignoreTrailing = trailingDataResponse is True
+
         sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
         sock.bind(("127.0.0.1", port))
         while True:
             data, addr = sock.recvfrom(4096)
-            request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
-            response = cls._getResponse(request, fromQueue, toQueue)
-
+            forceRcode = None
+            try:
+                request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
+            except dns.message.TrailingJunk as e:
+                if trailingDataResponse is False or forceRcode is True:
+                    raise
+                print("UDP query with trailing data, synthesizing response")
+                request = dns.message.from_wire(data, ignore_trailing=True)
+                forceRcode = trailingDataResponse
+
+            response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
             if not response:
                 continue
 
@@ -190,8 +207,14 @@ class DNSDistTest(unittest.TestCase):
         sock.close()
 
     @classmethod
-    def TCPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False, multipleResponses=False):
+    def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False):
+        # trailingDataResponse=True means "ignore trailing data".
+        # Other values are either False (meaning "raise an exception")
+        # or are interpreted as a response RCODE for queries with trailing data.
+        ignoreTrailing = trailingDataResponse is True
+
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
         try:
             sock.bind(("127.0.0.1", port))
@@ -210,9 +233,17 @@ class DNSDistTest(unittest.TestCase):
 
             (datalen,) = struct.unpack("!H", data)
             data = conn.recv(datalen)
-            request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
-            response = cls._getResponse(request, fromQueue, toQueue)
-
+            forceRcode = None
+            try:
+                request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
+            except dns.message.TrailingJunk as e:
+                if trailingDataResponse is False or forceRcode is True:
+                    raise
+                print("TCP query with trailing data, synthesizing response")
+                request = dns.message.from_wire(data, ignore_trailing=True)
+                forceRcode = trailingDataResponse
+
+            response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
             if not response:
                 conn.close()
                 continue
@@ -274,6 +305,7 @@ class DNSDistTest(unittest.TestCase):
     @classmethod
     def openTCPConnection(cls, timeout=None):
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         if timeout:
             sock.settimeout(timeout)
 
@@ -283,6 +315,7 @@ class DNSDistTest(unittest.TestCase):
     @classmethod
     def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         if timeout:
             sock.settimeout(timeout)
 
@@ -355,6 +388,7 @@ class DNSDistTest(unittest.TestCase):
             for response in responses:
                 cls._toResponderQueue.put(response, True, timeout)
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         if timeout:
             sock.settimeout(timeout)
 
@@ -445,6 +479,7 @@ class DNSDistTest(unittest.TestCase):
         ourNonce = libnacl.utils.rand_nonce()
         theirNonce = None
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         if timeout:
             sock.settimeout(timeout)
 
@@ -485,10 +520,12 @@ class DNSDistTest(unittest.TestCase):
     def checkMessageEDNSWithoutOptions(self, expected, received):
         self.assertEquals(expected, received)
         self.assertEquals(received.edns, 0)
+        self.assertEquals(expected.payload, received.payload)
 
     def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
         self.assertEquals(expected, received)
         self.assertEquals(received.edns, 0)
+        self.assertEquals(expected.payload, received.payload)
         self.assertEquals(len(received.options), withCookies)
         if withCookies:
             for option in received.options:
@@ -497,6 +534,7 @@ class DNSDistTest(unittest.TestCase):
     def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
         self.assertEquals(expected, received)
         self.assertEquals(received.edns, 0)
+        self.assertEquals(expected.payload, received.payload)
         self.assertEquals(len(received.options), 1 + additionalOptions)
         hasECS = False
         for option in received.options: