]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/test_Async.py
Merge pull request #12769 from neilcook/patch-1
[thirdparty/pdns.git] / regression-tests.dnsdist / test_Async.py
index 8856df30f59aa3ef34ab7cdfbba172df98515f7f..eda4a5903c33ea1c959e2585abd0bea903bcb018 100644 (file)
@@ -6,6 +6,9 @@ import sys
 import threading
 import unittest
 import dns
+import dns.message
+import doqclient
+
 from dnsdisttests import DNSDistTest, pickAvailablePort
 
 def AsyncResponder(listenPath, responsePath):
@@ -95,6 +98,7 @@ class AsyncTests(object):
     _dohWithH2OServerPort = pickAvailablePort()
     _dohWithNGHTTP2BaseURL = ("https://%s:%d/" % (_serverName, _dohWithNGHTTP2ServerPort))
     _dohWithH2OBaseURL = ("https://%s:%d/" % (_serverName, _dohWithH2OServerPort))
+    _doqServerPort = pickAvailablePort()
 
     def testPass(self):
         """
@@ -111,11 +115,14 @@ class AsyncTests(object):
                                         '192.0.2.1')
             response.answer.append(rrset)
 
-            for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper"):
+            for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper", "sendDOQQueryWrapper"):
                 sender = getattr(self, method)
                 (receivedQuery, receivedResponse) = sender(query, response)
                 receivedQuery.id = query.id
                 self.assertEqual(query, receivedQuery)
+                if method == 'sendDOQQueryWrapper':
+                    # dnspython sets the ID to 0
+                    receivedResponse.id = response.id
                 self.assertEqual(response, receivedResponse)
 
     def testPassCached(self):
@@ -133,9 +140,9 @@ class AsyncTests(object):
                                     '192.0.2.1')
         response.answer.append(rrset)
 
-        for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper"):
+        for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper", "sendDOQQueryWrapper"):
             sender = getattr(self, method)
-            if method != 'sendDOTQueryWrapper' and method != 'sendDOHWithH2OQueryWrapper':
+            if method != 'sendDOTQueryWrapper' and method != 'sendDOHWithH2OQueryWrapper' and method != 'sendDOQQueryWrapper':
                 # first time to fill the cache
                 # disabled for DoT since it was already filled via TCP
                 (receivedQuery, receivedResponse) = sender(query, response)
@@ -146,6 +153,9 @@ class AsyncTests(object):
             # second time from the cache
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            if method == 'sendDOQQueryWrapper':
+                # dnspython sets the ID to 0
+                receivedResponse.id = response.id
             self.assertEqual(response, receivedResponse)
 
     def testTimeoutThenAccept(self):
@@ -163,11 +173,14 @@ class AsyncTests(object):
                                         '192.0.2.1')
             response.answer.append(rrset)
 
-            for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper"):
+            for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper", "sendDOQQueryWrapper"):
                 sender = getattr(self, method)
                 (receivedQuery, receivedResponse) = sender(query, response)
                 receivedQuery.id = query.id
                 self.assertEqual(query, receivedQuery)
+                if method == 'sendDOQQueryWrapper':
+                    # dnspython sets the ID to 0
+                    receivedResponse.id = response.id
                 self.assertEqual(response, receivedResponse)
 
     def testAcceptThenTimeout(self):
@@ -185,11 +198,14 @@ class AsyncTests(object):
                                         '192.0.2.1')
             response.answer.append(rrset)
 
-            for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper"):
+            for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper", "sendDOQQueryWrapper"):
                 sender = getattr(self, method)
                 (receivedQuery, receivedResponse) = sender(query, response)
                 receivedQuery.id = query.id
                 self.assertEqual(query, receivedQuery)
+                if method == 'sendDOQQueryWrapper':
+                    # dnspython sets the ID to 0
+                    receivedResponse.id = response.id
                 self.assertEqual(response, receivedResponse)
 
     def testAcceptThenRefuse(self):
@@ -211,11 +227,14 @@ class AsyncTests(object):
             expectedResponse.flags |= dns.flags.RA
             expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-            for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper"):
+            for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper", "sendDOQQueryWrapper"):
                 sender = getattr(self, method)
                 (receivedQuery, receivedResponse) = sender(query, response)
                 receivedQuery.id = query.id
                 self.assertEqual(query, receivedQuery)
+                if method == 'sendDOQQueryWrapper':
+                    # dnspython sets the ID to 0
+                    receivedResponse.id = expectedResponse.id
                 self.assertEqual(expectedResponse, receivedResponse)
 
     def testAcceptThenCustom(self):
@@ -233,18 +252,20 @@ class AsyncTests(object):
                                         '192.0.2.1')
             response.answer.append(rrset)
 
-            # easier to get the same custom response to everyone, sorry!
-            expectedQuery = dns.message.make_query('custom.async.tests.powerdns.com.', 'A', 'IN')
+            expectedQuery = dns.message.make_query(name, 'A', 'IN')
             expectedQuery.id = query.id
             expectedResponse = dns.message.make_response(expectedQuery)
             expectedResponse.flags |= dns.flags.RA
             expectedResponse.set_rcode(dns.rcode.FORMERR)
 
-            for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper"):
+            for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper", "sendDOQQueryWrapper"):
                 sender = getattr(self, method)
                 (receivedQuery, receivedResponse) = sender(query, response)
                 receivedQuery.id = query.id
                 self.assertEqual(query, receivedQuery)
+                if method == 'sendDOQQueryWrapper':
+                    # dnspython sets the ID to 0
+                    receivedResponse.id = expectedResponse.id
                 self.assertEqual(expectedResponse, receivedResponse)
 
     def testAcceptThenDrop(self):
@@ -262,9 +283,14 @@ class AsyncTests(object):
                                         '192.0.2.1')
             response.answer.append(rrset)
 
-            for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper"):
+            for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper", "sendDOQQueryWrapper"):
                 sender = getattr(self, method)
-                (receivedQuery, receivedResponse) = sender(query, response)
+                try:
+                    (receivedQuery, receivedResponse) = sender(query, response)
+                except doqclient.StreamResetError:
+                    if not self._fromResponderQueue.empty():
+                        receivedQuery = self._fromResponderQueue.get(True, 1.0)
+                    receivedResponse = None
                 receivedQuery.id = query.id
                 self.assertEqual(query, receivedQuery)
                 self.assertEqual(receivedResponse, None)
@@ -280,10 +306,13 @@ class AsyncTests(object):
         expectedResponse.flags |= dns.flags.RA
         expectedResponse.set_rcode(dns.rcode.REFUSED)
 
-        for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper"):
+        for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper", "sendDOQQueryWrapper"):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
+            if method == 'sendDOQQueryWrapper':
+                # dnspython sets the ID to 0
+                receivedResponse.id = expectedResponse.id
             self.assertEqual(expectedResponse, receivedResponse)
 
     def testDrop(self):
@@ -293,9 +322,12 @@ class AsyncTests(object):
         name = 'drop.async.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
 
-        for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper"):
+        for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper", "sendDOQQueryWrapper"):
             sender = getattr(self, method)
-            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            try:
+                (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            except doqclient.StreamResetError:
+                receivedResponse = None
             self.assertEqual(receivedResponse, None)
 
     def testCustom(self):
@@ -309,10 +341,13 @@ class AsyncTests(object):
         expectedResponse.flags |= dns.flags.RA
         expectedResponse.set_rcode(dns.rcode.FORMERR)
 
-        for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper"):
+        for method in ("sendUDPQuery", "sendTCPQuery", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper", "sendDOHWithH2OQueryWrapper", "sendDOQQueryWrapper"):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
+            if method == 'sendDOQQueryWrapper':
+                # dnspython sets the ID to 0
+                receivedResponse.id = expectedResponse.id
             self.assertEqual(expectedResponse, receivedResponse)
 
     def testTruncation(self):
@@ -369,6 +404,7 @@ class TestAsyncFFI(DNSDistTest, AsyncTests):
     addTLSLocal("127.0.0.1:%d", "%s", "%s", { provider="openssl" })
     addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library="h2o"})
     addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library="nghttp2"})
+    addDOQLocal("127.0.0.1:%d", "%s", "%s")
 
     local ffi = require("ffi")
     local C = ffi.C
@@ -380,6 +416,8 @@ class TestAsyncFFI(DNSDistTest, AsyncTests):
     pc = newPacketCache(100)
     getPool('cache'):setCache(pc)
 
+    local asyncObjectsMap = {}
+
     function gotAsyncResponse(endpointID, message, from)
 
       print('Got async response '..message)
@@ -390,6 +428,7 @@ class TestAsyncFFI(DNSDistTest, AsyncTests):
         return
       end
       local queryID = tonumber(parts[1])
+      local qname = asyncObjectsMap[queryID]
       if parts[2] == 'accept' then
         print('accepting')
         C.dnsdist_ffi_resume_from_async(asyncID, queryID, filteringTagName, #filteringTagName, filteringTagValue, #filteringTagValue, true)
@@ -407,17 +446,34 @@ class TestAsyncFFI(DNSDistTest, AsyncTests):
       end
       if parts[2] == 'custom' then
         print('sending a custom response')
-        local raw = '\\000\\000\\128\\129\\000\\001\\000\\000\\000\\000\\000\\001\\006custom\\005async\\005tests\\008powerdns\\003com\\000\\000\\001\\000\\001\\000\\000\\041\\002\\000\\000\\000\\128\\000\\000\\000'
+        local raw = nil
+        if qname == string.char(6)..'custom'..string.char(5)..'async'..string.char(5)..'tests'..string.char(8)..'powerdns'..string.char(3)..'com' then
+          raw = '\\000\\000\\128\\129\\000\\001\\000\\000\\000\\000\\000\\001\\006custom\\005async\\005tests\\008powerdns\\003com\\000\\000\\001\\000\\001\\000\\000\\041\\002\\000\\000\\000\\128\\000\\000\\000'
+        elseif qname == string.char(18)..'accept-then-custom'..string.char(5)..'async'..string.char(5)..'tests'..string.char(8)..'powerdns'..string.char(3)..'com' then
+          raw = '\\000\\000\\128\\129\\000\\001\\000\\000\\000\\000\\000\\001\\018accept-then-custom\\005async\\005tests\\008powerdns\\003com\\000\\000\\001\\000\\001\\000\\000\\041\\002\\000\\000\\000\\128\\000\\000\\000'
+        elseif qname == string.char(18)..'accept-then-custom'..string.char(8)..'tcp-only'..string.char(5)..'async'..string.char(5)..'tests'..string.char(8)..'powerdns'..string.char(3)..'com' then
+          raw = '\\000\\000\\128\\129\\000\\001\\000\\000\\000\\000\\000\\001\\018accept-then-custom\\008tcp-only\\005async\\005tests\\008powerdns\\003com\\000\\000\\001\\000\\001\\000\\000\\041\\002\\000\\000\\000\\128\\000\\000\\000'
+        end
+
         C.dnsdist_ffi_set_answer_from_async(asyncID, queryID, raw, #raw)
         return
       end
     end
 
-    local asyncResponderEndpoint = newNetworkEndpoint('%s')
-    local listener = newNetworkListener()
+    asyncResponderEndpoint = newNetworkEndpoint('%s')
+    listener = newNetworkListener()
     listener:addUnixListeningEndpoint('%s', 0, gotAsyncResponse)
     listener:start()
 
+    function getQNameRaw(dq)
+      local ret_ptr = ffi.new("char *[1]")
+      local ret_ptr_param = ffi.cast("const char **", ret_ptr)
+      local ret_size = ffi.new("size_t[1]")
+      local ret_size_param = ffi.cast("size_t*", ret_size)
+      C.dnsdist_ffi_dnsquestion_get_qname_raw(dq, ret_ptr_param, ret_size_param)
+      return ffi.string(ret_ptr[0])
+    end
+
     function passQueryToAsyncFilter(dq)
       print('in passQueryToAsyncFilter')
       local timeout = 500 -- 500 ms
@@ -428,7 +484,9 @@ class TestAsyncFFI(DNSDistTest, AsyncTests):
       -- we need to take a copy, as we can no longer touch that data after calling set_async
       local buffer = ffi.string(queryPtr, querySize)
 
-      print(C.dnsdist_ffi_dnsquestion_set_async(dq, asyncID, C.dnsdist_ffi_dnsquestion_get_id(dq), timeout))
+      asyncObjectsMap[C.dnsdist_ffi_dnsquestion_get_id(dq)] = getQNameRaw(dq)
+
+      C.dnsdist_ffi_dnsquestion_set_async(dq, asyncID, C.dnsdist_ffi_dnsquestion_get_id(dq), timeout)
       asyncResponderEndpoint:send(buffer)
 
       return DNSAction.Allow
@@ -444,7 +502,9 @@ class TestAsyncFFI(DNSDistTest, AsyncTests):
       -- we need to take a copy, as we can no longer touch that data after calling set_async
       local buffer = ffi.string(responsePtr, responseSize)
 
-      print(C.dnsdist_ffi_dnsresponse_set_async(dr, asyncID, C.dnsdist_ffi_dnsquestion_get_id(dr), timeout))
+      asyncObjectsMap[C.dnsdist_ffi_dnsquestion_get_id(dr)] = getQNameRaw(dr)
+
+      C.dnsdist_ffi_dnsresponse_set_async(dr, asyncID, C.dnsdist_ffi_dnsquestion_get_id(dr), timeout)
       asyncResponderEndpoint:send(buffer)
 
       return DNSResponseAction.Allow
@@ -459,7 +519,7 @@ class TestAsyncFFI(DNSDistTest, AsyncTests):
     """
     _asyncResponderSocketPath = asyncResponderSocketPath
     _dnsdistSocketPath = dnsdistSocketPath
