]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add regression tests for outgoing DoH health-checks and X-Forwarded-* headers
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 30 Sep 2021 14:52:59 +0000 (16:52 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 13 Oct 2021 13:20:54 +0000 (15:20 +0200)
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_OutgoingDOH.py

index 745ef72001e49cd89a4123e9a53d1b26b21fc937..b4ea7be4b9a58bc507b1c35d813893f9e1dfcf4c 100644 (file)
@@ -323,13 +323,102 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
 
         sock.close()
 
+    @classmethod
+    def handleDoHConnection(cls, config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol):
+        ignoreTrailing = trailingDataResponse is True
+        h2conn = h2.connection.H2Connection(config=config)
+        h2conn.initiate_connection()
+        conn.sendall(h2conn.data_to_send())
+        dnsData = {}
+
+        if useProxyProtocol:
+            # try to read the entire Proxy Protocol header
+            proxy = ProxyProtocol()
+            header = conn.recv(proxy.HEADER_SIZE)
+            if not header:
+                print('unable to get header')
+                conn.close()
+                return
+
+            if not proxy.parseHeader(header):
+                print('unable to parse header')
+                print(header)
+                conn.close()
+                return
+
+            proxyContent = conn.recv(proxy.contentLen)
+            if not proxyContent:
+                print('unable to get content')
+                conn.close()
+                return
+
+            payload = header + proxyContent
+            toQueue.put(payload, True, cls._queueTimeout)
+
+        # be careful, HTTP/2 headers and data might be in different recv() results
+        requestHeaders = None
+        while True:
+            data = conn.recv(65535)
+            if not data:
+                break
+
+            events = h2conn.receive_data(data)
+            for event in events:
+                if isinstance(event, h2.events.RequestReceived):
+                    requestHeaders = event.headers
+                if isinstance(event, h2.events.DataReceived):
+                    h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id)
+                    if not event.stream_id in dnsData:
+                      dnsData[event.stream_id] = b''
+                    dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data)
+                    if event.stream_ended:
+                        forceRcode = None
+                        status = 200
+                        try:
+                            request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing)
+                        except dns.message.TrailingJunk as e:
+                            if trailingDataResponse is False or forceRcode is True:
+                                raise
+                            print("DOH query with trailing data, synthesizing response")
+                            request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True)
+                            forceRcode = trailingDataResponse
+
+                        if callback:
+                            status, wire = callback(request, requestHeaders, fromQueue, toQueue)
+                        else:
+                            response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
+                            if response:
+                                wire = response.to_wire(max_size=65535)
+
+                        if not wire:
+                            conn.close()
+                            conn = None
+                            break
+
+                        headers = [
+                          (':status', str(status)),
+                          ('content-length', str(len(wire))),
+                          ('content-type', 'application/dns-message'),
+                        ]
+                        h2conn.send_headers(stream_id=event.stream_id, headers=headers)
+                        h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True)
+
+                data_to_send = h2conn.data_to_send()
+                if data_to_send:
+                    conn.sendall(data_to_send)
+
+            if conn is None:
+                break
+
+        if conn is not None:
+            conn.close()
+
     @classmethod
     def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False):
         # trailingDataResponse=True means "ignore trailing data".
         # Other values are either False (meaning "raise an exception")
         # or are interpreted as a response RCODE for queries with trailing data.
         # callback is invoked for every -even healthcheck ones- query and should return a raw response
-        ignoreTrailing = trailingDataResponse is True
 
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
@@ -355,88 +444,11 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
               continue
 
             conn.settimeout(5.0)
