]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/dnsdisttests.py
Merge pull request #8713 from rgacogne/auth-strict-caches-size
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
index 83dc4b5a423566ad9fdd8bd7607184997eccc6cd..01eb5332edab98a2dc12818aba7df39f7b7a6ce4 100644 (file)
@@ -16,6 +16,8 @@ import dns.message
 import libnacl
 import libnacl.utils
 
+from eqdnsmessage import AssertEqualDNSMessageMixin
+
 # Python2/3 compatibility hacks
 try:
   from queue import Queue
@@ -28,7 +30,7 @@ except NameError:
   pass
 
 
-class DNSDistTest(unittest.TestCase):
+class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
     """
     Set up a dnsdist instance and responder threads.
     Queries sent to dnsdist are relayed to the responder threads,
@@ -55,6 +57,7 @@ class DNSDistTest(unittest.TestCase):
     _healthCheckName = 'a.root-servers.net.'
     _healthCheckCounter = 0
     _answerUnexpected = True
+    _checkConfigExpectedOutput = None
 
     @classmethod
     def startResponders(cls):
@@ -89,7 +92,10 @@ class DNSDistTest(unittest.TestCase):
             output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
         except subprocess.CalledProcessError as exc:
             raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
-        expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
+        if cls._checkConfigExpectedOutput is not None:
+          expectedOutput = cls._checkConfigExpectedOutput
+        else:
+          expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
         if output != expectedOutput:
             raise AssertionError('dnsdist --check-config failed: %s' % output)
 
@@ -176,10 +182,11 @@ class DNSDistTest(unittest.TestCase):
         return response
 
     @classmethod
-    def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False):
+    def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None):
         # trailingDataResponse=True means "ignore trailing data".
         # Other values are either False (meaning "raise an exception")
         # or are interpreted as a response RCODE for queries with trailing data.
+        # callback is invoked for every -even healthcheck ones- query and should return a raw response
         ignoreTrailing = trailingDataResponse is True
 
         sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@@ -197,20 +204,28 @@ class DNSDistTest(unittest.TestCase):
                 request = dns.message.from_wire(data, ignore_trailing=True)
                 forceRcode = trailingDataResponse
 
-            response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
-            if not response:
-                continue
+            wire = None
+            if callback:
+              wire = callback(request)
+            else:
+              response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
+              if response:
+                wire = response.to_wire()
+
+            if not wire:
+              continue
 
             sock.settimeout(2.0)
-            sock.sendto(response.to_wire(), addr)
+            sock.sendto(wire, addr)
             sock.settimeout(None)
         sock.close()
 
     @classmethod
-    def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False):
+    def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None):
         # trailingDataResponse=True means "ignore trailing data".
         # Other values are either False (meaning "raise an exception")
         # or are interpreted as a response RCODE for queries with trailing data.
+        # callback is invoked for every -even healthcheck ones- query and should return a raw response
         ignoreTrailing = trailingDataResponse is True
 
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -225,7 +240,7 @@ class DNSDistTest(unittest.TestCase):
         sock.listen(100)
         while True:
             (conn, _) = sock.accept()
-            conn.settimeout(2.0)
+            conn.settimeout(5.0)
             data = conn.recv(2)
             if not data:
                 conn.close()
@@ -243,12 +258,17 @@ class DNSDistTest(unittest.TestCase):
                 request = dns.message.from_wire(data, ignore_trailing=True)
                 forceRcode = trailingDataResponse
 
-            response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
-            if not response:
+            if callback:
+              wire = callback(request)
+            else:
+              response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
+              if response:
+                wire = response.to_wire(max_size=65535)
+
+            if not wire:
                 conn.close()
                 continue
 
-            wire = response.to_wire()
             conn.send(struct.pack("!H", len(wire)))
             conn.send(wire)
 
@@ -262,7 +282,7 @@ class DNSDistTest(unittest.TestCase):
 
                 response = copy.copy(response)
                 response.id = request.id
-                wire = response.to_wire()
+                wire = response.to_wire(max_size=65535)
                 try:
                     conn.send(struct.pack("!H", len(wire)))
                     conn.send(wire)
@@ -440,6 +460,8 @@ class DNSDistTest(unittest.TestCase):
         while not self._fromResponderQueue.empty():
             self._fromResponderQueue.get(False)
 
+        super(DNSDistTest, self).setUp()
+
     @classmethod
     def clearToResponderQueue(cls):
         while not cls._toResponderQueue.empty():
@@ -563,3 +585,4 @@ class DNSDistTest(unittest.TestCase):
 
     def checkResponseNoEDNS(self, expected, received):
         self.checkMessageNoEDNS(expected, received)
+