import threading
import time
import unittest
+
import clientsubnetoption
+
import dns
import dns.message
+
import libnacl
import libnacl.utils
+import h2.connection
+import h2.events
+import h2.config
+
from eqdnsmessage import AssertEqualDNSMessageMixin
# Python2/3 compatibility hacks
sock.close()
+ @classmethod
+ def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None):
+ # trailingDataResponse=True means "ignore trailing data".
+ # Other values are either False (meaning "raise an exception")
+ # or are interpreted as a response RCODE for queries with trailing data.
+ # callback is invoked for every -even healthcheck ones- query and should return a raw response
+ ignoreTrailing = trailingDataResponse is True
+
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+ try:
+ sock.bind(("127.0.0.1", port))
+ except socket.error as e:
+ print("Error binding in the TCP responder: %s" % str(e))
+ sys.exit(1)
+
+ sock.listen(100)
+ if tlsContext:
+ sock = tlsContext.wrap_socket(sock, server_side=True)
+
+ config = h2.config.H2Configuration(client_side=False)
+
+ while True:
+ try:
+ (conn, _) = sock.accept()
+ except ssl.SSLError:
+ continue
+ except ConnectionResetError:
+ continue
+ conn.settimeout(5.0)
+ h2conn = h2.connection.H2Connection(config=config)
+ h2conn.initiate_connection()
+ conn.sendall(h2conn.data_to_send())
+ dnsData = {}
+
+ while True:
+ data = conn.recv(65535)
+ if not data:
+ break
+
+ events = h2conn.receive_data(data)
+ for event in events:
+ if isinstance(event, h2.events.DataReceived):
+ h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id)
+ if not event.stream_id in dnsData:
+ dnsData[event.stream_id] = b''
+ dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data)
+ if event.stream_ended:
+ forceRcode = None
+ status = 200
+ try:
+ request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing)
+ except dns.message.TrailingJunk as e:
+ if trailingDataResponse is False or forceRcode is True:
+ raise
+ print("DOH query with trailing data, synthesizing response")
+ request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True)
+ forceRcode = trailingDataResponse
+
+ if callback:
+ status, wire = callback(request)
+ else:
+ response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
+ if response:
+ wire = response.to_wire(max_size=65535)
+
+ if not wire:
+ conn.close()
+ conn = None
+ break
+
+ headers = [
+ (':status', str(status)),
+ ('content-length', str(len(wire))),
+ ('content-type', 'application/dns-message'),
+ ]
+ h2conn.send_headers(stream_id=event.stream_id, headers=headers)
+ h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True)
+
+ data_to_send = h2conn.data_to_send()
+ if data_to_send:
+ conn.sendall(data_to_send)
+
+ if conn is None:
+ break
+
+ if conn is not None:
+ conn.close()
+
+ sock.close()
+
@classmethod
def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
if useQueue and response is not None:
--- /dev/null
+#!/usr/bin/env python
+import dns
+import requests
+import ssl
+import threading
+import time
+
+from dnsdisttests import DNSDistTest
+
+class OutgoingDOHTests(object):
+
+ _webTimeout = 2.0
+ _webServerPort = 8083
+ _webServerBasicAuthPassword = 'secret'
+ _webServerAPIKey = 'apisecret'
+
+ def checkOnlyDOHResponderHit(self, numberOfDOHQueries=1):
+ self.assertNotIn('UDP Responder', self._responsesCounter)
+ self.assertNotIn('TCP Responder', self._responsesCounter)
+ self.assertNotIn('TLS Responder', self._responsesCounter)
+ self.assertEqual(self._responsesCounter['DOH Responder'], numberOfDOHQueries)
+
+ def getServerStat(self, key):
+ headers = {'x-api-key': self._webServerAPIKey}
+ url = 'http://127.0.0.1:' + str(self._webServerPort) + '/api/v1/servers/localhost'
+ r = requests.get(url, headers=headers, timeout=self._webTimeout)
+ self.assertTrue(r)
+ self.assertEqual(r.status_code, 200)
+ self.assertTrue(r.json())
+ content = r.json()
+ self.assertTrue(len(content['servers']), 1)
+ server = content['servers'][0]
+ self.assertIn(key, server)
+ return server[key]
+
+ def testUDP(self):
+ """
+ Outgoing DOH: UDP query is sent via DOH
+ """
+ name = 'udp.outgoing-doh.test.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ expectedResponse = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 60,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ expectedResponse.answer.append(rrset)
+
+ numberOfUDPQueries = 10
+ for _ in range(numberOfUDPQueries):
+ (receivedQuery, receivedResponse) = self.sendUDPQuery(query, expectedResponse)
+ self.assertEqual(query, receivedQuery)
+ self.assertEqual(receivedResponse, expectedResponse)
+
+ # there was one TCP query in testTCP (below, but before in alphabetical order)
+ numberOfQueries = numberOfUDPQueries + 1
+ self.checkOnlyDOHResponderHit(numberOfUDPQueries)
+
+ self.assertEqual(self.getServerStat('tcpNewConnections'), 1)
+ self.assertEqual(self.getServerStat('tcpReusedConnections'), numberOfQueries - 1)
+ self.assertEqual(self.getServerStat('tlsResumptions'), 0)
+
+ def testTCP(self):
+ """
+ Outgoing DOH: TCP query is sent via DOH
+ """
+ name = 'tcp.outgoing-doh.test.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ expectedResponse = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 60,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ expectedResponse.answer.append(rrset)
+
+ (receivedQuery, receivedResponse) = self.sendTCPQuery(query, expectedResponse)
+ self.assertEqual(query, receivedQuery)
+ self.assertEqual(receivedResponse, expectedResponse)
+ self.checkOnlyDOHResponderHit()
+ self.assertEqual(self.getServerStat('tcpNewConnections'), 1)
+ self.assertEqual(self.getServerStat('tcpReusedConnections'), 0)
+ self.assertEqual(self.getServerStat('tlsResumptions'), 0)
+
+class BrokenOutgoingDOHTests(object):
+
+ _webTimeout = 2.0
+ _webServerPort = 8083
+ _webServerBasicAuthPassword = 'secret'
+ _webServerAPIKey = 'apisecret'
+
+ def checkNoResponderHit(self):
+ self.assertNotIn('UDP Responder', self._responsesCounter)
+ self.assertNotIn('TCP Responder', self._responsesCounter)
+ self.assertNotIn('TLS Responder', self._responsesCounter)
+ self.assertNotIn('DOH Responder', self._responsesCounter)
+
+ def testUDP(self):
+ """
+ Outgoing DOH (broken): UDP query is sent via DOH
+ """
+ name = 'udp.broken-outgoing-doh.test.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+
+ (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+ self.assertEqual(receivedResponse, None)
+ self.checkNoResponderHit()
+
+ def testTCP(self):
+ """
+ Outgoing DOH (broken): TCP query is sent via DOH
+ """
+ name = 'tcp.broken-outgoing-doh.test.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ expectedResponse = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 60,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ expectedResponse.answer.append(rrset)
+
+ (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+ self.assertEqual(receivedResponse, None)
+ self.checkNoResponderHit()
+
+class OutgoingDOHBrokenResponsesTests(object):
+
+ _webTimeout = 2.0
+ _webServerPort = 8083
+ _webServerBasicAuthPassword = 'secret'
+ _webServerAPIKey = 'apisecret'
+
+ def testUDP(self):
+ """
+ Outgoing DOH (broken responses): UDP query is sent via DOH
+ """
+ name = '500-status.broken-responses.outgoing-doh.test.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+
+ (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+ print(receivedResponse)
+ self.assertEqual(receivedResponse, None)
+
+ name = 'invalid-dns-payload.broken-responses.outgoing-doh.test.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+
+ (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+ self.assertEqual(receivedResponse, None)
+
+ name = 'closing-connection-id.broken-responses.outgoing-doh.test.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+
+ (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+ self.assertEqual(receivedResponse, None)
+
+class TestOutgoingDOHOpenSSL(DNSDistTest, OutgoingDOHTests):
+ _tlsBackendPort = 10543
+ _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+ _config_template = """
+ setMaxTCPClientThreads(1)
+ newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query'}:setUp()
+ webserver("127.0.0.1:%s")
+ setWebserverConfig({password="%s", apiKey="%s"})
+ """
+
+ @classmethod
+ def startResponders(cls):
+ tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ tlsContext.set_alpn_protocols(["h2"])
+ tlsContext.load_cert_chain('server.chain', 'server.key')
+
+ print("Launching DOH responder..")
+ cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+ cls._DOHResponder.setDaemon(True)
+ cls._DOHResponder.start()
+
+class TestOutgoingDOHGnuTLS(DNSDistTest, OutgoingDOHTests):
+ _tlsBackendPort = 10544
+ _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+ _config_template = """
+ setMaxTCPClientThreads(1)
+ newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query'}:setUp()
+ webserver("127.0.0.1:%s")
+ setWebserverConfig({password="%s", apiKey="%s"})
+ """
+
+ @classmethod
+ def startResponders(cls):
+ tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ tlsContext.load_cert_chain('server.chain', 'server.key')
+ tlsContext.keylog_filename = "/tmp/keys"
+
+ print("Launching DOH responder..")
+ cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+ cls._DOHResponder.setDaemon(True)
+ cls._DOHResponder.start()
+
+class TestOutgoingDOHOpenSSLWrongCertName(DNSDistTest, BrokenOutgoingDOHTests):
+ _tlsBackendPort = 10545
+ _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+ _config_template = """
+ setMaxTCPClientThreads(1)
+ newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=true, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query'}:setUp()
+ webserver("127.0.0.1:%s")
+ setWebserverConfig({password="%s", apiKey="%s"})
+ """
+
+ @classmethod
+ def startResponders(cls):
+ tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ tlsContext.load_cert_chain('server.chain', 'server.key')
+
+ print("Launching DOH responder..")
+ cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+ cls._DOHResponder.setDaemon(True)
+ cls._DOHResponder.start()
+
+class TestOutgoingDOHGnuTLSWrongCertName(DNSDistTest, BrokenOutgoingDOHTests):
+ _tlsBackendPort = 10546
+ _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+ _config_template = """
+ setMaxTCPClientThreads(1)
+ newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query'}:setUp()
+ webserver("127.0.0.1:%s")
+ setWebserverConfig({password="%s", apiKey="%s"})
+ """
+
+ @classmethod
+ def startResponders(cls):
+ tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ tlsContext.load_cert_chain('server.chain', 'server.key')
+
+ print("Launching DOH responder..")
+ cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+ cls._DOHResponder.setDaemon(True)
+ cls._DOHResponder.start()
+
+class TestOutgoingDOHOpenSSLWrongCertNameButNoCheck(DNSDistTest, OutgoingDOHTests):
+ _tlsBackendPort = 10547
+ _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+ _config_template = """
+ setMaxTCPClientThreads(1)
+ newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query'}:setUp()
+ webserver("127.0.0.1:%s")
+ setWebserverConfig({password="%s", apiKey="%s"})
+ """
+
+ @classmethod
+ def startResponders(cls):
+ tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ tlsContext.load_cert_chain('server.chain', 'server.key')
+
+ print("Launching DOH responder..")
+ cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+ cls._DOHResponder.setDaemon(True)
+ cls._DOHResponder.start()
+
+class TestOutgoingDOHGnuTLSWrongCertNameButNoCheck(DNSDistTest, OutgoingDOHTests):
+ _tlsBackendPort = 10548
+ _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+ _config_template = """
+ setMaxTCPClientThreads(1)
+ newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query'}:setUp()
+ webserver("127.0.0.1:%s")
+ setWebserverConfig({password="%s", apiKey="%s"})
+ """
+
+ @classmethod
+ def startResponders(cls):
+ tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ tlsContext.load_cert_chain('server.chain', 'server.key')
+
+ print("Launching DOH responder..")
+ cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+ cls._DOHResponder.setDaemon(True)
+ cls._DOHResponder.start()
+
+class TestOutgoingDOHBrokenResponsesOpenSSL(DNSDistTest, OutgoingDOHBrokenResponsesTests):
+ _tlsBackendPort = 10549
+ _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+ _config_template = """
+ setMaxTCPClientThreads(1)
+ newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query'}:setUp()
+ webserver("127.0.0.1:%s")
+ setWebserverConfig({password="%s", apiKey="%s"})
+ """
+
+ def callback(request):
+
+ if str(request.question[0].name) == '500-status.broken-responses.outgoing-doh.test.powerdns.com.':
+ print("returning 500")
+ return 500, b'Server error'
+
+ if str(request.question[0].name) == 'invalid-dns-payload.broken-responses.outgoing-doh.test.powerdns.com.':
+ return 200, b'not DNS'
+
+ if str(request.question[0].name) == 'closing-connection-id.broken-responses.outgoing-doh.test.powerdns.com.':
+ return 200, None
+
+ print("Returning default for %s" % (request.question[0].name))
+ return 200, dns.message.make_response(request).to_wire()
+
+ @classmethod
+ def startResponders(cls):
+ tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ tlsContext.set_alpn_protocols(["h2"])
+ tlsContext.load_cert_chain('server.chain', 'server.key')
+
+ print("Launching DOH responder..")
+ cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, cls.callback, tlsContext])
+ cls._DOHResponder.setDaemon(True)
+ cls._DOHResponder.start()
+
+class TestOutgoingDOHBrokenResponsesGnuTLS(DNSDistTest, OutgoingDOHBrokenResponsesTests):
+ _tlsBackendPort = 10550
+ _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPassword', '_webServerAPIKey']
+ _config_template = """
+ setMaxTCPClientThreads(1)
+ newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query'}:setUp()
+ webserver("127.0.0.1:%s")
+ setWebserverConfig({password="%s", apiKey="%s"})
+ """
+
+ def callback(request):
+
+ if str(request.question[0].name) == '500-status.broken-responses.outgoing-doh.test.powerdns.com.':
+ print("returning 500")
+ return 500, b'Server error'
+
+ if str(request.question[0].name) == 'invalid-dns-payload.broken-responses.outgoing-doh.test.powerdns.com.':
+ return 200, b'not DNS'
+
+ if str(request.question[0].name) == 'closing-connection-id.broken-responses.outgoing-doh.test.powerdns.com.':
+ return 200, None
+
+ print("Returning default for %s" % (request.question[0].name))
+ return 200, dns.message.make_response(request).to_wire()
+
+ @classmethod
+ def startResponders(cls):
+ tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ tlsContext.set_alpn_protocols(["h2"])
+ tlsContext.load_cert_chain('server.chain', 'server.key')
+
+ print("Launching DOH responder..")
+ cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, cls.callback, tlsContext])
+ cls._DOHResponder.setDaemon(True)
+ cls._DOHResponder.start()