-            h2conn = h2.connection.H2Connection(config=config)
-            h2conn.initiate_connection()
-            conn.sendall(h2conn.data_to_send())
-            dnsData = {}
-
-            if useProxyProtocol:
-                # try to read the entire Proxy Protocol header
-                proxy = ProxyProtocol()
-                header = conn.recv(proxy.HEADER_SIZE)
-                if not header:
-                    print('unable to get header')
-                    conn.close()
-                    continue
-
-                if not proxy.parseHeader(header):
-                    print('unable to parse header')
-                    print(header)
-                    conn.close()
-                    continue
-
-                proxyContent = conn.recv(proxy.contentLen)
-                if not proxyContent:
-                    print('unable to get content')
-                    conn.close()
-                    continue
-
-                payload = header + proxyContent
-                toQueue.put(payload, True, cls._queueTimeout)
-
-            while True:
-                data = conn.recv(65535)
-                if not data:
-                    break
-
-                events = h2conn.receive_data(data)
-                for event in events:
-                    if isinstance(event, h2.events.DataReceived):
-                        h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id)
-                        if not event.stream_id in dnsData:
-                          dnsData[event.stream_id] = b''
-                        dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data)
-                        if event.stream_ended:
-                            forceRcode = None
-                            status = 200
-                            try:
-                                request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing)
-                            except dns.message.TrailingJunk as e:
-                                if trailingDataResponse is False or forceRcode is True:
-                                    raise
-                                print("DOH query with trailing data, synthesizing response")
-                                request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True)
-                                forceRcode = trailingDataResponse
-
-                            if callback:
-                                status, wire = callback(request)
-                            else:
-                                response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
-                                if response:
-                                    wire = response.to_wire(max_size=65535)
-
-                            if not wire:
-                                conn.close()
-                                conn = None
-                                break
-
-                            headers = [
-                              (':status', str(status)),
-                              ('content-length', str(len(wire))),
-                              ('content-type', 'application/dns-message'),
-                            ]
-                            h2conn.send_headers(stream_id=event.stream_id, headers=headers)
-                            h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True)
-
-                    data_to_send = h2conn.data_to_send()
-                    if data_to_send:
-                        conn.sendall(data_to_send)
-
-                if conn is None:
-                    break
-
-            if conn is not None:
-                conn.close()
+            thread = threading.Thread(name='DoH Connection Handler',
+                                      target=cls.handleDoHConnection,
+                                      args=[config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol])
+            thread.setDaemon(True)
+            thread.start()
 
         sock.close()
 
index 9f0821f07707f02a8d79fbed594a1d9437156f8b..217c885fef0e9d478b73f78fa4d2ecb74bf9eee1 100644 (file)
@@ -1,4 +1,6 @@
 #!/usr/bin/env python
+import base64
+import copy
 import dns
 import requests
 import ssl
@@ -20,7 +22,7 @@ class OutgoingDOHTests(object):
         self.assertNotIn('UDP Responder', self._responsesCounter)
         self.assertNotIn('TCP Responder', self._responsesCounter)
         self.assertNotIn('TLS Responder', self._responsesCounter)
-        self.assertEqual(self._responsesCounter['DOH Responder'], numberOfDOHQueries)
+        self.assertEqual(self._responsesCounter['DoH Connection Handler'], numberOfDOHQueries)
 
     def getServerStat(self, key):
         headers = {'x-api-key': self._webServerAPIKey}
@@ -135,6 +137,14 @@ class OutgoingDOHTests(object):
             (_, receivedResponse) = self.sendTCPQuery(query, useQueue=False, response=None)
             self.assertEqual(receivedResponse, expectedResponse)
 
+    def testZHealthChecks(self):
+        # this test has to run last, as it will mess up the TCP connection counter,
+        # hence the 'Z' in the name
+        self.sendConsoleCommand("getServer(0):setAuto()")
+        time.sleep(2)
+        status = self.sendConsoleCommand("if getServer(0):isUp() then return 'up' else return 'down' end").strip("\n")
+        self.assertEqual(status, 'up')
+
 class BrokenOutgoingDOHTests(object):
 
     _webTimeout = 2.0
@@ -254,10 +264,15 @@ class OutgoingDOHBrokenResponsesTests(object):
 
 class TestOutgoingDOHOpenSSL(DNSDistTest, OutgoingDOHTests):
     _tlsBackendPort = 10543
