_dnsdistStartupDelay = 2.0
_dnsdist = None
_responsesCounter = {}
- _shutUp = True
_config_template = """
"""
_config_params = ['_testServerPort']
cls._TCPResponder.start()
@classmethod
- def startDNSDist(cls, shutUp=True):
+ def startDNSDist(cls):
print("Launching dnsdist..")
- conffile = 'dnsdist_test.conf'
+ confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__))
params = tuple([getattr(cls, param) for param in cls._config_params])
print(params)
- with open(conffile, 'w') as conf:
+ with open(confFile, 'w') as conf:
conf.write("-- Autogenerated by dnsdisttests.py\n")
conf.write(cls._config_template % params)
- dnsdistcmd = [os.environ['DNSDISTBIN'], '-C', conffile,
+ dnsdistcmd = [os.environ['DNSDISTBIN'], '-C', confFile,
'-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
for acl in cls._acl:
dnsdistcmd.extend(['--acl', acl])
output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
except subprocess.CalledProcessError as exc:
raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
- if output != b'Configuration \'dnsdist_test.conf\' OK!\n':
+ expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
+ if output != expectedOutput:
raise AssertionError('dnsdist --check-config failed: %s' % output)
- if shutUp:
- with open(os.devnull, 'w') as fdDevNull:
- cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdDevNull)
- else:
- cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True)
+ logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__))
+ with open(logFile, 'w') as fdLog:
+ cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog)
if 'DNSDIST_FAST_TESTS' in os.environ:
delay = 0.5
def setUpClass(cls):
cls.startResponders()
- cls.startDNSDist(cls._shutUp)
+ cls.startDNSDist()
cls.setUpSockets()
print("Launching tests..")
cls._responsesCounter[threading.currentThread().name] = 1
@classmethod
- def _getResponse(cls, request, fromQueue, toQueue):
+ def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
response = None
if len(request.question) != 1:
print("Skipping query with question count %d" % (len(request.question)))
healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
if healthCheck:
cls._healthCheckCounter += 1
+ response = dns.message.make_response(request)
else:
cls._ResponderIncrementCounter()
if not fromQueue.empty():
- response = fromQueue.get(True, cls._queueTimeout)
- if response:
- response = copy.copy(response)
- response.id = request.id
- toQueue.put(request, True, cls._queueTimeout)
+ toQueue.put(request, True, cls._queueTimeout)
+ if synthesize is None:
+ response = fromQueue.get(True, cls._queueTimeout)
+ if response:
+ response = copy.copy(response)
+ response.id = request.id
if not response:
- if healthCheck:
+ if synthesize is not None:
response = dns.message.make_response(request)
+ response.set_rcode(synthesize)
elif cls._answerUnexpected:
response = dns.message.make_response(request)
response.set_rcode(dns.rcode.SERVFAIL)
return response
@classmethod
- def UDPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False):
+ def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False):
+ # trailingDataResponse=True means "ignore trailing data".
+ # Other values are either False (meaning "raise an exception")
+ # or are interpreted as a response RCODE for queries with trailing data.
+ ignoreTrailing = trailingDataResponse is True
+
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.bind(("127.0.0.1", port))
while True:
data, addr = sock.recvfrom(4096)
- request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
- response = cls._getResponse(request, fromQueue, toQueue)
-
+ forceRcode = None
+ try:
+ request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
+ except dns.message.TrailingJunk as e:
+ if trailingDataResponse is False or forceRcode is True:
+ raise
+ print("UDP query with trailing data, synthesizing response")
+ request = dns.message.from_wire(data, ignore_trailing=True)
+ forceRcode = trailingDataResponse
+
+ response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
if not response:
continue
sock.close()
@classmethod
- def TCPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False, multipleResponses=False):
+ def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False):
+ # trailingDataResponse=True means "ignore trailing data".
+ # Other values are either False (meaning "raise an exception")
+ # or are interpreted as a response RCODE for queries with trailing data.
+ ignoreTrailing = trailingDataResponse is True
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
try:
(datalen,) = struct.unpack("!H", data)
data = conn.recv(datalen)
- request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
- response = cls._getResponse(request, fromQueue, toQueue)
-
+ forceRcode = None
+ try:
+ request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
+ except dns.message.TrailingJunk as e:
+ if trailingDataResponse is False or forceRcode is True:
+ raise
+ print("TCP query with trailing data, synthesizing response")
+ request = dns.message.from_wire(data, ignore_trailing=True)
+ forceRcode = trailingDataResponse
+
+ response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
if not response:
conn.close()
continue