From: Remi Gacogne Date: Fri, 27 Aug 2021 14:54:03 +0000 (+0200) Subject: dnsdist: Add regression tests for DoH between dnsdist and the backend X-Git-Tag: dnsdist-1.7.0-alpha1~23^2~22 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=9d71a0cf24fc632fc3a1bf641c0c9609ee982824;p=thirdparty%2Fpdns.git dnsdist: Add regression tests for DoH between dnsdist and the backend --- diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 7284122ec7..4647b92794 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -10,12 +10,19 @@ import sys import threading import time import unittest + import clientsubnetoption + import dns import dns.message + import libnacl import libnacl.utils +import h2.connection +import h2.events +import h2.config + from eqdnsmessage import AssertEqualDNSMessageMixin # Python2/3 compatibility hacks @@ -311,6 +318,98 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): sock.close() + @classmethod + def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None): + # trailingDataResponse=True means "ignore trailing data". + # Other values are either False (meaning "raise an exception") + # or are interpreted as a response RCODE for queries with trailing data. + # callback is invoked for every -even healthcheck ones- query and should return a raw response + ignoreTrailing = trailingDataResponse is True + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + try: + sock.bind(("127.0.0.1", port)) + except socket.error as e: + print("Error binding in the TCP responder: %s" % str(e)) + sys.exit(1) + + sock.listen(100) + if tlsContext: + sock = tlsContext.wrap_socket(sock, server_side=True) + + config = h2.config.H2Configuration(client_side=False) + + while True: + try: + (conn, _) = sock.accept() + except ssl.SSLError: + continue + except ConnectionResetError: + continue + conn.settimeout(5.0) + h2conn = h2.connection.H2Connection(config=config) + h2conn.initiate_connection() + conn.sendall(h2conn.data_to_send()) + dnsData = {} + + while True: + data = conn.recv(65535) + if not data: + break + + events = h2conn.receive_data(data) + for event in events: + if isinstance(event, h2.events.DataReceived): + h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id) + if not event.stream_id in dnsData: + dnsData[event.stream_id] = b'' + dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data) + if event.stream_ended: + forceRcode = None + status = 200 + try: + request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing) + except dns.message.TrailingJunk as e: + if trailingDataResponse is False or forceRcode is True: + raise + print("DOH query with trailing data, synthesizing response") + request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True) + forceRcode = trailingDataResponse + + if callback: + status, wire = callback(request) + else: + response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) + if response: + wire = response.to_wire(max_size=65535) + + if not wire: + conn.close() + conn = None + break + + headers = [ + (':status', str(status)), + ('content-length', str(len(wire))), + ('content-type', 'application/dns-message'), + ] + h2conn.send_headers(stream_id=event.stream_id, headers=headers) + h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True) + + data_to_send = h2conn.data_to_send() + if data_to_send: + conn.sendall(data_to_send) + + if conn is None: + break + + if conn is not None: + conn.close() + + sock.close() + @classmethod def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False): if useQueue and response is not None: diff --git a/regression-tests.dnsdist/requirements.txt b/regression-tests.dnsdist/requirements.txt index afa9fa76ec..fd27b154cb 100644 --- a/regression-tests.dnsdist/requirements.txt +++ b/regression-tests.dnsdist/requirements.txt @@ -8,3 +8,4 @@ future>=0.17.1 pycurl>=7.43.0 lmdb>=0.95 cdbx==0.1.2 +h2>=4.0.0 diff --git a/regression-tests.dnsdist/test_OutgoingDOH.py b/regression-tests.dnsdist/test_OutgoingDOH.py new file mode 100644 index 0000000000..28a579c27f --- /dev/null +++ b/regression-tests.dnsdist/test_OutgoingDOH.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python +import dns +import requests +import ssl +import threading +import time + +from dnsdisttests import DNSDistTest + +class OutgoingDOHTests(object): + + _webTimeout = 2.0 + _webServerPort = 8083 + _webServerBasicAuthPassword = 'secret' + _webServerAPIKey = 'apisecret' + + def checkOnlyDOHResponderHit(self, numberOfDOHQueries=1): + self.assertNotIn('UDP Responder', self._responsesCounter) + self.assertNotIn('TCP Responder', self._responsesCounter) + self.assertNotIn('TLS Responder', self._responsesCounter) + self.assertEqual(self._responsesCounter['DOH Responder'], numberOfDOHQueries) + + def getServerStat(self, key): + headers = {'x-api-key': self._webServerAPIKey} + url = 'http://127.0.0.1:' + str(self._webServerPort) + '/api/v1/servers/localhost' + r = requests.get(url, headers=headers, timeout=self._webTimeout) + self.assertTrue(r) + self.assertEqual(r.status_code, 200) + self.assertTrue(r.json()) + content = r.json() + self.assertTrue(len(content['servers']), 1) + server = content['servers'][0] + self.assertIn(key, server) + return server[key] + + def testUDP(self): + """ + Outgoing DOH: UDP query is sent via DOH + """ + name = 'udp.outgoing-doh.test.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + expectedResponse.answer.append(rrset) + + numberOfUDPQueries = 10 + for _ in range(numberOfUDPQueries): + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, expectedResponse) + self.assertEqual(query, receivedQuery) + self.assertEqual(receivedResponse, expectedResponse) + + # there was one TCP query in testTCP (below, but before in alphabetical order) + numberOfQueries = numberOfUDPQueries + 1 + self.checkOnlyDOHResponderHit(numberOfUDPQueries) + + self.assertEqual(self.getServerStat('tcpNewConnections'), 1) + self.assertEqual(self.getServerStat('tcpReusedConnections'), numberOfQueries - 1) + self.assertEqual(self.getServerStat('tlsResumptions'), 0) + + def testTCP(self): + """ + Outgoing DOH: TCP query is sent via DOH + """ + name = 'tcp.outgoing-doh.test.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + expectedResponse.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, expectedResponse) + self.assertEqual(query, receivedQuery) + self.assertEqual(receivedResponse, expectedResponse) + self.checkOnlyDOHResponderHit() + self.assertEqual(self.getServerStat('tcpNewConnections'), 1) + self.assertEqual(self.getServerStat('tcpReusedConnections'), 0) + self.assertEqual(self.getServerStat('tlsResumptions'), 0) + +class BrokenOutgoingDOHTests(object): + + _webTimeout = 2.0 + _webServerPort = 8083 + _webServerBasicAuthPassword = 'secret' + _webServerAPIKey = 'apisecret' + + def checkNoResponderHit(self): + self.assertNotIn('UDP Responder', self._responsesCounter) + self.assertNotIn('TCP Responder', self._responsesCounter) + self.assertNotIn('TLS Responder', self._responsesCounter) + self.assertNotIn('DOH Responder', self._responsesCounter) + + def testUDP(self): + """ + Outgoing DOH (broken): UDP query is sent via DOH + """ + name = 'udp.broken-outgoing-doh.test.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + + (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) + self.assertEqual(receivedResponse, None) + self.checkNoResponderHit() + + def testTCP(self): + """ + Outgoing DOH (broken): TCP query is sent via DOH + """ + name = 'tcp.broken-outgoing-doh.test.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + expectedResponse.answer.append(rrset) + + (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False) + self.assertEqual(receivedResponse, None) + self.checkNoResponderHit() + +class OutgoingDOHBrokenResponsesTests(object): + + _webTimeout = 2.0 + _webServerPort = 8083 + _webServerBasicAuthPassword = 'secret' + _webServerAPIKey = 'apisecret' + + def testUDP(self): + """ + Outgoing DOH (broken responses): UDP query is sent via DOH + """ + name = '500-status.broken-responses.outgoing-doh.test.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + + (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) + print(receivedResponse) + self.assertEqual(receivedResponse, None) + + name = 'invalid-dns-payload.broken-responses.outgoing-doh.test.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + + (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) + self.assertEqual(receivedResponse, None) + + name = 'closing-connection-id.broken-responses.outgoing-doh.test.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + + (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) + self.assertEqual(receivedResponse, None) + +class TestOutgoingDOHOpenSSL(DNSDistTest, OutgoingDOHTests): + _tlsBackendPort = 10543 + _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey'] + _config_template = """ + setMaxTCPClientThreads(1) + newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query'}:setUp() + webserver("127.0.0.1:%s") + setWebserverConfig({password="%s", apiKey="%s"}) + """ + + @classmethod + def startResponders(cls): + tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + tlsContext.set_alpn_protocols(["h2"]) + tlsContext.load_cert_chain('server.chain', 'server.key') + + print("Launching DOH responder..") + cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext]) + cls._DOHResponder.setDaemon(True) + cls._DOHResponder.start() + +class TestOutgoingDOHGnuTLS(DNSDistTest, OutgoingDOHTests): + _tlsBackendPort = 10544 + _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey'] + _config_template = """ + setMaxTCPClientThreads(1) + newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query'}:setUp() + webserver("127.0.0.1:%s") + setWebserverConfig({password="%s", apiKey="%s"}) + """ + + @classmethod + def startResponders(cls): + tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + tlsContext.load_cert_chain('server.chain', 'server.key') + tlsContext.keylog_filename = "/tmp/keys" + + print("Launching DOH responder..") + cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext]) + cls._DOHResponder.setDaemon(True) + cls._DOHResponder.start() + +class TestOutgoingDOHOpenSSLWrongCertName(DNSDistTest, BrokenOutgoingDOHTests): + _tlsBackendPort = 10545 + _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey'] + _config_template = """ + setMaxTCPClientThreads(1) + newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=true, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query'}:setUp() + webserver("127.0.0.1:%s") + setWebserverConfig({password="%s", apiKey="%s"}) + """ + + @classmethod + def startResponders(cls): + tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + tlsContext.load_cert_chain('server.chain', 'server.key') + + print("Launching DOH responder..") + cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext]) + cls._DOHResponder.setDaemon(True) + cls._DOHResponder.start() + +class TestOutgoingDOHGnuTLSWrongCertName(DNSDistTest, BrokenOutgoingDOHTests): + _tlsBackendPort = 10546 + _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey'] + _config_template = """ + setMaxTCPClientThreads(1) + newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query'}:setUp() + webserver("127.0.0.1:%s") + setWebserverConfig({password="%s", apiKey="%s"}) + """ + + @classmethod + def startResponders(cls): + tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + tlsContext.load_cert_chain('server.chain', 'server.key') + + print("Launching DOH responder..") + cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext]) + cls._DOHResponder.setDaemon(True) + cls._DOHResponder.start() + +class TestOutgoingDOHOpenSSLWrongCertNameButNoCheck(DNSDistTest, OutgoingDOHTests): + _tlsBackendPort = 10547 + _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey'] + _config_template = """ + setMaxTCPClientThreads(1) + newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query'}:setUp() + webserver("127.0.0.1:%s") + setWebserverConfig({password="%s", apiKey="%s"}) + """ + + @classmethod + def startResponders(cls): + tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + tlsContext.load_cert_chain('server.chain', 'server.key') + + print("Launching DOH responder..") + cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext]) + cls._DOHResponder.setDaemon(True) + cls._DOHResponder.start() + +class TestOutgoingDOHGnuTLSWrongCertNameButNoCheck(DNSDistTest, OutgoingDOHTests): + _tlsBackendPort = 10548 + _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey'] + _config_template = """ + setMaxTCPClientThreads(1) + newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query'}:setUp() + webserver("127.0.0.1:%s") + setWebserverConfig({password="%s", apiKey="%s"}) + """ + + @classmethod + def startResponders(cls): + tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + tlsContext.load_cert_chain('server.chain', 'server.key') + + print("Launching DOH responder..") + cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext]) + cls._DOHResponder.setDaemon(True) + cls._DOHResponder.start() + +class TestOutgoingDOHBrokenResponsesOpenSSL(DNSDistTest, OutgoingDOHBrokenResponsesTests): + _tlsBackendPort = 10549 + _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey'] + _config_template = """ + setMaxTCPClientThreads(1) + newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query'}:setUp() + webserver("127.0.0.1:%s") + setWebserverConfig({password="%s", apiKey="%s"}) + """ + + def callback(request): + + if str(request.question[0].name) == '500-status.broken-responses.outgoing-doh.test.powerdns.com.': + print("returning 500") + return 500, b'Server error' + + if str(request.question[0].name) == 'invalid-dns-payload.broken-responses.outgoing-doh.test.powerdns.com.': + return 200, b'not DNS' + + if str(request.question[0].name) == 'closing-connection-id.broken-responses.outgoing-doh.test.powerdns.com.': + return 200, None + + print("Returning default for %s" % (request.question[0].name)) + return 200, dns.message.make_response(request).to_wire() + + @classmethod + def startResponders(cls): + tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + tlsContext.set_alpn_protocols(["h2"]) + tlsContext.load_cert_chain('server.chain', 'server.key') + + print("Launching DOH responder..") + cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, cls.callback, tlsContext]) + cls._DOHResponder.setDaemon(True) + cls._DOHResponder.start() + +class TestOutgoingDOHBrokenResponsesGnuTLS(DNSDistTest, OutgoingDOHBrokenResponsesTests): + _tlsBackendPort = 10550 + _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey'] + _config_template = """ + setMaxTCPClientThreads(1) + newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query'}:setUp() + webserver("127.0.0.1:%s") + setWebserverConfig({password="%s", apiKey="%s"}) + """ + + def callback(request): + + if str(request.question[0].name) == '500-status.broken-responses.outgoing-doh.test.powerdns.com.': + print("returning 500") + return 500, b'Server error' + + if str(request.question[0].name) == 'invalid-dns-payload.broken-responses.outgoing-doh.test.powerdns.com.': + return 200, b'not DNS' + + if str(request.question[0].name) == 'closing-connection-id.broken-responses.outgoing-doh.test.powerdns.com.': + return 200, None + + print("Returning default for %s" % (request.question[0].name)) + return 200, dns.message.make_response(request).to_wire() + + @classmethod + def startResponders(cls): + tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + tlsContext.set_alpn_protocols(["h2"]) + tlsContext.load_cert_chain('server.chain', 'server.key') + + print("Launching DOH responder..") + cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, cls.callback, tlsContext]) + cls._DOHResponder.setDaemon(True) + cls._DOHResponder.start()