]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/dnsdisttests.py
Merge pull request #8945 from rgacogne/ddist-x-forwarded-for
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
index f5e0b92ee674c1ea8986b7c7658556b3714abd76..8765ea58445658ea0ac95288b08f19d6933c98f8 100644 (file)
@@ -16,6 +16,8 @@ import dns.message
 import libnacl
 import libnacl.utils
 
+from eqdnsmessage import AssertEqualDNSMessageMixin
+
 # Python2/3 compatibility hacks
 try:
   from queue import Queue
@@ -28,7 +30,7 @@ except NameError:
   pass
 
 
-class DNSDistTest(unittest.TestCase):
+class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
     """
     Set up a dnsdist instance and responder threads.
     Queries sent to dnsdist are relayed to the responder threads,
@@ -55,6 +57,7 @@ class DNSDistTest(unittest.TestCase):
     _healthCheckName = 'a.root-servers.net.'
     _healthCheckCounter = 0
     _answerUnexpected = True
+    _checkConfigExpectedOutput = None
 
     @classmethod
     def startResponders(cls):
@@ -89,7 +92,10 @@ class DNSDistTest(unittest.TestCase):
             output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
         except subprocess.CalledProcessError as exc:
             raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
-        expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
+        if cls._checkConfigExpectedOutput is not None:
+          expectedOutput = cls._checkConfigExpectedOutput
+        else:
+          expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
         if output != expectedOutput:
             raise AssertionError('dnsdist --check-config failed: %s' % output)
 
@@ -454,6 +460,8 @@ class DNSDistTest(unittest.TestCase):
         while not self._fromResponderQueue.empty():
             self._fromResponderQueue.get(False)
 
+        super(DNSDistTest, self).setUp()
+
     @classmethod
     def clearToResponderQueue(cls):
         while not cls._toResponderQueue.empty():
@@ -544,6 +552,9 @@ class DNSDistTest(unittest.TestCase):
         if withCookies:
             for option in received.options:
                 self.assertEquals(option.otype, 10)
+        else:
+            for option in received.options:
+                self.assertNotEquals(option.otype, 10)
 
     def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
         self.assertEquals(expected, received)
@@ -577,3 +588,4 @@ class DNSDistTest(unittest.TestCase):
 
     def checkResponseNoEDNS(self, expected, received):
         self.checkMessageNoEDNS(expected, received)
+