From: Remi Gacogne Date: Thu, 30 Sep 2021 14:52:59 +0000 (+0200) Subject: dnsdist: Add regression tests for outgoing DoH health-checks and X-Forwarded-* headers X-Git-Tag: dnsdist-1.7.0-alpha2~10^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c4c72a2ca96cd1dd4f4c41db496447cea2570128;p=thirdparty%2Fpdns.git dnsdist: Add regression tests for outgoing DoH health-checks and X-Forwarded-* headers --- diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 745ef72001..b4ea7be4b9 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -323,13 +323,102 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): sock.close() + @classmethod + def handleDoHConnection(cls, config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol): + ignoreTrailing = trailingDataResponse is True + h2conn = h2.connection.H2Connection(config=config) + h2conn.initiate_connection() + conn.sendall(h2conn.data_to_send()) + dnsData = {} + + if useProxyProtocol: + # try to read the entire Proxy Protocol header + proxy = ProxyProtocol() + header = conn.recv(proxy.HEADER_SIZE) + if not header: + print('unable to get header') + conn.close() + return + + if not proxy.parseHeader(header): + print('unable to parse header') + print(header) + conn.close() + return + + proxyContent = conn.recv(proxy.contentLen) + if not proxyContent: + print('unable to get content') + conn.close() + return + + payload = header + proxyContent + toQueue.put(payload, True, cls._queueTimeout) + + # be careful, HTTP/2 headers and data might be in different recv() results + requestHeaders = None + while True: + data = conn.recv(65535) + if not data: + break + + events = h2conn.receive_data(data) + for event in events: + if isinstance(event, h2.events.RequestReceived): + requestHeaders = event.headers + if isinstance(event, h2.events.DataReceived): + h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id) + if not event.stream_id in dnsData: + dnsData[event.stream_id] = b'' + dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data) + if event.stream_ended: + forceRcode = None + status = 200 + try: + request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing) + except dns.message.TrailingJunk as e: + if trailingDataResponse is False or forceRcode is True: + raise + print("DOH query with trailing data, synthesizing response") + request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True) + forceRcode = trailingDataResponse + + if callback: + status, wire = callback(request, requestHeaders, fromQueue, toQueue) + else: + response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) + if response: + wire = response.to_wire(max_size=65535) + + if not wire: + conn.close() + conn = None + break + + headers = [ + (':status', str(status)), + ('content-length', str(len(wire))), + ('content-type', 'application/dns-message'), + ] + h2conn.send_headers(stream_id=event.stream_id, headers=headers) + h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True) + + data_to_send = h2conn.data_to_send() + if data_to_send: + conn.sendall(data_to_send) + + if conn is None: + break + + if conn is not None: + conn.close() + @classmethod def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False): # trailingDataResponse=True means "ignore trailing data". # Other values are either False (meaning "raise an exception") # or are interpreted as a response RCODE for queries with trailing data. # callback is invoked for every -even healthcheck ones- query and should return a raw response - ignoreTrailing = trailingDataResponse is True sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -355,88 +444,11 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): continue conn.settimeout(5.0) - h2conn = h2.connection.H2Connection(config=config) - h2conn.initiate_connection() - conn.sendall(h2conn.data_to_send()) - dnsData = {} - - if useProxyProtocol: - # try to read the entire Proxy Protocol header - proxy = ProxyProtocol() - header = conn.recv(proxy.HEADER_SIZE) - if not header: - print('unable to get header') - conn.close() - continue - - if not proxy.parseHeader(header): - print('unable to parse header') - print(header) - conn.close() - continue - - proxyContent = conn.recv(proxy.contentLen) - if not proxyContent: - print('unable to get content') - conn.close() - continue - - payload = header + proxyContent - toQueue.put(payload, True, cls._queueTimeout) - - while True: - data = conn.recv(65535) - if not data: - break - - events = h2conn.receive_data(data) - for event in events: - if isinstance(event, h2.events.DataReceived): - h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id) - if not event.stream_id in dnsData: - dnsData[event.stream_id] = b'' - dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data) - if event.stream_ended: - forceRcode = None - status = 200 - try: - request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing) - except dns.message.TrailingJunk as e: - if trailingDataResponse is False or forceRcode is True: - raise - print("DOH query with trailing data, synthesizing response") - request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True) - forceRcode = trailingDataResponse - - if callback: - status, wire = callback(request) - else: - response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) - if response: - wire = response.to_wire(max_size=65535) - - if not wire: - conn.close() - conn = None - break - - headers = [ - (':status', str(status)), - ('content-length', str(len(wire))), - ('content-type', 'application/dns-message'), - ] - h2conn.send_headers(stream_id=event.stream_id, headers=headers) - h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True) - - data_to_send = h2conn.data_to_send() - if data_to_send: - conn.sendall(data_to_send) - - if conn is None: - break - - if conn is not None: - conn.close() + thread = threading.Thread(name='DoH Connection Handler', + target=cls.handleDoHConnection, + args=[config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol]) + thread.setDaemon(True) + thread.start() sock.close() diff --git a/regression-tests.dnsdist/test_OutgoingDOH.py b/regression-tests.dnsdist/test_OutgoingDOH.py index 9f0821f077..217c885fef 100644 --- a/regression-tests.dnsdist/test_OutgoingDOH.py +++ b/regression-tests.dnsdist/test_OutgoingDOH.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +import base64 +import copy import dns import requests import ssl @@ -20,7 +22,7 @@ class OutgoingDOHTests(object): self.assertNotIn('UDP Responder', self._responsesCounter) self.assertNotIn('TCP Responder', self._responsesCounter) self.assertNotIn('TLS Responder', self._responsesCounter) - self.assertEqual(self._responsesCounter['DOH Responder'], numberOfDOHQueries) + self.assertEqual(self._responsesCounter['DoH Connection Handler'], numberOfDOHQueries) def getServerStat(self, key): headers = {'x-api-key': self._webServerAPIKey} @@ -135,6 +137,14 @@ class OutgoingDOHTests(object): (_, receivedResponse) = self.sendTCPQuery(query, useQueue=False, response=None) self.assertEqual(receivedResponse, expectedResponse) + def testZHealthChecks(self): + # this test has to run last, as it will mess up the TCP connection counter, + # hence the 'Z' in the name + self.sendConsoleCommand("getServer(0):setAuto()") + time.sleep(2) + status = self.sendConsoleCommand("if getServer(0):isUp() then return 'up' else return 'down' end").strip("\n") + self.assertEqual(status, 'up') + class BrokenOutgoingDOHTests(object): _webTimeout = 2.0 @@ -254,10 +264,15 @@ class OutgoingDOHBrokenResponsesTests(object): class TestOutgoingDOHOpenSSL(DNSDistTest, OutgoingDOHTests): _tlsBackendPort = 10543 - _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed'] + _tlsProvider = 'openssl' + _consoleKey = DNSDistTest.generateConsoleKey() + _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii') + _config_params = ['_consoleKeyB64', '_consolePort', '_tlsBackendPort', '_tlsProvider', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed'] _config_template = """ + setKey("%s") + controlSocket("127.0.0.1:%d") setMaxTCPClientThreads(1) - newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp() + newServer{address="127.0.0.1:%s", tls='%s', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp() webserver("127.0.0.1:%s") setWebserverConfig({password="%s", apiKey="%s"}) @@ -281,10 +296,15 @@ class TestOutgoingDOHOpenSSL(DNSDistTest, OutgoingDOHTests): class TestOutgoingDOHGnuTLS(DNSDistTest, OutgoingDOHTests): _tlsBackendPort = 10544 - _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed'] + _tlsProvider = 'gnutls' + _consoleKey = DNSDistTest.generateConsoleKey() + _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii') + _config_params = ['_consoleKeyB64', '_consolePort', '_tlsBackendPort', '_tlsProvider', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed'] _config_template = """ + setKey("%s") + controlSocket("127.0.0.1:%d") setMaxTCPClientThreads(1) - newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp() + newServer{address="127.0.0.1:%s", tls='%s', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp() webserver("127.0.0.1:%s") setWebserverConfig({password="%s", apiKey="%s"}) @@ -348,10 +368,15 @@ class TestOutgoingDOHGnuTLSWrongCertName(DNSDistTest, BrokenOutgoingDOHTests): class TestOutgoingDOHOpenSSLWrongCertNameButNoCheck(DNSDistTest, OutgoingDOHTests): _tlsBackendPort = 10547 - _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed'] + _tlsProvider = 'openssl' + _consoleKey = DNSDistTest.generateConsoleKey() + _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii') + _config_params = ['_consoleKeyB64', '_consolePort', '_tlsBackendPort', '_tlsProvider', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed'] _config_template = """ + setKey("%s") + controlSocket("127.0.0.1:%d") setMaxTCPClientThreads(1) - newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp() + newServer{address="127.0.0.1:%s", tls='%s', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp() webserver("127.0.0.1:%s") setWebserverConfig({password="%s", apiKey="%s"}) @@ -374,10 +399,15 @@ class TestOutgoingDOHOpenSSLWrongCertNameButNoCheck(DNSDistTest, OutgoingDOHTest class TestOutgoingDOHGnuTLSWrongCertNameButNoCheck(DNSDistTest, OutgoingDOHTests): _tlsBackendPort = 10548 - _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed'] + _tlsProvider = 'gnutls' + _consoleKey = DNSDistTest.generateConsoleKey() + _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii') + _config_params = ['_consoleKeyB64', '_consolePort', '_tlsBackendPort', '_tlsProvider', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed'] _config_template = """ + setKey("%s") + controlSocket("127.0.0.1:%d") setMaxTCPClientThreads(1) - newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp() + newServer{address="127.0.0.1:%s", tls='%s', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp() webserver("127.0.0.1:%s") setWebserverConfig({password="%s", apiKey="%s"}) @@ -414,7 +444,7 @@ class TestOutgoingDOHBrokenResponsesOpenSSL(DNSDistTest, OutgoingDOHBrokenRespon addAction(SuffixMatchNodeRule(smn), PoolAction('cache')) """ - def callback(request): + def callback(request, headers, fromQueue, toQueue): if str(request.question[0].name) == '500-status.broken-responses.outgoing-doh.test.powerdns.com.': print("returning 500") @@ -451,7 +481,7 @@ class TestOutgoingDOHBrokenResponsesGnuTLS(DNSDistTest, OutgoingDOHBrokenRespons """ _verboseMode = True - def callback(request): + def callback(request, headers, fromQueue, toQueue): if str(request.question[0].name) == '500-status.broken-responses.outgoing-doh.test.powerdns.com.': print("returning 500") @@ -523,3 +553,74 @@ class TestOutgoingDOHProxyProtocol(DNSDistTest): self.assertEqual(query, receivedQuery) self.assertEqual(receivedResponse, expectedResponse) self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True) + +class TestOutgoingDOHXForwarded(DNSDistTest): + _tlsBackendPort = 10560 + _config_params = ['_tlsBackendPort'] + _config_template = """ + setMaxTCPClientThreads(1) + newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', addXForwardedHeaders=true} + """ + _verboseMode = True + + def callback(request, headersList, fromQueue, toQueue): + + if str(request.question[0].name) == 'a.root-servers.net.': + # do not check headers on health-check queries + return 200, dns.message.make_response(request).to_wire() + + headers = {} + if headersList: + for k,v in headersList: + headers[k] = v + + if not b'x-forwarded-for' in headers: + print("missing X-Forwarded-For") + return 406, b'Missing X-Forwarded-For header' + if not b'x-forwarded-port' in headers: + print("missing X-Forwarded-Port") + return 406, b'Missing X-Forwarded-Port header' + if not b'x-forwarded-proto' in headers: + print("missing X-Forwarded-Proto") + return 406, b'Missing X-Forwarded-Proto header' + + toQueue.put(request, True, 1.0) + response = fromQueue.get(True, 1.0) + if response: + response = copy.copy(response) + response.id = request.id + + return 200, response.to_wire() + + @classmethod + def startResponders(cls): + tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + tlsContext.set_alpn_protocols(["h2"]) + tlsContext.load_cert_chain('server.chain', 'server.key') + + print("Launching DOH responder..") + cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, cls.callback, tlsContext]) + cls._DOHResponder.setDaemon(True) + cls._DOHResponder.start() + + def testXForwarded(self): + """ + Outgoing DOH: X-Forwarded + """ + name = 'x-forwarded-for.outgoing-doh.test.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + expectedResponse.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, expectedResponse) + self.assertEqual(query, receivedQuery) + self.assertEqual(receivedResponse, expectedResponse) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, expectedResponse) + self.assertEqual(query, receivedQuery) + self.assertEqual(receivedResponse, expectedResponse)