]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/dnsdisttests.py
dnsdist: Add a regression test for DoQ certs/keys reloading
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
index 1cd53af708bd02b3835ae9efa58b01dfe3f2a99c..c854cfab53005f3e13296381a3a402a4677e6fab 100644 (file)
@@ -28,6 +28,9 @@ import h2.config
 import pycurl
 from io import BytesIO
 
+from doqclient import quic_query
+from doh3client import doh3_query
+
 from eqdnsmessage import AssertEqualDNSMessageMixin
 from proxyprotocol import ProxyProtocol
 
@@ -42,6 +45,23 @@ try:
 except NameError:
   pass
 
+def getWorkerID():
+    if not 'PYTEST_XDIST_WORKER' in os.environ:
+      return 0
+    workerName = os.environ['PYTEST_XDIST_WORKER']
+    return int(workerName[2:])
+
+workerPorts = {}
+
+def pickAvailablePort():
+    global workerPorts
+    workerID = getWorkerID()
+    if workerID in workerPorts:
+      port = workerPorts[workerID] + 1
+    else:
+      port = 11000 + (workerID * 1000)
+    workerPorts[workerID] = port
+    return port
 
 class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
     """
@@ -52,9 +72,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
     from dnsdist on a separate queue, allowing the tests to check
     that the queries sent from dnsdist were as expected.
     """
-    _dnsDistPort = 5340
     _dnsDistListeningAddr = "127.0.0.1"
-    _testServerPort = 5350
     _toResponderQueue = Queue()
     _fromResponderQueue = Queue()
     _queueTimeout = 1
