]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/dnsdisttests.py
dnsdist: Add a regression test for DoQ certs/keys reloading
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
index c82f2b9ac9ddd08563762c21a92c0cdd720ae197..c854cfab53005f3e13296381a3a402a4677e6fab 100644 (file)
@@ -28,6 +28,9 @@ import h2.config
 import pycurl
 from io import BytesIO
 
+from doqclient import quic_query
+from doh3client import doh3_query
+
 from eqdnsmessage import AssertEqualDNSMessageMixin
 from proxyprotocol import ProxyProtocol
 
@@ -905,16 +908,17 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
     def checkResponseNoEDNS(self, expected, received):
         self.checkMessageNoEDNS(expected, received)
 
-    def generateNewCertificateAndKey(self):
+    @staticmethod
+    def generateNewCertificateAndKey(filePrefix):
         # generate and sign a new cert
-        cmd = ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', 'server.key', '-out', 'server.csr', '-config', 'configServer.conf']
+        cmd = ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', filePrefix + '.key', '-out', filePrefix + '.csr', '-config', 'configServer.conf']
         output = None
         try:
             process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
             output = process.communicate(input='')
         except subprocess.CalledProcessError as exc:
             raise AssertionError('openssl req failed (%d): %s' % (exc.returncode, exc.output))
-        cmd = ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', 'server.csr', '-out', 'server.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req']
+        cmd = ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', filePrefix + '.csr', '-out', filePrefix + '.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req']
         output = None
         try:
             process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
@@ -922,12 +926,12 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         except subprocess.CalledProcessError as exc:
             raise AssertionError('openssl x509 failed (%d): %s' % (exc.returncode, exc.output))
 
-        with open('server.chain', 'w') as outFile:
-            for inFileName in ['server.pem', 'ca.pem']:
+        with open(filePrefix + '.chain', 'w') as outFile:
+            for inFileName in [filePrefix + '.pem', 'ca.pem']:
                 with open(inFileName) as inFile:
                     outFile.write(inFile.read())
 
-        cmd = ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', 'server.pem', '-CAfile', 'ca.pem', '-inkey', 'server.key', '-out', 'server.p12']
+        cmd = ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', filePrefix + '.pem', '-CAfile', 'ca.pem', '-inkey', filePrefix + '.key', '-out', filePrefix + '.p12']
         output = None
         try:
             process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
@@ -1088,3 +1092,63 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
 
     def sendDOTQueryWrapper(self, query, response, useQueue=True):
         return self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response, self._caCert, useQueue=useQueue)
+
+    def sendDOQQueryWrapper(self, query, response, useQueue=True):
+        return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName)
+
+    def sendDOH3QueryWrapper(self, query, response, useQueue=True):
+        return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName)
+    @classmethod
+    def getDOQConnection(cls, port, caFile=None, source=None, source_port=0):
+
+        manager = dns.quic.SyncQuicManager(
+            verify_mode=caFile
+        )
+
+        return manager.connect('127.0.0.1', port, source, source_port)
+
+    @classmethod
+    def sendDOQQuery(cls, port, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None):
+
+        if response:
+            if toQueue:
+                toQueue.put(response, True, timeout)
+            else:
+                cls._toResponderQueue.put(response, True, timeout)
+
+        (message, _) = quic_query(query, '127.0.0.1', timeout, port, verify=caFile, server_hostname=serverName)
+
+        receivedQuery = None
+
+        if useQueue:
+            if fromQueue:
+                if not fromQueue.empty():
+                    receivedQuery = fromQueue.get(True, timeout)
+            else:
+                if not cls._fromResponderQueue.empty():
+                    receivedQuery = cls._fromResponderQueue.get(True, timeout)
+
+        return (receivedQuery, message)
+
+    @classmethod
+    def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False):
+
+        if response:
+            if toQueue:
+                toQueue.put(response, True, timeout)
+            else:
+                cls._toResponderQueue.put(response, True, timeout)
+
+        message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post)
+
+        receivedQuery = None
+
+        if useQueue:
+            if fromQueue:
+                if not fromQueue.empty():
+                    receivedQuery = fromQueue.get(True, timeout)
+            else:
+                if not cls._fromResponderQueue.empty():
+                    receivedQuery = cls._fromResponderQueue.get(True, timeout)
+
+        return (receivedQuery, message)