-    _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed']
+    _tlsProvider = 'openssl'
+    _consoleKey = DNSDistTest.generateConsoleKey()
+    _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
+    _config_params = ['_consoleKeyB64', '_consolePort', '_tlsBackendPort', '_tlsProvider', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed']
     _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%d")
     setMaxTCPClientThreads(1)
-    newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp()
+    newServer{address="127.0.0.1:%s", tls='%s', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp()
     webserver("127.0.0.1:%s")
     setWebserverConfig({password="%s", apiKey="%s"})
 
@@ -281,10 +296,15 @@ class TestOutgoingDOHOpenSSL(DNSDistTest, OutgoingDOHTests):
 
 class TestOutgoingDOHGnuTLS(DNSDistTest, OutgoingDOHTests):
     _tlsBackendPort = 10544
-    _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed']
+    _tlsProvider = 'gnutls'
+    _consoleKey = DNSDistTest.generateConsoleKey()
+    _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
+    _config_params = ['_consoleKeyB64', '_consolePort', '_tlsBackendPort', '_tlsProvider', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed']
     _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%d")
     setMaxTCPClientThreads(1)
-    newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp()
+    newServer{address="127.0.0.1:%s", tls='%s', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp()
     webserver("127.0.0.1:%s")
     setWebserverConfig({password="%s", apiKey="%s"})
 
@@ -348,10 +368,15 @@ class TestOutgoingDOHGnuTLSWrongCertName(DNSDistTest, BrokenOutgoingDOHTests):
 
 class TestOutgoingDOHOpenSSLWrongCertNameButNoCheck(DNSDistTest, OutgoingDOHTests):
     _tlsBackendPort = 10547
-    _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed']
+    _tlsProvider = 'openssl'
+    _consoleKey = DNSDistTest.generateConsoleKey()
+    _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
+    _config_params = ['_consoleKeyB64', '_consolePort', '_tlsBackendPort', '_tlsProvider', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed']
     _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%d")
     setMaxTCPClientThreads(1)
-    newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp()
+    newServer{address="127.0.0.1:%s", tls='%s', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp()
     webserver("127.0.0.1:%s")
     setWebserverConfig({password="%s", apiKey="%s"})
 
@@ -374,10 +399,15 @@ class TestOutgoingDOHOpenSSLWrongCertNameButNoCheck(DNSDistTest, OutgoingDOHTest
 
 class TestOutgoingDOHGnuTLSWrongCertNameButNoCheck(DNSDistTest, OutgoingDOHTests):
     _tlsBackendPort = 10548
-    _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed']
+    _tlsProvider = 'gnutls'
+    _consoleKey = DNSDistTest.generateConsoleKey()
+    _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
+    _config_params = ['_consoleKeyB64', '_consolePort', '_tlsBackendPort', '_tlsProvider', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed']
     _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%d")
     setMaxTCPClientThreads(1)
-    newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp()
+    newServer{address="127.0.0.1:%s", tls='%s', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp()
     webserver("127.0.0.1:%s")
     setWebserverConfig({password="%s", apiKey="%s"})
 
@@ -414,7 +444,7 @@ class TestOutgoingDOHBrokenResponsesOpenSSL(DNSDistTest, OutgoingDOHBrokenRespon
     addAction(SuffixMatchNodeRule(smn), PoolAction('cache'))
     """
 
-    def callback(request):
+    def callback(request, headers, fromQueue, toQueue):
 
         if str(request.question[0].name) == '500-status.broken-responses.outgoing-doh.test.powerdns.com.':
             print("returning 500")
@@ -451,7 +481,7 @@ class TestOutgoingDOHBrokenResponsesGnuTLS(DNSDistTest, OutgoingDOHBrokenRespons
     """
     _verboseMode = True
 
-    def callback(request):
+    def callback(request, headers, fromQueue, toQueue):
 
         if str(request.question[0].name) == '500-status.broken-responses.outgoing-doh.test.powerdns.com.':
             print("returning 500")
@@ -523,3 +553,74 @@ class TestOutgoingDOHProxyProtocol(DNSDistTest):
         self.assertEqual(query, receivedQuery)
         self.assertEqual(receivedResponse, expectedResponse)
         self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True)
+
+class TestOutgoingDOHXForwarded(DNSDistTest):
+    _tlsBackendPort = 10560
+    _config_params = ['_tlsBackendPort']
+    _config_template = """
+    setMaxTCPClientThreads(1)
+    newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', addXForwardedHeaders=true}
+    """
+    _verboseMode = True
+
+    def callback(request, headersList, fromQueue, toQueue):
+
+        if str(request.question[0].name) == 'a.root-servers.net.':
+            # do not check headers on health-check queries
+            return 200, dns.message.make_response(request).to_wire()
+
+        headers = {}
+        if headersList:
+            for k,v in headersList:
+                headers[k] = v
+
+        if not b'x-forwarded-for' in headers:
+            print("missing X-Forwarded-For")
+            return 406, b'Missing X-Forwarded-For header'
+        if not b'x-forwarded-port' in headers:
+            print("missing X-Forwarded-Port")
+            return 406, b'Missing X-Forwarded-Port header'
+        if not b'x-forwarded-proto' in headers:
+            print("missing X-Forwarded-Proto")
+            return 406, b'Missing X-Forwarded-Proto header'
+
+        toQueue.put(request, True, 1.0)
+        response = fromQueue.get(True, 1.0)
+        if response:
+            response = copy.copy(response)
+            response.id = request.id
+
+        return 200, response.to_wire()
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.set_alpn_protocols(["h2"])
+        tlsContext.load_cert_chain('server.chain', 'server.key')
+
+        print("Launching DOH responder..")
+        cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, cls.callback, tlsContext])
+        cls._DOHResponder.setDaemon(True)
+        cls._DOHResponder.start()
+
+    def testXForwarded(self):
+        """
+        Outgoing DOH: X-Forwarded
+        """
+        name = 'x-forwarded-for.outgoing-doh.test.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        expectedResponse.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, expectedResponse)
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(receivedResponse, expectedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, expectedResponse)
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(receivedResponse, expectedResponse)