X-Git-Url: http://git.ipfire.org/?a=blobdiff_plain;f=regression-tests.dnsdist%2Fdnsdisttests.py;h=8765ea58445658ea0ac95288b08f19d6933c98f8;hb=d4bc1d60ffdf7dfbe01d278936bb393306aedce2;hp=83dc4b5a423566ad9fdd8bd7607184997eccc6cd;hpb=292750b58a7fcc9e1f48da8397a12cbc0a560abc;p=thirdparty%2Fpdns.git diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 83dc4b5a42..8765ea5844 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -16,6 +16,8 @@ import dns.message import libnacl import libnacl.utils +from eqdnsmessage import AssertEqualDNSMessageMixin + # Python2/3 compatibility hacks try: from queue import Queue @@ -28,7 +30,7 @@ except NameError: pass -class DNSDistTest(unittest.TestCase): +class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): """ Set up a dnsdist instance and responder threads. Queries sent to dnsdist are relayed to the responder threads, @@ -55,6 +57,7 @@ class DNSDistTest(unittest.TestCase): _healthCheckName = 'a.root-servers.net.' _healthCheckCounter = 0 _answerUnexpected = True + _checkConfigExpectedOutput = None @classmethod def startResponders(cls): @@ -89,7 +92,10 @@ class DNSDistTest(unittest.TestCase): output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True) except subprocess.CalledProcessError as exc: raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output)) - expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode() + if cls._checkConfigExpectedOutput is not None: + expectedOutput = cls._checkConfigExpectedOutput + else: + expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode() if output != expectedOutput: raise AssertionError('dnsdist --check-config failed: %s' % output) @@ -176,10 +182,11 @@ class DNSDistTest(unittest.TestCase): return response @classmethod - def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False): + def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None): # trailingDataResponse=True means "ignore trailing data". # Other values are either False (meaning "raise an exception") # or are interpreted as a response RCODE for queries with trailing data. + # callback is invoked for every -even healthcheck ones- query and should return a raw response ignoreTrailing = trailingDataResponse is True sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -197,20 +204,28 @@ class DNSDistTest(unittest.TestCase): request = dns.message.from_wire(data, ignore_trailing=True) forceRcode = trailingDataResponse - response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) - if not response: - continue + wire = None + if callback: + wire = callback(request) + else: + response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) + if response: + wire = response.to_wire() + + if not wire: + continue sock.settimeout(2.0) - sock.sendto(response.to_wire(), addr) + sock.sendto(wire, addr) sock.settimeout(None) sock.close() @classmethod - def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False): + def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None): # trailingDataResponse=True means "ignore trailing data". # Other values are either False (meaning "raise an exception") # or are interpreted as a response RCODE for queries with trailing data. + # callback is invoked for every -even healthcheck ones- query and should return a raw response ignoreTrailing = trailingDataResponse is True sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -225,7 +240,7 @@ class DNSDistTest(unittest.TestCase): sock.listen(100) while True: (conn, _) = sock.accept() - conn.settimeout(2.0) + conn.settimeout(5.0) data = conn.recv(2) if not data: conn.close() @@ -243,12 +258,17 @@ class DNSDistTest(unittest.TestCase): request = dns.message.from_wire(data, ignore_trailing=True) forceRcode = trailingDataResponse - response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) - if not response: + if callback: + wire = callback(request) + else: + response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) + if response: + wire = response.to_wire(max_size=65535) + + if not wire: conn.close() continue - wire = response.to_wire() conn.send(struct.pack("!H", len(wire))) conn.send(wire) @@ -262,7 +282,7 @@ class DNSDistTest(unittest.TestCase): response = copy.copy(response) response.id = request.id - wire = response.to_wire() + wire = response.to_wire(max_size=65535) try: conn.send(struct.pack("!H", len(wire))) conn.send(wire) @@ -440,6 +460,8 @@ class DNSDistTest(unittest.TestCase): while not self._fromResponderQueue.empty(): self._fromResponderQueue.get(False) + super(DNSDistTest, self).setUp() + @classmethod def clearToResponderQueue(cls): while not cls._toResponderQueue.empty(): @@ -530,6 +552,9 @@ class DNSDistTest(unittest.TestCase): if withCookies: for option in received.options: self.assertEquals(option.otype, 10) + else: + for option in received.options: + self.assertNotEquals(option.otype, 10) def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0): self.assertEquals(expected, received) @@ -563,3 +588,4 @@ class DNSDistTest(unittest.TestCase): def checkResponseNoEDNS(self, expected, received): self.checkMessageNoEDNS(expected, received) +