sock.close()
+ @classmethod
+ def handleDoHConnection(cls, config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol):
+ ignoreTrailing = trailingDataResponse is True
+ h2conn = h2.connection.H2Connection(config=config)
+ h2conn.initiate_connection()
+ conn.sendall(h2conn.data_to_send())
+ dnsData = {}
+
+ if useProxyProtocol:
+ # try to read the entire Proxy Protocol header
+ proxy = ProxyProtocol()
+ header = conn.recv(proxy.HEADER_SIZE)
+ if not header:
+ print('unable to get header')
+ conn.close()
+ return
+
+ if not proxy.parseHeader(header):
+ print('unable to parse header')
+ print(header)
+ conn.close()
+ return
+
+ proxyContent = conn.recv(proxy.contentLen)
+ if not proxyContent:
+ print('unable to get content')
+ conn.close()
+ return
+
+ payload = header + proxyContent
+ toQueue.put(payload, True, cls._queueTimeout)
+
+ # be careful, HTTP/2 headers and data might be in different recv() results
+ requestHeaders = None
+ while True:
+ data = conn.recv(65535)
+ if not data:
+ break
+
+ events = h2conn.receive_data(data)
+ for event in events:
+ if isinstance(event, h2.events.RequestReceived):
+ requestHeaders = event.headers
+ if isinstance(event, h2.events.DataReceived):
+ h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id)
+ if not event.stream_id in dnsData:
+ dnsData[event.stream_id] = b''
+ dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data)
+ if event.stream_ended:
+ forceRcode = None
+ status = 200
+ try:
+ request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing)
+ except dns.message.TrailingJunk as e:
+ if trailingDataResponse is False or forceRcode is True:
+ raise
+ print("DOH query with trailing data, synthesizing response")
+ request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True)
+ forceRcode = trailingDataResponse
+
+ if callback:
+ status, wire = callback(request, requestHeaders, fromQueue, toQueue)
+ else:
+ response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
+ if response:
+ wire = response.to_wire(max_size=65535)
+
+ if not wire:
+ conn.close()
+ conn = None
+ break
+
+ headers = [
+ (':status', str(status)),
+ ('content-length', str(len(wire))),
+ ('content-type', 'application/dns-message'),
+ ]
+ h2conn.send_headers(stream_id=event.stream_id, headers=headers)
+ h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True)
+
+ data_to_send = h2conn.data_to_send()
+ if data_to_send:
+ conn.sendall(data_to_send)
+
+ if conn is None:
+ break
+
+ if conn is not None:
+ conn.close()
+
@classmethod
def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False):
# trailingDataResponse=True means "ignore trailing data".
# Other values are either False (meaning "raise an exception")
# or are interpreted as a response RCODE for queries with trailing data.
# callback is invoked for every -even healthcheck ones- query and should return a raw response
- ignoreTrailing = trailingDataResponse is True
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
continue
conn.settimeout(5.0)
- h2conn = h2.connection.H2Connection(config=config)
- h2conn.initiate_connection()
- conn.sendall(h2conn.data_to_send())
- dnsData = {}
-
- if useProxyProtocol:
- # try to read the entire Proxy Protocol header
- proxy = ProxyProtocol()
- header = conn.recv(proxy.HEADER_SIZE)
- if not header:
- print('unable to get header')
- conn.close()
- continue
-
- if not proxy.parseHeader(header):
- print('unable to parse header')
- print(header)
- conn.close()
- continue
-
- proxyContent = conn.recv(proxy.contentLen)
- if not proxyContent:
- print('unable to get content')
- conn.close()
- continue
-
- payload = header + proxyContent
- toQueue.put(payload, True, cls._queueTimeout)
-
- while True:
- data = conn.recv(65535)
- if not data:
- break
-
- events = h2conn.receive_data(data)
- for event in events:
- if isinstance(event, h2.events.DataReceived):
- h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id)
- if not event.stream_id in dnsData:
- dnsData[event.stream_id] = b''
- dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data)
- if event.stream_ended:
- forceRcode = None
- status = 200
- try:
- request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing)
- except dns.message.TrailingJunk as e:
- if trailingDataResponse is False or forceRcode is True:
- raise
- print("DOH query with trailing data, synthesizing response")
- request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True)
- forceRcode = trailingDataResponse
-
- if callback:
- status, wire = callback(request)
- else:
- response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
- if response:
- wire = response.to_wire(max_size=65535)
-
- if not wire:
- conn.close()
- conn = None
- break
-
- headers = [
- (':status', str(status)),
- ('content-length', str(len(wire))),
- ('content-type', 'application/dns-message'),
- ]
- h2conn.send_headers(stream_id=event.stream_id, headers=headers)
- h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True)
-
- data_to_send = h2conn.data_to_send()
- if data_to_send:
- conn.sendall(data_to_send)
-
- if conn is None:
- break
-
- if conn is not None:
- conn.close()
+ thread = threading.Thread(name='DoH Connection Handler',
+ target=cls.handleDoHConnection,
+ args=[config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol])
+ thread.setDaemon(True)
+ thread.start()
sock.close()
#!/usr/bin/env python
+import base64
+import copy
import dns
import requests
import ssl
self.assertNotIn('UDP Responder', self._responsesCounter)
self.assertNotIn('TCP Responder', self._responsesCounter)
self.assertNotIn('TLS Responder', self._responsesCounter)
- self.assertEqual(self._responsesCounter['DOH Responder'], numberOfDOHQueries)
+ self.assertEqual(self._responsesCounter['DoH Connection Handler'], numberOfDOHQueries)
def getServerStat(self, key):
headers = {'x-api-key': self._webServerAPIKey}
(_, receivedResponse) = self.sendTCPQuery(query, useQueue=False, response=None)
self.assertEqual(receivedResponse, expectedResponse)
+ def testZHealthChecks(self):
+ # this test has to run last, as it will mess up the TCP connection counter,
+ # hence the 'Z' in the name
+ self.sendConsoleCommand("getServer(0):setAuto()")
+ time.sleep(2)
+ status = self.sendConsoleCommand("if getServer(0):isUp() then return 'up' else return 'down' end").strip("\n")
+ self.assertEqual(status, 'up')
+
class BrokenOutgoingDOHTests(object):
_webTimeout = 2.0
class TestOutgoingDOHOpenSSL(DNSDistTest, OutgoingDOHTests):
_tlsBackendPort = 10543
- _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed']
+ _tlsProvider = 'openssl'
+ _consoleKey = DNSDistTest.generateConsoleKey()
+ _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
+ _config_params = ['_consoleKeyB64', '_consolePort', '_tlsBackendPort', '_tlsProvider', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed']
_config_template = """
+ setKey("%s")
+ controlSocket("127.0.0.1:%d")
setMaxTCPClientThreads(1)
- newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp()
+ newServer{address="127.0.0.1:%s", tls='%s', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp()
webserver("127.0.0.1:%s")
setWebserverConfig({password="%s", apiKey="%s"})
class TestOutgoingDOHGnuTLS(DNSDistTest, OutgoingDOHTests):
_tlsBackendPort = 10544
- _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed']
+ _tlsProvider = 'gnutls'
+ _consoleKey = DNSDistTest.generateConsoleKey()
+ _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
+ _config_params = ['_consoleKeyB64', '_consolePort', '_tlsBackendPort', '_tlsProvider', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed']
_config_template = """
+ setKey("%s")
+ controlSocket("127.0.0.1:%d")
setMaxTCPClientThreads(1)
- newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp()
+ newServer{address="127.0.0.1:%s", tls='%s', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp()
webserver("127.0.0.1:%s")
setWebserverConfig({password="%s", apiKey="%s"})
class TestOutgoingDOHOpenSSLWrongCertNameButNoCheck(DNSDistTest, OutgoingDOHTests):
_tlsBackendPort = 10547
- _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed']
+ _tlsProvider = 'openssl'
+ _consoleKey = DNSDistTest.generateConsoleKey()
+ _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
+ _config_params = ['_consoleKeyB64', '_consolePort', '_tlsBackendPort', '_tlsProvider', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed']
_config_template = """
+ setKey("%s")
+ controlSocket("127.0.0.1:%d")
setMaxTCPClientThreads(1)
- newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp()
+ newServer{address="127.0.0.1:%s", tls='%s', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp()
webserver("127.0.0.1:%s")
setWebserverConfig({password="%s", apiKey="%s"})
class TestOutgoingDOHGnuTLSWrongCertNameButNoCheck(DNSDistTest, OutgoingDOHTests):
_tlsBackendPort = 10548
- _config_params = ['_tlsBackendPort', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed']
+ _tlsProvider = 'gnutls'
+ _consoleKey = DNSDistTest.generateConsoleKey()
+ _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
+ _config_params = ['_consoleKeyB64', '_consolePort', '_tlsBackendPort', '_tlsProvider', '_webServerPort', '_webServerBasicAuthPasswordHashed', '_webServerAPIKeyHashed']
_config_template = """
+ setKey("%s")
+ controlSocket("127.0.0.1:%d")
setMaxTCPClientThreads(1)
- newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp()
+ newServer{address="127.0.0.1:%s", tls='%s', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com', dohPath='/dns-query', pool={'', 'cache'}}:setUp()
webserver("127.0.0.1:%s")
setWebserverConfig({password="%s", apiKey="%s"})
addAction(SuffixMatchNodeRule(smn), PoolAction('cache'))
"""
- def callback(request):
+ def callback(request, headers, fromQueue, toQueue):
if str(request.question[0].name) == '500-status.broken-responses.outgoing-doh.test.powerdns.com.':
print("returning 500")
"""
_verboseMode = True
- def callback(request):
+ def callback(request, headers, fromQueue, toQueue):
if str(request.question[0].name) == '500-status.broken-responses.outgoing-doh.test.powerdns.com.':
print("returning 500")
self.assertEqual(query, receivedQuery)
self.assertEqual(receivedResponse, expectedResponse)
self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True)
+
+class TestOutgoingDOHXForwarded(DNSDistTest):
+ _tlsBackendPort = 10560
+ _config_params = ['_tlsBackendPort']
+ _config_template = """
+ setMaxTCPClientThreads(1)
+ newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com', dohPath='/dns-query', addXForwardedHeaders=true}
+ """
+ _verboseMode = True
+
+ def callback(request, headersList, fromQueue, toQueue):
+
+ if str(request.question[0].name) == 'a.root-servers.net.':
+ # do not check headers on health-check queries
+ return 200, dns.message.make_response(request).to_wire()
+
+ headers = {}
+ if headersList:
+ for k,v in headersList:
+ headers[k] = v
+
+ if not b'x-forwarded-for' in headers:
+ print("missing X-Forwarded-For")
+ return 406, b'Missing X-Forwarded-For header'
+ if not b'x-forwarded-port' in headers:
+ print("missing X-Forwarded-Port")
+ return 406, b'Missing X-Forwarded-Port header'
+ if not b'x-forwarded-proto' in headers:
+ print("missing X-Forwarded-Proto")
+ return 406, b'Missing X-Forwarded-Proto header'
+
+ toQueue.put(request, True, 1.0)
+ response = fromQueue.get(True, 1.0)
+ if response:
+ response = copy.copy(response)
+ response.id = request.id
+
+ return 200, response.to_wire()
+
+ @classmethod
+ def startResponders(cls):
+ tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ tlsContext.set_alpn_protocols(["h2"])
+ tlsContext.load_cert_chain('server.chain', 'server.key')
+
+ print("Launching DOH responder..")
+ cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, cls.callback, tlsContext])
+ cls._DOHResponder.setDaemon(True)
+ cls._DOHResponder.start()
+
+ def testXForwarded(self):
+ """
+ Outgoing DOH: X-Forwarded
+ """
+ name = 'x-forwarded-for.outgoing-doh.test.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ expectedResponse = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 60,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ expectedResponse.answer.append(rrset)
+
+ (receivedQuery, receivedResponse) = self.sendUDPQuery(query, expectedResponse)
+ self.assertEqual(query, receivedQuery)
+ self.assertEqual(receivedResponse, expectedResponse)
+
+ (receivedQuery, receivedResponse) = self.sendTCPQuery(query, expectedResponse)
+ self.assertEqual(query, receivedQuery)
+ self.assertEqual(receivedResponse, expectedResponse)