]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add regression tests for asynchronous processing
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 16 Dec 2021 16:00:40 +0000 (17:00 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 13 Jan 2023 15:57:49 +0000 (16:57 +0100)
regression-tests.dnsdist/dnsdistdohtests.py
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_Async.py [new file with mode: 0644]
regression-tests.dnsdist/test_DOH.py

index cd19d095a2f5b2a3abc931809a8421b59eff2224..57f0c54879b847001e56bb9bf5a3d7a7006bd0a3 100644 (file)
@@ -12,110 +12,6 @@ from io import BytesIO
 @unittest.skipIf('SKIP_DOH_TESTS' in os.environ, 'DNS over HTTPS tests are disabled')
 class DNSDistDOHTest(DNSDistTest):
 
-    @classmethod
-    def getDOHGetURL(cls, baseurl, query, rawQuery=False):
-        if rawQuery:
-            wire = query
-        else:
-            wire = query.to_wire()
-        param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
-        return baseurl + "?dns=" + param
-
-    @classmethod
-    def openDOHConnection(cls, port, caFile, timeout=2.0):
-        conn = pycurl.Curl()
-        conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
-
-        conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message",
-                                         "Accept: application/dns-message"])
-        return conn
-
-    @classmethod
-    def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None):
-        url = cls.getDOHGetURL(baseurl, query, rawQuery)
-        conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
-        response_headers = BytesIO()
-        #conn.setopt(pycurl.VERBOSE, True)
-        conn.setopt(pycurl.URL, url)
-        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
-        if useHTTPS:
-            conn.setopt(pycurl.SSL_VERIFYPEER, 1)
-            conn.setopt(pycurl.SSL_VERIFYHOST, 2)
-            if caFile:
-                conn.setopt(pycurl.CAINFO, caFile)
-
-        conn.setopt(pycurl.HTTPHEADER, customHeaders)
-        conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
-
-        if response:
-            if toQueue:
-                toQueue.put(response, True, timeout)
-            else:
-                cls._toResponderQueue.put(response, True, timeout)
-
-        receivedQuery = None
-        message = None
-        cls._response_headers = ''
-        data = conn.perform_rb()
-        cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
-        if cls._rcode == 200 and not rawResponse:
-            message = dns.message.from_wire(data)
-        elif rawResponse:
-            message = data
-
-        if useQueue:
-            if fromQueue:
-                if not fromQueue.empty():
-                    receivedQuery = fromQueue.get(True, timeout)
-            else:
-                if not cls._fromResponderQueue.empty():
-                    receivedQuery = cls._fromResponderQueue.get(True, timeout)
-
-        cls._response_headers = response_headers.getvalue()
-        return (receivedQuery, message)
-
-    @classmethod
-    def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
-        url = baseurl
-        conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
-        response_headers = BytesIO()
-        #conn.setopt(pycurl.VERBOSE, True)
-        conn.setopt(pycurl.URL, url)
-        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
-        if useHTTPS:
-            conn.setopt(pycurl.SSL_VERIFYPEER, 1)
-            conn.setopt(pycurl.SSL_VERIFYHOST, 2)
-            if caFile:
-                conn.setopt(pycurl.CAINFO, caFile)
-
-        conn.setopt(pycurl.HTTPHEADER, customHeaders)
-        conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
-        conn.setopt(pycurl.POST, True)
-        data = query
-        if not rawQuery:
-            data = data.to_wire()
-
-        conn.setopt(pycurl.POSTFIELDS, data)
-
-        if response:
-            cls._toResponderQueue.put(response, True, timeout)
-
-        receivedQuery = None
-        message = None
-        cls._response_headers = ''
-        data = conn.perform_rb()
-        cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
-        if cls._rcode == 200 and not rawResponse:
-            message = dns.message.from_wire(data)
-        elif rawResponse:
-            message = data
-
-        if useQueue and not cls._fromResponderQueue.empty():
-            receivedQuery = cls._fromResponderQueue.get(True, timeout)
-
-        cls._response_headers = response_headers.getvalue()
-        return (receivedQuery, message)
-
     def getHeaderValue(self, name):
         for header in self._response_headers.decode().splitlines(False):
             values = header.split(':')