@@ -64,7 +82,6 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
     """
     _config_params = ['_testServerPort']
     _acl = ['127.0.0.1/32']
-    _consolePort = 5199
     _consoleKey = None
     _healthCheckName = 'a.root-servers.net.'
     _healthCheckCounter = 0
@@ -78,6 +95,9 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
     _UDPResponder = None
     _TCPResponder = None
     _extraStartupSleep = 0
+    _dnsDistPort = pickAvailablePort()
+    _consolePort = pickAvailablePort()
+    _testServerPort = pickAvailablePort()
 
     @classmethod
     def waitForTCPSocket(cls, ipaddress, port):
@@ -97,23 +117,28 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
     @classmethod
     def startResponders(cls):
         print("Launching responders..")
+        cls._testServerPort = pickAvailablePort()
 
         cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
-        cls._UDPResponder.setDaemon(True)
+        cls._UDPResponder.daemon = True
         cls._UDPResponder.start()
         cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
-        cls._TCPResponder.setDaemon(True)
+        cls._TCPResponder.daemon = True
         cls._TCPResponder.start()
         cls.waitForTCPSocket("127.0.0.1", cls._testServerPort);
 
     @classmethod
     def startDNSDist(cls):
+        cls._dnsDistPort = pickAvailablePort()
+        cls._consolePort = pickAvailablePort()
+
         print("Launching dnsdist..")
         confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__))
         params = tuple([getattr(cls, param) for param in cls._config_params])
         print(params)
         with open(confFile, 'w') as conf:
             conf.write("-- Autogenerated by dnsdisttests.py\n")
+            conf.write(f"-- dnsdist will listen on {cls._dnsDistPort}")
             conf.write(cls._config_template % params)
             conf.write("setSecurityPollSuffix('')")
 
@@ -141,7 +166,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         else:
           expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
         if not cls._verboseMode and output != expectedOutput:
-            raise AssertionError('dnsdist --check-config failed: %s' % output)
+            raise AssertionError('dnsdist --check-config failed: %s (expected %s)' % (output, expectedOutput))
 
         logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__))
         with open(logFile, 'w') as fdLog:
@@ -209,10 +234,10 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
 
     @classmethod
     def _ResponderIncrementCounter(cls):
-        if threading.currentThread().name in cls._responsesCounter:
-            cls._responsesCounter[threading.currentThread().name] += 1
+        if threading.current_thread().name in cls._responsesCounter:
+            cls._responsesCounter[threading.current_thread().name] += 1
         else:
-            cls._responsesCounter[threading.currentThread().name] = 1
+            cls._responsesCounter[threading.current_thread().name] = 1
 
     @classmethod
     def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
@@ -298,7 +323,11 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
     @classmethod
     def handleTCPConnection(cls, conn, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, partialWrite=False):
       ignoreTrailing = trailingDataResponse is True
-      data = conn.recv(2)
+      try:
+        data = conn.recv(2)
+      except Exception as err:
+        data = None
+        print(f'Error while reading query size in TCP responder thread {err=}, {type(err)=}')
       if not data:
         conn.close()
         return
@@ -401,7 +430,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
               thread = threading.Thread(name='TCP Connection Handler',
                                         target=cls.handleTCPConnection,
                                         args=[conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, partialWrite])
-              thread.setDaemon(True)
+              thread.daemon = True
               thread.start()
             else:
               cls.handleTCPConnection(conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, partialWrite)
@@ -418,6 +447,9 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         except ssl.SSLEOFError as e:
           print("Unexpected EOF: %s" % (e))
           return
+        except Exception as err:
+          print(f'Unexpected exception in DoH responder thread (connection init) {err=}, {type(err)=}')
+          return
 
         dnsData = {}
 
@@ -448,7 +480,11 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         # be careful, HTTP/2 headers and data might be in different recv() results
         requestHeaders = None
         while True:
-            data = conn.recv(65535)
+            try:
+              data = conn.recv(65535)
+            except Exception as err:
+              data = None
+              print(f'Unexpected exception in DoH responder thread {err=}, {type(err)=}')
             if not data:
                 break
 
@@ -545,7 +581,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             thread = threading.Thread(name='DoH Connection Handler',
                                       target=cls.handleDoHConnection,
                                       args=[config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol])
-            thread.setDaemon(True)
+            thread.daemon = True
             thread.start()
 
         sock.close()
@@ -591,7 +627,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         return sock
 
     @classmethod
-    def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
+    def openTLSConnection(cls, port, serverName, caCert=None, timeout=None, alpn=[]):
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         if timeout:
@@ -600,6 +636,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         # 2.7.9+
         if hasattr(ssl, 'create_default_context'):
             sslctx = ssl.create_default_context(cafile=caCert)
+            if len(alpn)> 0 and hasattr(sslctx, 'set_alpn_protocols'):
+              sslctx.set_alpn_protocols(alpn)
             sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
         else:
             sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
@@ -622,7 +660,6 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
 
     @classmethod
     def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
-        print("reading data")
         message = None
         data = sock.recv(2)
         if data:
@@ -636,7 +673,6 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         print(useQueue)
         if useQueue and not cls._fromResponderQueue.empty():
             receivedQuery = cls._fromResponderQueue.get(True, timeout)
-            print("Got from queue")
             print(receivedQuery)
             return (receivedQuery, message)
         else:
@@ -672,7 +708,6 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         receivedQuery = None
         print(useQueue)
         if useQueue and not cls._fromResponderQueue.empty():
-            print("Got from queue")
             print(receivedQuery)
             receivedQuery = cls._fromResponderQueue.get(True, timeout)
         else:
@@ -873,16 +908,17 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
     def checkResponseNoEDNS(self, expected, received):
         self.checkMessageNoEDNS(expected, received)
 
-    def generateNewCertificateAndKey(self):
+    @staticmethod
+    def generateNewCertificateAndKey(filePrefix):
         # generate and sign a new cert
-        cmd = ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', 'server.key', '-out', 'server.csr', '-config', 'configServer.conf']
+        cmd = ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', filePrefix + '.key', '-out', filePrefix + '.csr', '-config', 'configServer.conf']
         output = None
         try:
             process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
             output = process.communicate(input='')
         except subprocess.CalledProcessError as exc:
             raise AssertionError('openssl req failed (%d): %s' % (exc.returncode, exc.output))
-        cmd = ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', 'server.csr', '-out', 'server.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req']
+        cmd = ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', filePrefix + '.csr', '-out', filePrefix + '.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req']
         output = None
         try:
             process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
@@ -890,12 +926,12 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         except subprocess.CalledProcessError as exc:
             raise AssertionError('openssl x509 failed (%d): %s' % (exc.returncode, exc.output))
 
-        with open('server.chain', 'w') as outFile:
-            for inFileName in ['server.pem', 'ca.pem']:
+        with open(filePrefix + '.chain', 'w') as outFile:
+            for inFileName in [filePrefix + '.pem', 'ca.pem']:
                 with open(inFileName) as inFile:
                     outFile.write(inFile.read())
 
-        cmd = ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', 'server.pem', '-CAfile', 'ca.pem', '-inkey', 'server.key', '-out', 'server.p12']
+        cmd = ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', filePrefix + '.pem', '-CAfile', 'ca.pem', '-inkey', filePrefix + '.key', '-out', filePrefix + '.p12']
         output = None
         try:
             process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
@@ -952,19 +988,25 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         return conn
 
     @classmethod
-    def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None):
+    def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None, conn=None):
         url = cls.getDOHGetURL(baseurl, query, rawQuery)
-        conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
-        response_headers = BytesIO()
-        #conn.setopt(pycurl.VERBOSE, True)
-        conn.setopt(pycurl.URL, url)
-        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+
+        if not conn:
+            conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
+            # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
+            conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
+
         if useHTTPS:
             conn.setopt(pycurl.SSL_VERIFYPEER, 1)
             conn.setopt(pycurl.SSL_VERIFYHOST, 2)
             if caFile:
                 conn.setopt(pycurl.CAINFO, caFile)
 
+        response_headers = BytesIO()
+        #conn.setopt(pycurl.VERBOSE, True)
+        conn.setopt(pycurl.URL, url)
+        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+
         conn.setopt(pycurl.HTTPHEADER, customHeaders)
         conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
 
@@ -1003,6 +1045,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         #conn.setopt(pycurl.VERBOSE, True)
         conn.setopt(pycurl.URL, url)
         conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+        # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
+        conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
         if useHTTPS:
             conn.setopt(pycurl.SSL_VERIFYPEER, 1)
             conn.setopt(pycurl.SSL_VERIFYHOST, 2)
@@ -1040,5 +1084,71 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
     def sendDOHQueryWrapper(self, query, response, useQueue=True):
         return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
 
+    def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True):
+        return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
+
+    def sendDOHWithH2OQueryWrapper(self, query, response, useQueue=True):
+        return self.sendDOHQuery(self._dohWithH2OServerPort, self._serverName, self._dohWithH2OBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
+
     def sendDOTQueryWrapper(self, query, response, useQueue=True):
         return self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response, self._caCert, useQueue=useQueue)
+
+    def sendDOQQueryWrapper(self, query, response, useQueue=True):
+        return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName)
+
+    def sendDOH3QueryWrapper(self, query, response, useQueue=True):
+        return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName)
+    @classmethod
+    def getDOQConnection(cls, port, caFile=None, source=None, source_port=0):
+
+        manager = dns.quic.SyncQuicManager(
+            verify_mode=caFile
+        )
+
+        return manager.connect('127.0.0.1', port, source, source_port)
+
+    @classmethod
+    def sendDOQQuery(cls, port, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None):
+
+        if response:
+            if toQueue:
+                toQueue.put(response, True, timeout)
+            else:
+                cls._toResponderQueue.put(response, True, timeout)
+
+        (message, _) = quic_query(query, '127.0.0.1', timeout, port, verify=caFile, server_hostname=serverName)
+
+        receivedQuery = None
+
+        if useQueue:
+            if fromQueue:
+                if not fromQueue.empty():
+                    receivedQuery = fromQueue.get(True, timeout)
+            else:
+                if not cls._fromResponderQueue.empty():
+                    receivedQuery = cls._fromResponderQueue.get(True, timeout)
+
+        return (receivedQuery, message)
+
+    @classmethod
+    def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False):
+
+        if response:
+            if toQueue:
+                toQueue.put(response, True, timeout)
+            else:
+                cls._toResponderQueue.put(response, True, timeout)
+
+        message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post)
+
+        receivedQuery = None
+
+        if useQueue:
+            if fromQueue:
+                if not fromQueue.empty():
+                    receivedQuery = fromQueue.get(True, timeout)
+            else:
+                if not cls._fromResponderQueue.empty():
+                    receivedQuery = cls._fromResponderQueue.get(True, timeout)
+
+        return (receivedQuery, message)