From: Remi Gacogne Date: Thu, 16 Dec 2021 16:00:40 +0000 (+0100) Subject: dnsdist: Add regression tests for asynchronous processing X-Git-Tag: dnsdist-1.8.0-rc1~86^2~15 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=fda32c1ca6c3d262bf054894e7343f85610cc383;p=thirdparty%2Fpdns.git dnsdist: Add regression tests for asynchronous processing --- diff --git a/regression-tests.dnsdist/dnsdistdohtests.py b/regression-tests.dnsdist/dnsdistdohtests.py index cd19d095a2..57f0c54879 100644 --- a/regression-tests.dnsdist/dnsdistdohtests.py +++ b/regression-tests.dnsdist/dnsdistdohtests.py @@ -12,110 +12,6 @@ from io import BytesIO @unittest.skipIf('SKIP_DOH_TESTS' in os.environ, 'DNS over HTTPS tests are disabled') class DNSDistDOHTest(DNSDistTest): - @classmethod - def getDOHGetURL(cls, baseurl, query, rawQuery=False): - if rawQuery: - wire = query - else: - wire = query.to_wire() - param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=') - return baseurl + "?dns=" + param - - @classmethod - def openDOHConnection(cls, port, caFile, timeout=2.0): - conn = pycurl.Curl() - conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2) - - conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message", - "Accept: application/dns-message"]) - return conn - - @classmethod - def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None): - url = cls.getDOHGetURL(baseurl, query, rawQuery) - conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout) - response_headers = BytesIO() - #conn.setopt(pycurl.VERBOSE, True) - conn.setopt(pycurl.URL, url) - conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)]) - if useHTTPS: - conn.setopt(pycurl.SSL_VERIFYPEER, 1) - conn.setopt(pycurl.SSL_VERIFYHOST, 2) - if caFile: - conn.setopt(pycurl.CAINFO, caFile) - - conn.setopt(pycurl.HTTPHEADER, customHeaders) - conn.setopt(pycurl.HEADERFUNCTION, response_headers.write) - - if response: - if toQueue: - toQueue.put(response, True, timeout) - else: - cls._toResponderQueue.put(response, True, timeout) - - receivedQuery = None - message = None - cls._response_headers = '' - data = conn.perform_rb() - cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE) - if cls._rcode == 200 and not rawResponse: - message = dns.message.from_wire(data) - elif rawResponse: - message = data - - if useQueue: - if fromQueue: - if not fromQueue.empty(): - receivedQuery = fromQueue.get(True, timeout) - else: - if not cls._fromResponderQueue.empty(): - receivedQuery = cls._fromResponderQueue.get(True, timeout) - - cls._response_headers = response_headers.getvalue() - return (receivedQuery, message) - - @classmethod - def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True): - url = baseurl - conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout) - response_headers = BytesIO() - #conn.setopt(pycurl.VERBOSE, True) - conn.setopt(pycurl.URL, url) - conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)]) - if useHTTPS: - conn.setopt(pycurl.SSL_VERIFYPEER, 1) - conn.setopt(pycurl.SSL_VERIFYHOST, 2) - if caFile: - conn.setopt(pycurl.CAINFO, caFile) - - conn.setopt(pycurl.HTTPHEADER, customHeaders) - conn.setopt(pycurl.HEADERFUNCTION, response_headers.write) - conn.setopt(pycurl.POST, True) - data = query - if not rawQuery: - data = data.to_wire() - - conn.setopt(pycurl.POSTFIELDS, data) - - if response: - cls._toResponderQueue.put(response, True, timeout) - - receivedQuery = None - message = None - cls._response_headers = '' - data = conn.perform_rb() - cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE) - if cls._rcode == 200 and not rawResponse: - message = dns.message.from_wire(data) - elif rawResponse: - message = data - - if useQueue and not cls._fromResponderQueue.empty(): - receivedQuery = cls._fromResponderQueue.get(True, timeout) - - cls._response_headers = response_headers.getvalue() - return (receivedQuery, message) - def getHeaderValue(self, name): for header in self._response_headers.decode().splitlines(False): values = header.split(':') diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index c622769b87..015369a862 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -1,5 +1,6 @@ #!/usr/bin/env python2 +import base64 import copy import errno import os @@ -24,6 +25,9 @@ import h2.connection import h2.events import h2.config +import pycurl +from io import BytesIO + from eqdnsmessage import AssertEqualDNSMessageMixin from proxyprotocol import ProxyProtocol @@ -629,6 +633,14 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): print("queue empty") return message + @classmethod + def sendDOTQuery(cls, port, serverName, query, response, caFile, useQueue=True): + conn = cls.openTLSConnection(port, serverName, caFile) + cls.sendTCPQueryOverConnection(conn, query, response=response) + if useQueue: + return cls.recvTCPResponseOverConnection(conn, useQueue=useQueue) + return None, cls.recvTCPResponseOverConnection(conn, useQueue=useQueue) + @classmethod def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False): message = None @@ -910,3 +922,107 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): proxy.values.sort() values.sort() self.assertEqual(proxy.values, values) + + @classmethod + def getDOHGetURL(cls, baseurl, query, rawQuery=False): + if rawQuery: + wire = query + else: + wire = query.to_wire() + param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=') + return baseurl + "?dns=" + param + + @classmethod + def openDOHConnection(cls, port, caFile, timeout=2.0): + conn = pycurl.Curl() + conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2) + + conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message", + "Accept: application/dns-message"]) + return conn + + @classmethod + def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None): + url = cls.getDOHGetURL(baseurl, query, rawQuery) + conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout) + response_headers = BytesIO() + #conn.setopt(pycurl.VERBOSE, True) + conn.setopt(pycurl.URL, url) + conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)]) + if useHTTPS: + conn.setopt(pycurl.SSL_VERIFYPEER, 1) + conn.setopt(pycurl.SSL_VERIFYHOST, 2) + if caFile: + conn.setopt(pycurl.CAINFO, caFile) + + conn.setopt(pycurl.HTTPHEADER, customHeaders) + conn.setopt(pycurl.HEADERFUNCTION, response_headers.write) + + if response: + if toQueue: + toQueue.put(response, True, timeout) + else: + cls._toResponderQueue.put(response, True, timeout) + + receivedQuery = None + message = None + cls._response_headers = '' + data = conn.perform_rb() + cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE) + if cls._rcode == 200 and not rawResponse: + message = dns.message.from_wire(data) + elif rawResponse: + message = data + + if useQueue: + if fromQueue: + if not fromQueue.empty(): + receivedQuery = fromQueue.get(True, timeout) + else: + if not cls._fromResponderQueue.empty(): + receivedQuery = cls._fromResponderQueue.get(True, timeout) + + cls._response_headers = response_headers.getvalue() + return (receivedQuery, message) + + @classmethod + def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True): + url = baseurl + conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout) + response_headers = BytesIO() + #conn.setopt(pycurl.VERBOSE, True) + conn.setopt(pycurl.URL, url) + conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)]) + if useHTTPS: + conn.setopt(pycurl.SSL_VERIFYPEER, 1) + conn.setopt(pycurl.SSL_VERIFYHOST, 2) + if caFile: + conn.setopt(pycurl.CAINFO, caFile) + + conn.setopt(pycurl.HTTPHEADER, customHeaders) + conn.setopt(pycurl.HEADERFUNCTION, response_headers.write) + conn.setopt(pycurl.POST, True) + data = query + if not rawQuery: + data = data.to_wire() + + conn.setopt(pycurl.POSTFIELDS, data) + + if response: + cls._toResponderQueue.put(response, True, timeout) + + receivedQuery = None + message = None + cls._response_headers = '' + data = conn.perform_rb() + cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE) + if cls._rcode == 200 and not rawResponse: + message = dns.message.from_wire(data) + elif rawResponse: + message = data + + if useQueue and not cls._fromResponderQueue.empty(): + receivedQuery = cls._fromResponderQueue.get(True, timeout) + + cls._response_headers = response_headers.getvalue() + return (receivedQuery, message) diff --git a/regression-tests.dnsdist/test_Async.py b/regression-tests.dnsdist/test_Async.py new file mode 100644 index 0000000000..646d950a1d --- /dev/null +++ b/regression-tests.dnsdist/test_Async.py @@ -0,0 +1,548 @@ +#!/usr/bin/env python + +import os +import socket +import threading +import unittest +import dns +from dnsdisttests import DNSDistTest + +def AsyncResponder(listenPath, responsePath): + # Make sure the socket does not already exist + try: + os.unlink(listenPath) + except OSError: + if os.path.exists(listenPath): + raise + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + try: + sock.bind(listenPath) + except socket.error as e: + print("Error binding in the Asynchronous responder: %s" % str(e)) + sys.exit(1) + + while True: + data, addr = sock.recvfrom(65535) + print("Got message [%d] '%s' from %s" % (len(data), data, addr)) + if not data: + break + + request = dns.message.from_wire(data) + reply = str(request.id) + ' ' + if str(request.question[0].name).startswith('accept-then-refuse'): + if request.flags & dns.flags.QR: + reply = reply + 'refuse' + else: + reply = reply + 'accept' + elif str(request.question[0].name).startswith('accept-then-drop'): + if request.flags & dns.flags.QR: + reply = reply + 'drop' + else: + reply = reply + 'accept' + elif str(request.question[0].name).startswith('accept-then-custom'): + if request.flags & dns.flags.QR: + reply = reply + 'custom' + else: + reply = reply + 'accept' + elif str(request.question[0].name).startswith('timeout-then-accept'): + if request.flags & dns.flags.QR: + reply = reply + 'accept' + else: + # no response + continue + elif str(request.question[0].name).startswith('accept-then-timeout'): + if request.flags & dns.flags.QR: + # no response + continue + else: + reply = reply + 'accept' + elif str(request.question[0].name).startswith('accept'): + reply = reply + 'accept' + elif str(request.question[0].name).startswith('refuse'): + reply = reply + 'refuse' + elif str(request.question[0].name).startswith('drop'): + reply = reply + 'drop' + elif str(request.question[0].name).startswith('custom'): + reply = reply + 'custom' + elif str(request.question[0].name).startswith('timeout'): + # no response + continue + else: + reply = reply + 'invalid' + + remote = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + remote.connect(responsePath) + remote.send(reply.encode()) + print("Sent [%d] '%s' to %s" % (len(reply), reply, responsePath)) + + sock.close() + +asyncResponderSocketPath = '/tmp/async-responder.sock' +dnsdistSocketPath = '/tmp/dnsdist.sock' +asyncResponder = threading.Thread(name='Asynchronous Responder', target=AsyncResponder, args=[asyncResponderSocketPath, dnsdistSocketPath]) +asyncResponder.setDaemon(True) +asyncResponder.start() + +@unittest.skipIf('SKIP_DOH_TESTS' in os.environ, 'DNS over HTTPS tests are disabled') +class TestAsync(DNSDistTest): + + _serverKey = 'server.key' + _serverCert = 'server.chain' + _serverName = 'tls.tests.dnsdist.org' + _caCert = 'ca.pem' + _tlsServerPort = 8453 + _dohServerPort = 8443 + _dohBaseURL = ("https://%s:%d/" % (_serverName, _dohServerPort)) + + _config_template = """ + newServer{address="127.0.0.1:%s", pool={'', 'cache'}} + newServer{address="127.0.0.1:%s", pool="tcp-only", tcpOnly=true } + + addTLSLocal("127.0.0.1:%s", "%s", "%s", { provider="openssl" }) + addDOHLocal("127.0.0.1:%s", "%s", "%s", { "/"}) + + local ffi = require("ffi") + local C = ffi.C + + local filteringTagName = 'filtering' + local filteringTagValue = 'pass' + local asyncID = 0 + + pc = newPacketCache(100) + getPool('cache'):setCache(pc) + + function gotAsyncResponse(endpointID, message, from) + + print('Got async response '..message) + local parts = {} + for part in message:gmatch("%%S+") do table.insert(parts, part) end + if #parts ~= 2 then + print('Invalid message') + return + end + local queryID = tonumber(parts[1]) + if parts[2] == 'accept' then + print('accepting') + C.dnsdist_ffi_resume_from_async(asyncID, queryID, filteringTagName, #filteringTagName, filteringTagValue, #filteringTagValue, true) + return + end + if parts[2] == 'refuse' then + print('refusing') + C.dnsdist_ffi_set_rcode_from_async(asyncID, queryID, DNSRCode.REFUSED, true) + return + end + if parts[2] == 'drop' then + print('dropping') + C.dnsdist_ffi_drop_from_async(asyncID, queryID) + return + end + if parts[2] == 'custom' then + print('sending a custom response') + local raw = '\\000\\000\\128\\129\\000\\001\\000\\000\\000\\000\\000\\001\\006custom\\005async\\005tests\\008powerdns\\003com\\000\\000\\001\\000\\001\\000\\000\\041\\002\\000\\000\\000\\128\\000\\000\\000' + C.dnsdist_ffi_set_answer_from_async(asyncID, queryID, raw, #raw) + return + end + end + + local asyncResponderEndpoint = newNetworkEndpoint('%s') + local listener = newNetworkListener() + listener:addUnixListeningEndpoint('%s', 0, gotAsyncResponse) + listener:start() + + function passQueryToAsyncFilter(dq) + print('in passQueryToAsyncFilter') + local timeout = 500 -- 500 ms + + local queryPtr = C.dnsdist_ffi_dnsquestion_get_header(dq) + local querySize = C.dnsdist_ffi_dnsquestion_get_len(dq) + + -- we need to take a copy, as we can no longer touch that data after calling set_async + local buffer = ffi.string(queryPtr, querySize) + + print(C.dnsdist_ffi_dnsquestion_set_async(dq, asyncID, C.dnsdist_ffi_dnsquestion_get_id(dq), timeout)) + asyncResponderEndpoint:send(buffer) + + return DNSAction.Allow + end + + function passResponseToAsyncFilter(dr) + print('in passResponseToAsyncFilter') + local timeout = 500 -- 500 ms + + local responsePtr = C.dnsdist_ffi_dnsquestion_get_header(dr) + local responseSize = C.dnsdist_ffi_dnsquestion_get_len(dr) + + -- we need to take a copy, as we can no longer touch that data after calling set_async + local buffer = ffi.string(responsePtr, responseSize) + + print(C.dnsdist_ffi_dnsresponse_set_async(dr, asyncID, C.dnsdist_ffi_dnsquestion_get_id(dr), timeout)) + asyncResponderEndpoint:send(buffer) + + return DNSResponseAction.Allow + end + + -- this only matters for tests actually reaching the backend + addAction('tcp-only.async.tests.powerdns.com', PoolAction('tcp-only', false)) + addAction('cache.async.tests.powerdns.com', PoolAction('cache', false)) + addAction(AllRule(), LuaFFIAction(passQueryToAsyncFilter)) + addCacheHitResponseAction(AllRule(), LuaFFIResponseAction(passResponseToAsyncFilter)) + addResponseAction(AllRule(), LuaFFIResponseAction(passResponseToAsyncFilter)) + """ + _asyncResponderSocketPath = asyncResponderSocketPath + _dnsdistSocketPath = dnsdistSocketPath + _config_params = ['_testServerPort', '_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey', '_dohServerPort', '_serverCert', '_serverKey', '_asyncResponderSocketPath', '_dnsdistSocketPath'] + _verboseMode = True + + def testPass(self): + """ + Async: Accept + """ + for name in ['accept.async.tests.powerdns.com.', 'accept.tcp-only.async.tests.powerdns.com.']: + query = dns.message.make_query(name, 'A', 'IN') + + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '192.0.2.1') + response.answer.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (receivedQuery, receivedResponse) = sender(query, response) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(response, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=response, caFile=self._caCert) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(response, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(response, receivedResponse) + + def testPassCached(self): + """ + Async: Accept (cached) + """ + name = 'accept.cache.async.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '192.0.2.1') + response.answer.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + # first time to fill the cache + sender = getattr(self, method) + (receivedQuery, receivedResponse) = sender(query, response) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(response, receivedResponse) + # second time from the cache + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None, useQueue=False) + self.assertEqual(response, receivedResponse) + + (_, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=None, useQueue=False, caFile=self._caCert) + self.assertEqual(response, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(response, receivedResponse) + + (_, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=None, useQueue=False, caFile=self._caCert) + print(response) + print(receivedResponse) + self.assertEqual(response, receivedResponse) + + def testTimeoutThenAccept(self): + """ + Async: Timeout then accept + """ + for name in ['timeout-then-accept.async.tests.powerdns.com.', 'timeout-then-accept.tcp-only.async.tests.powerdns.com.']: + query = dns.message.make_query(name, 'A', 'IN') + + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '192.0.2.1') + response.answer.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (receivedQuery, receivedResponse) = sender(query, response) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(response, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=response, caFile=self._caCert) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(response, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(response, receivedResponse) + + def testAcceptThenTimeout(self): + """ + Async: Accept then timeout + """ + for name in ['accept-then-timeout.async.tests.powerdns.com.', 'accept-then-timeout.tcp-only.async.tests.powerdns.com.']: + query = dns.message.make_query(name, 'A', 'IN') + + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '192.0.2.1') + response.answer.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (receivedQuery, receivedResponse) = sender(query, response) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(response, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=response, caFile=self._caCert) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(response, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(response, receivedResponse) + + def testAcceptThenRefuse(self): + """ + Async: Accept then refuse + """ + for name in ['accept-then-refuse.async.tests.powerdns.com.', 'accept-then-refuse.tcp-only.async.tests.powerdns.com.']: + query = dns.message.make_query(name, 'A', 'IN') + + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '192.0.2.1') + response.answer.append(rrset) + + expectedResponse = dns.message.make_response(query) + expectedResponse.flags |= dns.flags.RA + expectedResponse.set_rcode(dns.rcode.REFUSED) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (receivedQuery, receivedResponse) = sender(query, response) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(expectedResponse, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=response, caFile=self._caCert) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(expectedResponse, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(expectedResponse, receivedResponse) + + def testAcceptThenCustom(self): + """ + Async: Accept then custom + """ + for name in ['accept-then-custom.async.tests.powerdns.com.', 'accept-then-custom.tcp-only.async.tests.powerdns.com.']: + query = dns.message.make_query(name, 'A', 'IN') + + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '192.0.2.1') + response.answer.append(rrset) + + # easier to get the same custom response to everyone, sorry! + expectedQuery = dns.message.make_query('custom.async.tests.powerdns.com.', 'A', 'IN') + expectedQuery.id = query.id + expectedResponse = dns.message.make_response(expectedQuery) + expectedResponse.flags |= dns.flags.RA + expectedResponse.set_rcode(dns.rcode.FORMERR) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (receivedQuery, receivedResponse) = sender(query, response) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(expectedResponse, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=response, caFile=self._caCert) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(expectedResponse, receivedResponse) + + (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(expectedResponse, receivedResponse) + + def testAcceptThenDrop(self): + """ + Async: Accept then drop + """ + for name in ['accept-then-drop.async.tests.powerdns.com.', 'accept-then-drop.tcp-only.async.tests.powerdns.com.']: + query = dns.message.make_query(name, 'A', 'IN') + + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '192.0.2.1') + response.answer.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (receivedQuery, receivedResponse) = sender(query, response) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(receivedResponse, None) + + (receivedQuery, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=response, caFile=self._caCert) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(receivedResponse, None) + + (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(receivedResponse, None) + + def testRefused(self): + """ + Async: Refused + """ + name = 'refused.async.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + + expectedResponse = dns.message.make_response(query) + expectedResponse.flags |= dns.flags.RA + expectedResponse.set_rcode(dns.rcode.REFUSED) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None, useQueue=False) + self.assertTrue(receivedResponse) + self.assertEqual(expectedResponse, receivedResponse) + + (_, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=None, caFile=self._caCert, useQueue=False) + self.assertEqual(expectedResponse, receivedResponse) + + (_, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=None, caFile=self._caCert, useQueue=False) + self.assertEqual(expectedResponse, receivedResponse) + + def testDrop(self): + """ + Async: Drop + """ + name = 'drop.async.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None, useQueue=False) + self.assertEqual(receivedResponse, None) + + (_, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=None, caFile=self._caCert, useQueue=False) + self.assertEqual(receivedResponse, None) + + (_, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=None, caFile=self._caCert, useQueue=False) + self.assertEqual(receivedResponse, None) + + def testCustom(self): + """ + Async: Custom answer + """ + name = 'custom.async.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + + expectedResponse = dns.message.make_response(query) + expectedResponse.flags |= dns.flags.RA + expectedResponse.set_rcode(dns.rcode.FORMERR) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None, useQueue=False) + self.assertTrue(receivedResponse) + print(expectedResponse) + print(receivedResponse) + self.assertEqual(expectedResponse, receivedResponse) + + (_, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=None, caFile=self._caCert, useQueue=False) + self.assertEqual(expectedResponse, receivedResponse) + + (_, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=None, caFile=self._caCert, useQueue=False) + self.assertEqual(expectedResponse, receivedResponse) + + def testTruncation(self): + """ + Async: DoH query, timeout then truncated answer over UDP, then valid over TCP and accept + """ + # the query is first forwarded over UDP, leading to a TC=1 answer from the + # backend, then over TCP + name = 'timeout-then-accept.tc.async.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + query.id = 42 + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096) + expectedQuery.id = 42 + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + # first response is a TC=1 + tcResponse = dns.message.make_response(query) + tcResponse.flags |= dns.flags.TC + self._toResponderQueue.put(tcResponse, True, 2.0) + + (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, caFile=self._caCert, response=response) + # first query, received by the responder over UDP + self.assertTrue(receivedQuery) + receivedQuery.id = expectedQuery.id + self.assertEqual(expectedQuery, receivedQuery) + self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery) + + # check the response + self.assertTrue(receivedResponse) + print(response) + print(receivedResponse) + self.assertEqual(response, receivedResponse) + + # check the second query, received by the responder over TCP + receivedQuery = self._fromResponderQueue.get(True, 2.0) + self.assertTrue(receivedQuery) + receivedQuery.id = expectedQuery.id + self.assertEqual(expectedQuery, receivedQuery) + self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery) diff --git a/regression-tests.dnsdist/test_DOH.py b/regression-tests.dnsdist/test_DOH.py index e3198ecf54..cd12c96133 100644 --- a/regression-tests.dnsdist/test_DOH.py +++ b/regression-tests.dnsdist/test_DOH.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -import base64 + import dns import os import time @@ -889,7 +889,7 @@ class TestDOHWithCache(DNSDistDOHTest): self._toResponderQueue.put(tcResponse, True, 2.0) (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, caFile=self._caCert, response=response) - # first query, received by dnsdist over UDP + # first query, received by the responder over UDP self.assertTrue(receivedQuery) receivedQuery.id = expectedQuery.id self.assertEqual(expectedQuery, receivedQuery) @@ -899,7 +899,7 @@ class TestDOHWithCache(DNSDistDOHTest): self.assertTrue(receivedResponse) self.assertEqual(response, receivedResponse) - # check the second query, received by dnsdist over TCP + # check the second query, received by the responder over TCP receivedQuery = self._fromResponderQueue.get(True, 2.0) self.assertTrue(receivedQuery) receivedQuery.id = expectedQuery.id @@ -1294,3 +1294,38 @@ class TestDOHWithPCKS12Cert(DNSDistDOHTest): self.assertTrue(receivedResponse) receivedQuery.id = expectedQuery.id self.assertEqual(expectedQuery, receivedQuery) + +class TestDOHForwardedToTCPOnly(DNSDistDOHTest): + _serverKey = 'server.key' + _serverCert = 'server.chain' + _serverName = 'tls.tests.dnsdist.org' + _caCert = 'ca.pem' + _dohServerPort = 8443 + _dohBaseURL = ("https://%s:%d/" % (_serverName, _dohServerPort)) + _config_template = """ + newServer{address="127.0.0.1:%s", tcpOnly=true} + addDOHLocal("127.0.0.1:%s", "%s", "%s", { "/" }) + """ + _config_params = ['_testServerPort', '_dohServerPort', '_serverCert', '_serverKey'] + + def testDOHTCPOnly(self): + """ + DoH: Test a DoH query forwarded to a TCP-only server + """ + name = 'tcponly.doh.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + query.id = 42 + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEqual(receivedQuery, query) + self.assertEqual(receivedResponse, response)