index c622769b8764d7d47b8d94fb9e9aa05c070d873a..015369a862bece0cc95ab1bd1e6d4f6b448db6c3 100644 (file)
@@ -1,5 +1,6 @@
 #!/usr/bin/env python2
 
+import base64
 import copy
 import errno
 import os
@@ -24,6 +25,9 @@ import h2.connection
 import h2.events
 import h2.config
 
+import pycurl
+from io import BytesIO
+
 from eqdnsmessage import AssertEqualDNSMessageMixin
 from proxyprotocol import ProxyProtocol
 
@@ -629,6 +633,14 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             print("queue empty")
             return message
 
+    @classmethod
+    def sendDOTQuery(cls, port, serverName, query, response, caFile, useQueue=True):
+        conn = cls.openTLSConnection(port, serverName, caFile)
+        cls.sendTCPQueryOverConnection(conn, query, response=response)
+        if useQueue:
+          return cls.recvTCPResponseOverConnection(conn, useQueue=useQueue)
+        return None, cls.recvTCPResponseOverConnection(conn, useQueue=useQueue)
+
     @classmethod
     def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
         message = None
@@ -910,3 +922,107 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         proxy.values.sort()
         values.sort()
         self.assertEqual(proxy.values, values)
+
+    @classmethod
+    def getDOHGetURL(cls, baseurl, query, rawQuery=False):
+        if rawQuery:
+            wire = query
+        else:
+            wire = query.to_wire()
+        param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
+        return baseurl + "?dns=" + param
+
+    @classmethod
+    def openDOHConnection(cls, port, caFile, timeout=2.0):
+        conn = pycurl.Curl()
+        conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
+
+        conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message",
+                                         "Accept: application/dns-message"])
+        return conn
+
+    @classmethod
+    def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None):
+        url = cls.getDOHGetURL(baseurl, query, rawQuery)
+        conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
+        response_headers = BytesIO()
+        #conn.setopt(pycurl.VERBOSE, True)
+        conn.setopt(pycurl.URL, url)
+        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+        if useHTTPS:
+            conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+            conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+            if caFile:
+                conn.setopt(pycurl.CAINFO, caFile)
+
+        conn.setopt(pycurl.HTTPHEADER, customHeaders)
+        conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
+
+        if response:
+            if toQueue:
+                toQueue.put(response, True, timeout)
+            else:
+                cls._toResponderQueue.put(response, True, timeout)
+
+        receivedQuery = None
+        message = None
+        cls._response_headers = ''
+        data = conn.perform_rb()
+        cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+        if cls._rcode == 200 and not rawResponse:
+            message = dns.message.from_wire(data)
+        elif rawResponse:
+            message = data
+
+        if useQueue:
+            if fromQueue:
+                if not fromQueue.empty():
+                    receivedQuery = fromQueue.get(True, timeout)
+            else:
+                if not cls._fromResponderQueue.empty():
+                    receivedQuery = cls._fromResponderQueue.get(True, timeout)
+
+        cls._response_headers = response_headers.getvalue()
+        return (receivedQuery, message)
+
+    @classmethod
+    def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
+        url = baseurl
+        conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
+        response_headers = BytesIO()
+        #conn.setopt(pycurl.VERBOSE, True)
+        conn.setopt(pycurl.URL, url)
+        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+        if useHTTPS:
+            conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+            conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+            if caFile:
+                conn.setopt(pycurl.CAINFO, caFile)
+
+        conn.setopt(pycurl.HTTPHEADER, customHeaders)
+        conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
+        conn.setopt(pycurl.POST, True)
+        data = query
+        if not rawQuery:
+            data = data.to_wire()
+
+        conn.setopt(pycurl.POSTFIELDS, data)
+
+        if response:
+            cls._toResponderQueue.put(response, True, timeout)
+
+        receivedQuery = None
+        message = None
+        cls._response_headers = ''
+        data = conn.perform_rb()
+        cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+        if cls._rcode == 200 and not rawResponse:
+            message = dns.message.from_wire(data)
+        elif rawResponse:
+            message = data
+
+        if useQueue and not cls._fromResponderQueue.empty():
+            receivedQuery = cls._fromResponderQueue.get(True, timeout)
+
+        cls._response_headers = response_headers.getvalue()
+        return (receivedQuery, message)
diff --git a/regression-tests.dnsdist/test_Async.py b/regression-tests.dnsdist/test_Async.py
new file mode 100644 (file)
index 0000000..646d950
--- /dev/null
@@ -0,0 +1,548 @@
+#!/usr/bin/env python
+
+import os
+import socket
+import threading
+import unittest
+import dns
+from dnsdisttests import DNSDistTest
+
+def AsyncResponder(listenPath, responsePath):
+    # Make sure the socket does not already exist
+    try:
+        os.unlink(listenPath)
+    except OSError:
+        if os.path.exists(listenPath):
+            raise
+
+    sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+    try:
+        sock.bind(listenPath)
+    except socket.error as e:
+        print("Error binding in the Asynchronous responder: %s" % str(e))
+        sys.exit(1)
+
+    while True:
+        data, addr = sock.recvfrom(65535)
+        print("Got message [%d] '%s' from %s" % (len(data), data, addr))
+        if not data:
+            break
+
+        request = dns.message.from_wire(data)
+        reply = str(request.id) + ' '
+        if str(request.question[0].name).startswith('accept-then-refuse'):
+            if request.flags & dns.flags.QR:
+                reply = reply + 'refuse'
+            else:
+                reply = reply + 'accept'
+        elif str(request.question[0].name).startswith('accept-then-drop'):
+            if request.flags & dns.flags.QR:
+                reply = reply + 'drop'
+            else:
+                reply = reply + 'accept'
+        elif str(request.question[0].name).startswith('accept-then-custom'):
+            if request.flags & dns.flags.QR:
+                reply = reply + 'custom'
+            else:
+                reply = reply + 'accept'
+        elif str(request.question[0].name).startswith('timeout-then-accept'):
+            if request.flags & dns.flags.QR:
+                reply = reply + 'accept'
+            else:
+                # no response
+                continue
+        elif str(request.question[0].name).startswith('accept-then-timeout'):
+            if request.flags & dns.flags.QR:
+                # no response
+                continue
+            else:
+                reply = reply + 'accept'
+        elif str(request.question[0].name).startswith('accept'):
+            reply = reply + 'accept'
+        elif str(request.question[0].name).startswith('refuse'):
+            reply = reply + 'refuse'
+        elif str(request.question[0].name).startswith('drop'):
+            reply = reply + 'drop'
+        elif str(request.question[0].name).startswith('custom'):
+            reply = reply + 'custom'
+        elif str(request.question[0].name).startswith('timeout'):
+            # no response
+            continue
+        else:
+            reply = reply + 'invalid'
+
+        remote = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+        remote.connect(responsePath)
+        remote.send(reply.encode())
+        print("Sent [%d] '%s' to %s" % (len(reply), reply, responsePath))
+
+    sock.close()
+
+asyncResponderSocketPath = '/tmp/async-responder.sock'
+dnsdistSocketPath = '/tmp/dnsdist.sock'
+asyncResponder = threading.Thread(name='Asynchronous Responder', target=AsyncResponder, args=[asyncResponderSocketPath, dnsdistSocketPath])
+asyncResponder.setDaemon(True)
+asyncResponder.start()
+
+@unittest.skipIf('SKIP_DOH_TESTS' in os.environ, 'DNS over HTTPS tests are disabled')
+class TestAsync(DNSDistTest):
+
+    _serverKey = 'server.key'
+    _serverCert = 'server.chain'
+    _serverName = 'tls.tests.dnsdist.org'
+    _caCert = 'ca.pem'
+    _tlsServerPort = 8453
+    _dohServerPort = 8443
+    _dohBaseURL = ("https://%s:%d/" % (_serverName, _dohServerPort))
+
+    _config_template = """
+    newServer{address="127.0.0.1:%s", pool={'', 'cache'}}
+    newServer{address="127.0.0.1:%s", pool="tcp-only", tcpOnly=true }
+
+    addTLSLocal("127.0.0.1:%s", "%s", "%s", { provider="openssl" })
+    addDOHLocal("127.0.0.1:%s", "%s", "%s", { "/"})
+
+    local ffi = require("ffi")
+    local C = ffi.C
+
+    local filteringTagName = 'filtering'
+    local filteringTagValue = 'pass'
+    local asyncID = 0
+
+    pc = newPacketCache(100)
+    getPool('cache'):setCache(pc)
+
+    function gotAsyncResponse(endpointID, message, from)
+
+      print('Got async response '..message)
+      local parts = {}
+      for part in message:gmatch("%%S+") do table.insert(parts, part) end
+      if #parts ~= 2 then
+        print('Invalid message')
+        return
+      end
+      local queryID = tonumber(parts[1])
+      if parts[2] == 'accept' then
+        print('accepting')
+        C.dnsdist_ffi_resume_from_async(asyncID, queryID, filteringTagName, #filteringTagName, filteringTagValue, #filteringTagValue, true)
+        return
+      end
+      if parts[2] == 'refuse' then
+        print('refusing')
+        C.dnsdist_ffi_set_rcode_from_async(asyncID, queryID, DNSRCode.REFUSED, true)
+        return
+      end
+      if parts[2] == 'drop' then
+        print('dropping')
+        C.dnsdist_ffi_drop_from_async(asyncID, queryID)
+        return
+      end
+      if parts[2] == 'custom' then
+        print('sending a custom response')
+        local raw = '\\000\\000\\128\\129\\000\\001\\000\\000\\000\\000\\000\\001\\006custom\\005async\\005tests\\008powerdns\\003com\\000\\000\\001\\000\\001\\000\\000\\041\\002\\000\\000\\000\\128\\000\\000\\000'
+        C.dnsdist_ffi_set_answer_from_async(asyncID, queryID, raw, #raw)
+        return
+      end
+    end
+
+    local asyncResponderEndpoint = newNetworkEndpoint('%s')
+    local listener = newNetworkListener()
+    listener:addUnixListeningEndpoint('%s', 0, gotAsyncResponse)
+    listener:start()
+
+    function passQueryToAsyncFilter(dq)
+      print('in passQueryToAsyncFilter')
+      local timeout = 500 -- 500 ms
+
+      local queryPtr = C.dnsdist_ffi_dnsquestion_get_header(dq)
+      local querySize = C.dnsdist_ffi_dnsquestion_get_len(dq)
+
+      -- we need to take a copy, as we can no longer touch that data after calling set_async
+      local buffer = ffi.string(queryPtr, querySize)
+
+      print(C.dnsdist_ffi_dnsquestion_set_async(dq, asyncID, C.dnsdist_ffi_dnsquestion_get_id(dq), timeout))
+      asyncResponderEndpoint:send(buffer)
+
+      return DNSAction.Allow
+    end
+
+    function passResponseToAsyncFilter(dr)
+      print('in passResponseToAsyncFilter')
+      local timeout = 500 -- 500 ms
+
+      local responsePtr = C.dnsdist_ffi_dnsquestion_get_header(dr)
+      local responseSize = C.dnsdist_ffi_dnsquestion_get_len(dr)
+
+      -- we need to take a copy, as we can no longer touch that data after calling set_async
+      local buffer = ffi.string(responsePtr, responseSize)
+
+      print(C.dnsdist_ffi_dnsresponse_set_async(dr, asyncID, C.dnsdist_ffi_dnsquestion_get_id(dr), timeout))
+      asyncResponderEndpoint:send(buffer)
+
+      return DNSResponseAction.Allow
+    end
+
+    -- this only matters for tests actually reaching the backend
+    addAction('tcp-only.async.tests.powerdns.com', PoolAction('tcp-only', false))
+    addAction('cache.async.tests.powerdns.com', PoolAction('cache', false))
+    addAction(AllRule(), LuaFFIAction(passQueryToAsyncFilter))
+    addCacheHitResponseAction(AllRule(), LuaFFIResponseAction(passResponseToAsyncFilter))
+    addResponseAction(AllRule(), LuaFFIResponseAction(passResponseToAsyncFilter))
+    """
+    _asyncResponderSocketPath = asyncResponderSocketPath
+    _dnsdistSocketPath = dnsdistSocketPath
+    _config_params = ['_testServerPort', '_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey', '_dohServerPort', '_serverCert', '_serverKey', '_asyncResponderSocketPath', '_dnsdistSocketPath']
+    _verboseMode = True
+
+    def testPass(self):
+        """
+        Async: Accept
+        """
+        for name in ['accept.async.tests.powerdns.com.', 'accept.tcp-only.async.tests.powerdns.com.']:
+            query = dns.message.make_query(name, 'A', 'IN')
+
+            response = dns.message.make_response(query)
+            rrset = dns.rrset.from_text(name,
+                                        60,
+                                        dns.rdataclass.IN,
+                                        dns.rdatatype.A,
+                                        '192.0.2.1')
+            response.answer.append(rrset)
+
+            for method in ("sendUDPQuery", "sendTCPQuery"):
+                sender = getattr(self, method)
+                (receivedQuery, receivedResponse) = sender(query, response)
+                receivedQuery.id = query.id
+                self.assertEqual(query, receivedQuery)
+                self.assertEqual(response, receivedResponse)
+
+            (receivedQuery, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=response, caFile=self._caCert)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
+
+            (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
+
+    def testPassCached(self):
+        """
+        Async: Accept (cached)
+        """
+        name = 'accept.cache.async.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            # first time to fill the cache
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
+            # second time from the cache
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEqual(response, receivedResponse)
+
+        (_, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=None, useQueue=False, caFile=self._caCert)
+        self.assertEqual(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert)
+        receivedQuery.id = query.id
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(response, receivedResponse)
+
+        (_, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=None, useQueue=False, caFile=self._caCert)
+        print(response)
+        print(receivedResponse)
+        self.assertEqual(response, receivedResponse)
+
+    def testTimeoutThenAccept(self):
+        """
+        Async: Timeout then accept
+        """
+        for name in ['timeout-then-accept.async.tests.powerdns.com.', 'timeout-then-accept.tcp-only.async.tests.powerdns.com.']:
+            query = dns.message.make_query(name, 'A', 'IN')
+
+            response = dns.message.make_response(query)
+            rrset = dns.rrset.from_text(name,
+                                        60,
+                                        dns.rdataclass.IN,
+                                        dns.rdatatype.A,
+                                        '192.0.2.1')
+            response.answer.append(rrset)
+
+            for method in ("sendUDPQuery", "sendTCPQuery"):
+                sender = getattr(self, method)
+                (receivedQuery, receivedResponse) = sender(query, response)
+                receivedQuery.id = query.id
+                self.assertEqual(query, receivedQuery)
+                self.assertEqual(response, receivedResponse)
+
+            (receivedQuery, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=response, caFile=self._caCert)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
+
+            (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
+
+    def testAcceptThenTimeout(self):
+        """
+        Async: Accept then timeout
+        """
+        for name in ['accept-then-timeout.async.tests.powerdns.com.', 'accept-then-timeout.tcp-only.async.tests.powerdns.com.']:
+            query = dns.message.make_query(name, 'A', 'IN')
+
+            response = dns.message.make_response(query)
+            rrset = dns.rrset.from_text(name,
+                                        60,
+                                        dns.rdataclass.IN,
+                                        dns.rdatatype.A,
+                                        '192.0.2.1')
+            response.answer.append(rrset)
+
+            for method in ("sendUDPQuery", "sendTCPQuery"):
+                sender = getattr(self, method)
+                (receivedQuery, receivedResponse) = sender(query, response)
+                receivedQuery.id = query.id
+                self.assertEqual(query, receivedQuery)
+                self.assertEqual(response, receivedResponse)
+
+            (receivedQuery, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=response, caFile=self._caCert)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
+
+            (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
+
+    def testAcceptThenRefuse(self):
+        """
+        Async: Accept then refuse
+        """
+        for name in ['accept-then-refuse.async.tests.powerdns.com.', 'accept-then-refuse.tcp-only.async.tests.powerdns.com.']:
+            query = dns.message.make_query(name, 'A', 'IN')
+
+            response = dns.message.make_response(query)
+            rrset = dns.rrset.from_text(name,
+                                        60,
+                                        dns.rdataclass.IN,
+                                        dns.rdatatype.A,
+                                        '192.0.2.1')
+            response.answer.append(rrset)
+
+            expectedResponse = dns.message.make_response(query)
+            expectedResponse.flags |= dns.flags.RA
+            expectedResponse.set_rcode(dns.rcode.REFUSED)
+
+            for method in ("sendUDPQuery", "sendTCPQuery"):
+                sender = getattr(self, method)
+                (receivedQuery, receivedResponse) = sender(query, response)
+                receivedQuery.id = query.id
+                self.assertEqual(query, receivedQuery)
+                self.assertEqual(expectedResponse, receivedResponse)
+
+            (receivedQuery, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=response, caFile=self._caCert)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+            (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+    def testAcceptThenCustom(self):
+        """
+        Async: Accept then custom
+        """
+        for name in ['accept-then-custom.async.tests.powerdns.com.', 'accept-then-custom.tcp-only.async.tests.powerdns.com.']:
+            query = dns.message.make_query(name, 'A', 'IN')
+
+            response = dns.message.make_response(query)
+            rrset = dns.rrset.from_text(name,
+                                        60,
+                                        dns.rdataclass.IN,
+                                        dns.rdatatype.A,
+                                        '192.0.2.1')
+            response.answer.append(rrset)
+
+            # easier to get the same custom response to everyone, sorry!
+            expectedQuery = dns.message.make_query('custom.async.tests.powerdns.com.', 'A', 'IN')
+            expectedQuery.id = query.id
+            expectedResponse = dns.message.make_response(expectedQuery)
+            expectedResponse.flags |= dns.flags.RA
+            expectedResponse.set_rcode(dns.rcode.FORMERR)
+
+            for method in ("sendUDPQuery", "sendTCPQuery"):
+                sender = getattr(self, method)
+                (receivedQuery, receivedResponse) = sender(query, response)
+                receivedQuery.id = query.id
+                self.assertEqual(query, receivedQuery)
+                self.assertEqual(expectedResponse, receivedResponse)
+
+            (receivedQuery, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=response, caFile=self._caCert)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+            (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+    def testAcceptThenDrop(self):
+        """
+        Async: Accept then drop
+        """
+        for name in ['accept-then-drop.async.tests.powerdns.com.', 'accept-then-drop.tcp-only.async.tests.powerdns.com.']:
+            query = dns.message.make_query(name, 'A', 'IN')
+
+            response = dns.message.make_response(query)
+            rrset = dns.rrset.from_text(name,
+                                        60,
+                                        dns.rdataclass.IN,
+                                        dns.rdatatype.A,
+                                        '192.0.2.1')
+            response.answer.append(rrset)
+
+            for method in ("sendUDPQuery", "sendTCPQuery"):
+                sender = getattr(self, method)
+                (receivedQuery, receivedResponse) = sender(query, response)
+                receivedQuery.id = query.id
+                self.assertEqual(query, receivedQuery)
+                self.assertEqual(receivedResponse, None)
+
+            (receivedQuery, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=response, caFile=self._caCert)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(receivedResponse, None)
+
+            (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(receivedResponse, None)
+
+    def testRefused(self):
+        """
+        Async: Refused
+        """
+        name = 'refused.async.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.flags |= dns.flags.RA
+        expectedResponse.set_rcode(dns.rcode.REFUSED)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+        (_, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=None, caFile=self._caCert, useQueue=False)
+        self.assertEqual(expectedResponse, receivedResponse)
+
+        (_, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=None, caFile=self._caCert, useQueue=False)
+        self.assertEqual(expectedResponse, receivedResponse)
+
+    def testDrop(self):
+        """
+        Async: Drop
+        """
+        name = 'drop.async.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEqual(receivedResponse, None)
+
+        (_, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=None, caFile=self._caCert, useQueue=False)
+        self.assertEqual(receivedResponse, None)
+
+        (_, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=None, caFile=self._caCert, useQueue=False)
+        self.assertEqual(receivedResponse, None)
+
+    def testCustom(self):
+        """
+        Async: Custom answer
+        """
+        name = 'custom.async.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.flags |= dns.flags.RA
+        expectedResponse.set_rcode(dns.rcode.FORMERR)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            print(expectedResponse)
+            print(receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+        (_, receivedResponse) = self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response=None, caFile=self._caCert, useQueue=False)
+        self.assertEqual(expectedResponse, receivedResponse)
+
+        (_, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=None, caFile=self._caCert, useQueue=False)
+        self.assertEqual(expectedResponse, receivedResponse)
+
+    def testTruncation(self):
+        """
+        Async: DoH query, timeout then truncated answer over UDP, then valid over TCP and accept
+        """
+        # the query is first forwarded over UDP, leading to a TC=1 answer from the
+        # backend, then over TCP
+        name = 'timeout-then-accept.tc.async.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        query.id = 42
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
+        expectedQuery.id = 42
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        # first response is a TC=1
+        tcResponse = dns.message.make_response(query)
+        tcResponse.flags |= dns.flags.TC
+        self._toResponderQueue.put(tcResponse, True, 2.0)
+
+        (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, caFile=self._caCert, response=response)
+        # first query, received by the responder over UDP
+        self.assertTrue(receivedQuery)
+        receivedQuery.id = expectedQuery.id
+        self.assertEqual(expectedQuery, receivedQuery)
+        self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
+
+        # check the response
+        self.assertTrue(receivedResponse)
+        print(response)
+        print(receivedResponse)
+        self.assertEqual(response, receivedResponse)
+
+        # check the second query, received by the responder over TCP
+        receivedQuery = self._fromResponderQueue.get(True, 2.0)
+        self.assertTrue(receivedQuery)
+        receivedQuery.id = expectedQuery.id
+        self.assertEqual(expectedQuery, receivedQuery)
+        self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
index e3198ecf5403d227e0b528aaaaf8b1dbb20205cf..cd12c96133d215b4825cea91a9760fa8d04cff0c 100644 (file)
@@ -1,5 +1,5 @@
 #!/usr/bin/env python
-import base64
+
 import dns
 import os
 import time
@@ -889,7 +889,7 @@ class TestDOHWithCache(DNSDistDOHTest):
         self._toResponderQueue.put(tcResponse, True, 2.0)
 
         (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, caFile=self._caCert, response=response)
-        # first query, received by dnsdist over UDP
+        # first query, received by the responder over UDP
         self.assertTrue(receivedQuery)
         receivedQuery.id = expectedQuery.id
         self.assertEqual(expectedQuery, receivedQuery)
@@ -899,7 +899,7 @@ class TestDOHWithCache(DNSDistDOHTest):
         self.assertTrue(receivedResponse)
         self.assertEqual(response, receivedResponse)
 
-        # check the second query, received by dnsdist over TCP
+        # check the second query, received by the responder over TCP
         receivedQuery = self._fromResponderQueue.get(True, 2.0)
         self.assertTrue(receivedQuery)
         receivedQuery.id = expectedQuery.id
@@ -1294,3 +1294,38 @@ class TestDOHWithPCKS12Cert(DNSDistDOHTest):
         self.assertTrue(receivedResponse)
         receivedQuery.id = expectedQuery.id
         self.assertEqual(expectedQuery, receivedQuery)
+
+class TestDOHForwardedToTCPOnly(DNSDistDOHTest):
+    _serverKey = 'server.key'
+    _serverCert = 'server.chain'
+    _serverName = 'tls.tests.dnsdist.org'
+    _caCert = 'ca.pem'
+    _dohServerPort = 8443
+    _dohBaseURL = ("https://%s:%d/" % (_serverName, _dohServerPort))
+    _config_template = """
+    newServer{address="127.0.0.1:%s", tcpOnly=true}
+    addDOHLocal("127.0.0.1:%s", "%s", "%s", { "/" })
+    """
+    _config_params = ['_testServerPort', '_dohServerPort', '_serverCert', '_serverKey']
+
+    def testDOHTCPOnly(self):
+        """
+        DoH: Test a DoH query forwarded to a TCP-only server
+        """
+        name = 'tcponly.doh.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        query.id = 42
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEqual(receivedQuery, query)
+        self.assertEqual(receivedResponse, response)