-    _config_params = ['_testServerPort', '_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey', '_dohWithH2OServerPort', '_serverCert', '_serverKey', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_asyncResponderSocketPath', '_dnsdistSocketPath']
+    _config_params = ['_testServerPort', '_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey', '_dohWithH2OServerPort', '_serverCert', '_serverKey', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_doqServerPort', '_serverCert', '_serverKey', '_asyncResponderSocketPath', '_dnsdistSocketPath']
     _verboseMode = True
 
 @unittest.skipIf('SKIP_DOH_TESTS' in os.environ, 'DNS over HTTPS tests are disabled')
@@ -471,6 +531,7 @@ class TestAsyncLua(DNSDistTest, AsyncTests):
     addTLSLocal("127.0.0.1:%d", "%s", "%s", { provider="openssl" })
     addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library="h2o"})
     addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library="nghttp2"})
+    addDOQLocal("127.0.0.1:%d", "%s", "%s")
 
     local filteringTagName = 'filtering'
     local filteringTagValue = 'pass'
@@ -513,16 +574,23 @@ class TestAsyncLua(DNSDistTest, AsyncTests):
       end
       if parts[2] == 'custom' then
         print('sending a custom response')
-        local raw = '\\000\\000\\128\\129\\000\\001\\000\\000\\000\\000\\000\\001\\006custom\\005async\\005tests\\008powerdns\\003com\\000\\000\\001\\000\\001\\000\\000\\041\\002\\000\\000\\000\\128\\000\\000\\000'
         local dq = asyncObject:getDQ()
