import libnacl
import libnacl.utils
+from eqdnsmessage import AssertEqualDNSMessageMixin
+
# Python2/3 compatibility hacks
try:
from queue import Queue
pass
-class DNSDistTest(unittest.TestCase):
+class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
"""
Set up a dnsdist instance and responder threads.
Queries sent to dnsdist are relayed to the responder threads,
_healthCheckName = 'a.root-servers.net.'
_healthCheckCounter = 0
_answerUnexpected = True
+ _checkConfigExpectedOutput = None
@classmethod
def startResponders(cls):
output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
except subprocess.CalledProcessError as exc:
raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
- expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
+ if cls._checkConfigExpectedOutput is not None:
+ expectedOutput = cls._checkConfigExpectedOutput
+ else:
+ expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
if output != expectedOutput:
raise AssertionError('dnsdist --check-config failed: %s' % output)
return response
@classmethod
- def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False):
+ def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None):
# trailingDataResponse=True means "ignore trailing data".
# Other values are either False (meaning "raise an exception")
# or are interpreted as a response RCODE for queries with trailing data.
+ # callback is invoked for every -even healthcheck ones- query and should return a raw response
ignoreTrailing = trailingDataResponse is True
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
request = dns.message.from_wire(data, ignore_trailing=True)
forceRcode = trailingDataResponse
- response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
- if not response:
- continue
+ wire = None
+ if callback:
+ wire = callback(request)
+ else:
+ response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
+ if response:
+ wire = response.to_wire()
+
+ if not wire:
+ continue
sock.settimeout(2.0)
- sock.sendto(response.to_wire(), addr)
+ sock.sendto(wire, addr)
sock.settimeout(None)
sock.close()
@classmethod
- def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False):
+ def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None):
# trailingDataResponse=True means "ignore trailing data".
# Other values are either False (meaning "raise an exception")
# or are interpreted as a response RCODE for queries with trailing data.
+ # callback is invoked for every -even healthcheck ones- query and should return a raw response
ignoreTrailing = trailingDataResponse is True
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.listen(100)
while True:
(conn, _) = sock.accept()
- conn.settimeout(2.0)
+ conn.settimeout(5.0)
data = conn.recv(2)
if not data:
conn.close()
request = dns.message.from_wire(data, ignore_trailing=True)
forceRcode = trailingDataResponse
- response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
- if not response:
+ if callback:
+ wire = callback(request)
+ else:
+ response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
+ if response:
+ wire = response.to_wire(max_size=65535)
+
+ if not wire:
conn.close()
continue
- wire = response.to_wire()
conn.send(struct.pack("!H", len(wire)))
conn.send(wire)
response = copy.copy(response)
response.id = request.id
- wire = response.to_wire()
+ wire = response.to_wire(max_size=65535)
try:
conn.send(struct.pack("!H", len(wire)))
conn.send(wire)
while not self._fromResponderQueue.empty():
self._fromResponderQueue.get(False)
+ super(DNSDistTest, self).setUp()
+
@classmethod
def clearToResponderQueue(cls):
while not cls._toResponderQueue.empty():
def checkResponseNoEDNS(self, expected, received):
self.checkMessageNoEDNS(expected, received)
+