]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add regression tests for DoH between dnsdist and the backend
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 27 Aug 2021 14:54:03 +0000 (16:54 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 13 Sep 2021 13:28:28 +0000 (15:28 +0200)
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/requirements.txt
regression-tests.dnsdist/test_OutgoingDOH.py [new file with mode: 0644]

index 7284122ec7efdacd66165fbc13b20453eadfdbd2..4647b9279414f696263a8aecfebfa45d2fffdc04 100644 (file)
@@ -10,12 +10,19 @@ import sys
 import threading
 import time
 import unittest
+
 import clientsubnetoption
+
 import dns
 import dns.message
+
 import libnacl
 import libnacl.utils
 
+import h2.connection
+import h2.events
+import h2.config
+
 from eqdnsmessage import AssertEqualDNSMessageMixin
 
 # Python2/3 compatibility hacks
@@ -311,6 +318,98 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
 
         sock.close()
 
+    @classmethod
+    def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None):
+        # trailingDataResponse=True means "ignore trailing data".
+        # Other values are either False (meaning "raise an exception")
+        # or are interpreted as a response RCODE for queries with trailing data.
+        # callback is invoked for every -even healthcheck ones- query and should return a raw response
+        ignoreTrailing = trailingDataResponse is True
+
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+        try:
+            sock.bind(("127.0.0.1", port))
+        except socket.error as e:
+            print("Error binding in the TCP responder: %s" % str(e))
+            sys.exit(1)
+
+        sock.listen(100)
+        if tlsContext:
+            sock = tlsContext.wrap_socket(sock, server_side=True)
+
+        config = h2.config.H2Configuration(client_side=False)
+
+        while True:
+            try:
+                (conn, _) = sock.accept()
+            except ssl.SSLError:
+                continue
+            except ConnectionResetError:
+              continue
+            conn.settimeout(5.0)
+            h2conn = h2.connection.H2Connection(config=config)
+            h2conn.initiate_connection()
+            conn.sendall(h2conn.data_to_send())
+            dnsData = {}
+
+            while True:
+                data = conn.recv(65535)
+                if not data:
+                    break
+
+                events = h2conn.receive_data(data)
+                for event in events:
+                    if isinstance(event, h2.events.DataReceived):
+                        h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id)
+                        if not event.stream_id in dnsData:
+                          dnsData[event.stream_id] = b''
+                        dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data)
+                        if event.stream_ended:
+                            forceRcode = None
+                            status = 200
+                            try:
+                                request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing)
+                            except dns.message.TrailingJunk as e:
+                                if trailingDataResponse is False or forceRcode is True:
+                                    raise
+                                print("DOH query with trailing data, synthesizing response")
+                                request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True)
+                                forceRcode = trailingDataResponse
+
+                            if callback:
+                                status, wire = callback(request)
+                            else:
+                                response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
+                                if response:
+                                    wire = response.to_wire(max_size=65535)
+
+                            if not wire:
+                                conn.close()
+                                conn = None
+                                break
+
+                            headers = [
+                              (':status', str(status)),
+                              ('content-length', str(len(wire))),
+                              ('content-type', 'application/dns-message'),
+                            ]
+                            h2conn.send_headers(stream_id=event.stream_id, headers=headers)
+                            h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True)
+
+                    data_to_send = h2conn.data_to_send()
+                    if data_to_send:
+                        conn.sendall(data_to_send)
+
+                if conn is None:
+                    break
+
+            if conn is not None:
+                conn.close()
+
+        sock.close()
+
     @classmethod
     def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
         if useQueue and response is not None:
index afa9fa76ecd5ff7cd0ed164d6d83926cd9eb8008..fd27b154cb5470e689cfd0318ef396923f85ae0f 100644 (file)
@@ -8,3 +8,4 @@ future>=0.17.1
 pycurl>=7.43.0
 lmdb>=0.95
 cdbx==0.1.2