+        local raw
+        if tostring(dq.qname) == 'custom.async.tests.powerdns.com.' then
+          raw = '\\000\\000\\128\\129\\000\\001\\000\\000\\000\\000\\000\\001\\006custom\\005async\\005tests\\008powerdns\\003com\\000\\000\\001\\000\\001\\000\\000\\041\\002\\000\\000\\000\\128\\000\\000\\000'
+        elseif tostring(dq.qname) == 'accept-then-custom.async.tests.powerdns.com.' then
+          raw = '\\000\\000\\128\\129\\000\\001\\000\\000\\000\\000\\000\\001\\018accept-then-custom\\005async\\005tests\\008powerdns\\003com\\000\\000\\001\\000\\001\\000\\000\\041\\002\\000\\000\\000\\128\\000\\000\\000'
+        elseif tostring(dq.qname) == 'accept-then-custom.tcp-only.async.tests.powerdns.com.' then
+          raw = '\\000\\000\\128\\129\\000\\001\\000\\000\\000\\000\\000\\001\\018accept-then-custom\\008tcp-only\\005async\\005tests\\008powerdns\\003com\\000\\000\\001\\000\\001\\000\\000\\041\\002\\000\\000\\000\\128\\000\\000\\000'
+        end
         dq:setContent(raw)
         asyncObject:resume()
         return
       end
     end
 
-    local asyncResponderEndpoint = newNetworkEndpoint('%s')
-    local listener = newNetworkListener()
+    asyncResponderEndpoint = newNetworkEndpoint('%s')
+    listener = newNetworkListener()
     listener:addUnixListeningEndpoint('%s', 0, gotAsyncResponse)
     listener:start()
 
@@ -559,5 +627,5 @@ class TestAsyncLua(DNSDistTest, AsyncTests):
     """
     _asyncResponderSocketPath = asyncResponderSocketPath
     _dnsdistSocketPath = dnsdistSocketPath
-    _config_params = ['_testServerPort', '_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey', '_dohWithH2OServerPort', '_serverCert', '_serverKey', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_asyncResponderSocketPath', '_dnsdistSocketPath']
+    _config_params = ['_testServerPort', '_testServerPort', '_tlsServerPort', '_serverCert', '_serverKey', '_dohWithH2OServerPort', '_serverCert', '_serverKey', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_doqServerPort', '_serverCert', '_serverKey', '_asyncResponderSocketPath', '_dnsdistSocketPath']
     _verboseMode = True