+h2>=4.0.0
diff --git a/regression-tests.dnsdist/test_OutgoingDOH.py b/regression-tests.dnsdist/test_OutgoingDOH.py
new file mode 100644 (file)
index 0000000..28a579c
--- /dev/null
@@ -0,0 +1,350 @@
+#!/usr/bin/env python
+import dns
+import requests
+import ssl
+import threading
+import time
+
+from dnsdisttests import DNSDistTest
+
+class OutgoingDOHTests(object):
+
+    _webTimeout = 2.0
+    _webServerPort = 8083
+    _webServerBasicAuthPassword = 'secret'
+    _webServerAPIKey = 'apisecret'
+
+    def checkOnlyDOHResponderHit(self, numberOfDOHQueries=1):
+        self.assertNotIn('UDP Responder', self._responsesCounter)
+        self.assertNotIn('TCP Responder', self._responsesCounter)
+        self.assertNotIn('TLS Responder', self._responsesCounter)
+        self.assertEqual(self._responsesCounter['DOH Responder'], numberOfDOHQueries)
+
+    def getServerStat(self, key):
+        headers = {'x-api-key': self._webServerAPIKey}
+        url = 'http://127.0.0.1:' + str(self._webServerPort) + '/api/v1/servers/localhost'
+        r = requests.get(url, headers=headers, timeout=self._webTimeout)
+        self.assertTrue(r)
+        self.assertEqual(r.status_code, 200)
+        self.assertTrue(r.json())
+        content = r.json()
+        self.assertTrue(len(content['servers']), 1)
+        server = content['servers'][0]
+        self.assertIn(key, server)
+        return server[key]
+
+    def testUDP(self):
+        """
+        Outgoing DOH: UDP query is sent via DOH
+        """
+        name = 'udp.outgoing-doh.test.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        expectedResponse.answer.append(rrset)
+
+        numberOfUDPQueries = 10
+        for _ in range(numberOfUDPQueries):
+            (receivedQuery, receivedResponse) = self.sendUDPQuery(query, expectedResponse)
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(receivedResponse, expectedResponse)
+
+        # there was one TCP query in testTCP (below, but before in alphabetical order)
+        numberOfQueries = numberOfUDPQueries + 1
+        self.checkOnlyDOHResponderHit(numberOfUDPQueries)
+
+        self.assertEqual(self.getServerStat('tcpNewConnections'), 1)
+        self.assertEqual(self.getServerStat('tcpReusedConnections'), numberOfQueries - 1)
+        self.assertEqual(self.getServerStat('tlsResumptions'), 0)
+
+    def testTCP(self):
+        """
+        Outgoing DOH: TCP query is sent via DOH
+        """
+        name = 'tcp.outgoing-doh.test.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        expectedResponse.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, expectedResponse)
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(receivedResponse, expectedResponse)
+        self.checkOnlyDOHResponderHit()
+        self.assertEqual(self.getServerStat('tcpNewConnections'), 1)
+        self.assertEqual(self.getServerStat('tcpReusedConnections'), 0)
+        self.assertEqual(self.getServerStat('tlsResumptions'), 0)
+
+class BrokenOutgoingDOHTests(object):
+
+    _webTimeout = 2.0
+    _webServerPort = 8083
+    _webServerBasicAuthPassword = 'secret'
+    _webServerAPIKey = 'apisecret'
+
+    def checkNoResponderHit(self):
+        self.assertNotIn('UDP Responder', self._responsesCounter)
+        self.assertNotIn('TCP Responder', self._responsesCounter)
+        self.assertNotIn('TLS Responder', self._responsesCounter)
+        self.assertNotIn('DOH Responder', self._responsesCounter)
+
+    def testUDP(self):
+        """
+        Outgoing DOH (broken): UDP query is sent via DOH
+        """
+        name = 'udp.broken-outgoing-doh.test.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.assertEqual(receivedResponse, None)
+        self.checkNoResponderHit()
+
+    def testTCP(self):
+        """
+        Outgoing DOH (broken): TCP query is sent via DOH
+        """
+        name = 'tcp.broken-outgoing-doh.test.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        expectedResponse.answer.append(rrset)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.assertEqual(receivedResponse, None)
+        self.checkNoResponderHit()
+
+class OutgoingDOHBrokenResponsesTests(object):
+
+    _webTimeout = 2.0
+    _webServerPort = 8083
+    _webServerBasicAuthPassword = 'secret'
+    _webServerAPIKey = 'apisecret'
+
+    def testUDP(self):
+        """
+        Outgoing DOH (broken responses): UDP query is sent via DOH
+        """
+        name = '500-status.broken-responses.outgoing-doh.test.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        print(receivedResponse)
+        self.assertEqual(receivedResponse, None)
+
+        name = 'invalid-dns-payload.broken-responses.outgoing-doh.test.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.assertEqual(receivedResponse, None)
+
+        name = 'closing-connection-id.broken-responses.outgoing-doh.test.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.assertEqual(receivedResponse, None)
+
+class TestOutgoingDOHOpenSSL(DNSDistTest, OutgoingDOHTests):
+    _tlsBackendPort = 10543
+    _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+    _config_template = """
+    setMaxTCPClientThreads(1)
+    newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query'}:setUp()
+    webserver("127.0.0.1:%s")
+    setWebserverConfig({password="%s", apiKey="%s"})
+    """
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.set_alpn_protocols(["h2"])
+        tlsContext.load_cert_chain('server.chain', 'server.key')
+
+        print("Launching DOH responder..")
+        cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+        cls._DOHResponder.setDaemon(True)
+        cls._DOHResponder.start()
+
+class TestOutgoingDOHGnuTLS(DNSDistTest, OutgoingDOHTests):
+    _tlsBackendPort = 10544
+    _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+    _config_template = """
+    setMaxTCPClientThreads(1)
+    newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query'}:setUp()
+    webserver("127.0.0.1:%s")
+    setWebserverConfig({password="%s", apiKey="%s"})
+    """
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.load_cert_chain('server.chain', 'server.key')
+        tlsContext.keylog_filename = "/tmp/keys"
+
+        print("Launching DOH responder..")
+        cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+        cls._DOHResponder.setDaemon(True)
+        cls._DOHResponder.start()
+
+class TestOutgoingDOHOpenSSLWrongCertName(DNSDistTest, BrokenOutgoingDOHTests):
+    _tlsBackendPort = 10545
+    _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+    _config_template = """
+    setMaxTCPClientThreads(1)
+    newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=true, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query'}:setUp()
+    webserver("127.0.0.1:%s")
+    setWebserverConfig({password="%s", apiKey="%s"})
+    """
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.load_cert_chain('server.chain', 'server.key')
+
+        print("Launching DOH responder..")
+        cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+        cls._DOHResponder.setDaemon(True)
+        cls._DOHResponder.start()
+
+class TestOutgoingDOHGnuTLSWrongCertName(DNSDistTest, BrokenOutgoingDOHTests):
+    _tlsBackendPort = 10546
+    _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+    _config_template = """
+    setMaxTCPClientThreads(1)
+    newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query'}:setUp()
+    webserver("127.0.0.1:%s")
+    setWebserverConfig({password="%s", apiKey="%s"})
+    """
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.load_cert_chain('server.chain', 'server.key')
+
+        print("Launching DOH responder..")
+        cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+        cls._DOHResponder.setDaemon(True)
+        cls._DOHResponder.start()
+
+class TestOutgoingDOHOpenSSLWrongCertNameButNoCheck(DNSDistTest, OutgoingDOHTests):
+    _tlsBackendPort = 10547
+    _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+    _config_template = """
+    setMaxTCPClientThreads(1)
+    newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query'}:setUp()
+    webserver("127.0.0.1:%s")
+    setWebserverConfig({password="%s", apiKey="%s"})
+    """
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.load_cert_chain('server.chain', 'server.key')
+
+        print("Launching DOH responder..")
+        cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+        cls._DOHResponder.setDaemon(True)
+        cls._DOHResponder.start()
+
+class TestOutgoingDOHGnuTLSWrongCertNameButNoCheck(DNSDistTest, OutgoingDOHTests):
+    _tlsBackendPort = 10548
+    _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+    _config_template = """
+    setMaxTCPClientThreads(1)
+    newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query'}:setUp()
+    webserver("127.0.0.1:%s")
+    setWebserverConfig({password="%s", apiKey="%s"})
+    """
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.load_cert_chain('server.chain', 'server.key')
+
+        print("Launching DOH responder..")
+        cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+        cls._DOHResponder.setDaemon(True)
+        cls._DOHResponder.start()
+
+class TestOutgoingDOHBrokenResponsesOpenSSL(DNSDistTest, OutgoingDOHBrokenResponsesTests):
+    _tlsBackendPort = 10549
+    _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+    _config_template = """
+    setMaxTCPClientThreads(1)
+    newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query'}:setUp()
+    webserver("127.0.0.1:%s")
+    setWebserverConfig({password="%s", apiKey="%s"})
+    """
+
+    def callback(request):
+
+        if str(request.question[0].name) == '500-status.broken-responses.outgoing-doh.test.powerdns.com.':
+            print("returning 500")
+            return 500, b'Server error'
+
+        if str(request.question[0].name) == 'invalid-dns-payload.broken-responses.outgoing-doh.test.powerdns.com.':
+            return 200, b'not DNS'
+
+        if str(request.question[0].name) == 'closing-connection-id.broken-responses.outgoing-doh.test.powerdns.com.':
+            return 200, None
+
+        print("Returning default for %s" % (request.question[0].name))
+        return 200, dns.message.make_response(request).to_wire()
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.set_alpn_protocols(["h2"])
+        tlsContext.load_cert_chain('server.chain', 'server.key')
+
+        print("Launching DOH responder..")
+        cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, cls.callback, tlsContext])
+        cls._DOHResponder.setDaemon(True)
+        cls._DOHResponder.start()
+
+class TestOutgoingDOHBrokenResponsesGnuTLS(DNSDistTest, OutgoingDOHBrokenResponsesTests):
+    _tlsBackendPort = 10550
+    _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+    _config_template = """
+    setMaxTCPClientThreads(1)
+    newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query'}:setUp()
+    webserver("127.0.0.1:%s")
+    setWebserverConfig({password="%s", apiKey="%s"})
+    """
+
+    def callback(request):
+
+        if str(request.question[0].name) == '500-status.broken-responses.outgoing-doh.test.powerdns.com.':
+            print("returning 500")
+            return 500, b'Server error'
+
+        if str(request.question[0].name) == 'invalid-dns-payload.broken-responses.outgoing-doh.test.powerdns.com.':
+            return 200, b'not DNS'
+
+        if str(request.question[0].name) == 'closing-connection-id.broken-responses.outgoing-doh.test.powerdns.com.':
+            return 200, None
+
+        print("Returning default for %s" % (request.question[0].name))
+        return 200, dns.message.make_response(request).to_wire()
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.set_alpn_protocols(["h2"])
+        tlsContext.load_cert_chain('server.chain', 'server.key')
+
+        print("Launching DOH responder..")
+        cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, cls.callback, tlsContext])
+        cls._DOHResponder.setDaemon(True)
+        cls._DOHResponder.start()