]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/dnsdisttests.py
dnsdist: Remove flaky healthcheck regression test
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
1 #!/usr/bin/env python2
2
3 import copy
4 import os
5 import socket
6 import ssl
7 import struct
8 import subprocess
9 import sys
10 import threading
11 import time
12 import unittest
13 import clientsubnetoption
14 import dns
15 import dns.message
16 import libnacl
17 import libnacl.utils
18
19 # Python2/3 compatibility hacks
20 try:
21 from queue import Queue
22 except ImportError:
23 from Queue import Queue
24
25 try:
26 range = xrange
27 except NameError:
28 pass
29
30
31 class DNSDistTest(unittest.TestCase):
32 """
33 Set up a dnsdist instance and responder threads.
34 Queries sent to dnsdist are relayed to the responder threads,
35 who reply with the response provided by the tests themselves
36 on a queue. Responder threads also queue the queries received
37 from dnsdist on a separate queue, allowing the tests to check
38 that the queries sent from dnsdist were as expected.
39 """
40 _dnsDistPort = 5340
41 _dnsDistListeningAddr = "127.0.0.1"
42 _testServerPort = 5350
43 _toResponderQueue = Queue()
44 _fromResponderQueue = Queue()
45 _queueTimeout = 1
46 _dnsdistStartupDelay = 2.0
47 _dnsdist = None
48 _responsesCounter = {}
49 _config_template = """
50 """
51 _config_params = ['_testServerPort']
52 _acl = ['127.0.0.1/32']
53 _consolePort = 5199
54 _consoleKey = None
55 _healthCheckName = 'a.root-servers.net.'
56 _healthCheckCounter = 0
57 _answerUnexpected = True
58
59 @classmethod
60 def startResponders(cls):
61 print("Launching responders..")
62
63 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
64 cls._UDPResponder.setDaemon(True)
65 cls._UDPResponder.start()
66 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
67 cls._TCPResponder.setDaemon(True)
68 cls._TCPResponder.start()
69
70 @classmethod
71 def startDNSDist(cls):
72 print("Launching dnsdist..")
73 confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__))
74 params = tuple([getattr(cls, param) for param in cls._config_params])
75 print(params)
76 with open(confFile, 'w') as conf:
77 conf.write("-- Autogenerated by dnsdisttests.py\n")
78 conf.write(cls._config_template % params)
79
80 dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile,
81 '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
82 for acl in cls._acl:
83 dnsdistcmd.extend(['--acl', acl])
84 print(' '.join(dnsdistcmd))
85
86 # validate config with --check-config, which sets client=true, possibly exposing bugs.
87 testcmd = dnsdistcmd + ['--check-config']
88 try:
89 output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
90 except subprocess.CalledProcessError as exc:
91 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
92 expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
93 if output != expectedOutput:
94 raise AssertionError('dnsdist --check-config failed: %s' % output)
95
96 logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__))
97 with open(logFile, 'w') as fdLog:
98 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog)
99
100 if 'DNSDIST_FAST_TESTS' in os.environ:
101 delay = 0.5
102 else:
103 delay = cls._dnsdistStartupDelay
104
105 time.sleep(delay)
106
107 if cls._dnsdist.poll() is not None:
108 cls._dnsdist.kill()
109 sys.exit(cls._dnsdist.returncode)
110
111 @classmethod
112 def setUpSockets(cls):
113 print("Setting up UDP socket..")
114 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
115 cls._sock.settimeout(2.0)
116 cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
117
118 @classmethod
119 def setUpClass(cls):
120
121 cls.startResponders()
122 cls.startDNSDist()
123 cls.setUpSockets()
124
125 print("Launching tests..")
126
127 @classmethod
128 def tearDownClass(cls):
129 if 'DNSDIST_FAST_TESTS' in os.environ:
130 delay = 0.1
131 else:
132 delay = 1.0
133 if cls._dnsdist:
134 cls._dnsdist.terminate()
135 if cls._dnsdist.poll() is None:
136 time.sleep(delay)
137 if cls._dnsdist.poll() is None:
138 cls._dnsdist.kill()
139 cls._dnsdist.wait()
140
141 @classmethod
142 def _ResponderIncrementCounter(cls):
143 if threading.currentThread().name in cls._responsesCounter:
144 cls._responsesCounter[threading.currentThread().name] += 1
145 else:
146 cls._responsesCounter[threading.currentThread().name] = 1
147
148 @classmethod
149 def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
150 response = None
151 if len(request.question) != 1:
152 print("Skipping query with question count %d" % (len(request.question)))
153 return None
154 healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
155 if healthCheck:
156 cls._healthCheckCounter += 1
157 response = dns.message.make_response(request)
158 else:
159 cls._ResponderIncrementCounter()
160 if not fromQueue.empty():
161 toQueue.put(request, True, cls._queueTimeout)
162 if synthesize is None:
163 response = fromQueue.get(True, cls._queueTimeout)
164 if response:
165 response = copy.copy(response)
166 response.id = request.id
167
168 if not response:
169 if synthesize is not None:
170 response = dns.message.make_response(request)
171 response.set_rcode(synthesize)
172 elif cls._answerUnexpected:
173 response = dns.message.make_response(request)
174 response.set_rcode(dns.rcode.SERVFAIL)
175
176 return response
177
178 @classmethod
179 def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None):
180 # trailingDataResponse=True means "ignore trailing data".
181 # Other values are either False (meaning "raise an exception")
182 # or are interpreted as a response RCODE for queries with trailing data.
183 # callback is invoked for every -even healthcheck ones- query and should return a raw response
184 ignoreTrailing = trailingDataResponse is True
185
186 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
187 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
188 sock.bind(("127.0.0.1", port))
189 while True:
190 data, addr = sock.recvfrom(4096)
191 forceRcode = None
192 try:
193 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
194 except dns.message.TrailingJunk as e:
195 if trailingDataResponse is False or forceRcode is True:
196 raise
197 print("UDP query with trailing data, synthesizing response")
198 request = dns.message.from_wire(data, ignore_trailing=True)
199 forceRcode = trailingDataResponse
200
201 wire = None
202 if callback:
203 wire = callback(request)
204 else:
205 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
206 if response:
207 wire = response.to_wire()
208
209 if not wire:
210 continue
211
212 sock.settimeout(2.0)
213 sock.sendto(wire, addr)
214 sock.settimeout(None)
215 sock.close()
216
217 @classmethod
218 def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None):
219 # trailingDataResponse=True means "ignore trailing data".
220 # Other values are either False (meaning "raise an exception")
221 # or are interpreted as a response RCODE for queries with trailing data.
222 # callback is invoked for every -even healthcheck ones- query and should return a raw response
223 ignoreTrailing = trailingDataResponse is True
224
225 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
226 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
227 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
228 try:
229 sock.bind(("127.0.0.1", port))
230 except socket.error as e:
231 print("Error binding in the TCP responder: %s" % str(e))
232 sys.exit(1)
233
234 sock.listen(100)
235 while True:
236 (conn, _) = sock.accept()
237 conn.settimeout(5.0)
238 data = conn.recv(2)
239 if not data:
240 conn.close()
241 continue
242
243 (datalen,) = struct.unpack("!H", data)
244 data = conn.recv(datalen)
245 forceRcode = None
246 try:
247 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
248 except dns.message.TrailingJunk as e:
249 if trailingDataResponse is False or forceRcode is True:
250 raise
251 print("TCP query with trailing data, synthesizing response")
252 request = dns.message.from_wire(data, ignore_trailing=True)
253 forceRcode = trailingDataResponse
254
255 if callback:
256 wire = callback(request)
257 else:
258 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
259 if response:
260 wire = response.to_wire(max_size=65535)
261
262 if not wire:
263 conn.close()
264 continue
265
266 conn.send(struct.pack("!H", len(wire)))
267 conn.send(wire)
268
269 while multipleResponses:
270 if fromQueue.empty():
271 break
272
273 response = fromQueue.get(True, cls._queueTimeout)
274 if not response:
275 break
276
277 response = copy.copy(response)
278 response.id = request.id
279 wire = response.to_wire(max_size=65535)
280 try:
281 conn.send(struct.pack("!H", len(wire)))
282 conn.send(wire)
283 except socket.error as e:
284 # some of the tests are going to close
285 # the connection on us, just deal with it
286 break
287
288 conn.close()
289
290 sock.close()
291
292 @classmethod
293 def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
294 if useQueue:
295 cls._toResponderQueue.put(response, True, timeout)
296
297 if timeout:
298 cls._sock.settimeout(timeout)
299
300 try:
301 if not rawQuery:
302 query = query.to_wire()
303 cls._sock.send(query)
304 data = cls._sock.recv(4096)
305 except socket.timeout:
306 data = None
307 finally:
308 if timeout:
309 cls._sock.settimeout(None)
310
311 receivedQuery = None
312 message = None
313 if useQueue and not cls._fromResponderQueue.empty():
314 receivedQuery = cls._fromResponderQueue.get(True, timeout)
315 if data:
316 message = dns.message.from_wire(data)
317 return (receivedQuery, message)
318
319 @classmethod
320 def openTCPConnection(cls, timeout=None):
321 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
322 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
323 if timeout:
324 sock.settimeout(timeout)
325
326 sock.connect(("127.0.0.1", cls._dnsDistPort))
327 return sock
328
329 @classmethod
330 def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
331 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
332 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
333 if timeout:
334 sock.settimeout(timeout)
335
336 # 2.7.9+
337 if hasattr(ssl, 'create_default_context'):
338 sslctx = ssl.create_default_context(cafile=caCert)
339 sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
340 else:
341 sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
342
343 sslsock.connect(("127.0.0.1", port))
344 return sslsock
345
346 @classmethod
347 def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0):
348 if not rawQuery:
349 wire = query.to_wire()
350 else:
351 wire = query
352
353 if response:
354 cls._toResponderQueue.put(response, True, timeout)
355
356 sock.send(struct.pack("!H", len(wire)))
357 sock.send(wire)
358
359 @classmethod
360 def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
361 message = None
362 data = sock.recv(2)
363 if data:
364 (datalen,) = struct.unpack("!H", data)
365 data = sock.recv(datalen)
366 if data:
367 message = dns.message.from_wire(data)
368
369 if useQueue and not cls._fromResponderQueue.empty():
370 receivedQuery = cls._fromResponderQueue.get(True, timeout)
371 return (receivedQuery, message)
372 else:
373 return message
374
375 @classmethod
376 def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
377 message = None
378 if useQueue:
379 cls._toResponderQueue.put(response, True, timeout)
380
381 sock = cls.openTCPConnection(timeout)
382
383 try:
384 cls.sendTCPQueryOverConnection(sock, query, rawQuery)
385 message = cls.recvTCPResponseOverConnection(sock)
386 except socket.timeout as e:
387 print("Timeout: %s" % (str(e)))
388 except socket.error as e:
389 print("Network error: %s" % (str(e)))
390 finally:
391 sock.close()
392
393 receivedQuery = None
394 if useQueue and not cls._fromResponderQueue.empty():
395 receivedQuery = cls._fromResponderQueue.get(True, timeout)
396
397 return (receivedQuery, message)
398
399 @classmethod
400 def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False):
401 if useQueue:
402 for response in responses:
403 cls._toResponderQueue.put(response, True, timeout)
404 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
405 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
406 if timeout:
407 sock.settimeout(timeout)
408
409 sock.connect(("127.0.0.1", cls._dnsDistPort))
410 messages = []
411
412 try:
413 if not rawQuery:
414 wire = query.to_wire()
415 else:
416 wire = query
417
418 sock.send(struct.pack("!H", len(wire)))
419 sock.send(wire)
420 while True:
421 data = sock.recv(2)
422 if not data:
423 break
424 (datalen,) = struct.unpack("!H", data)
425 data = sock.recv(datalen)
426 messages.append(dns.message.from_wire(data))
427
428 except socket.timeout as e:
429 print("Timeout: %s" % (str(e)))
430 except socket.error as e:
431 print("Network error: %s" % (str(e)))
432 finally:
433 sock.close()
434
435 receivedQuery = None
436 if useQueue and not cls._fromResponderQueue.empty():
437 receivedQuery = cls._fromResponderQueue.get(True, timeout)
438 return (receivedQuery, messages)
439
440 def setUp(self):
441 # This function is called before every tests
442
443 # Clear the responses counters
444 for key in self._responsesCounter:
445 self._responsesCounter[key] = 0
446
447 self._healthCheckCounter = 0
448
449 # Make sure the queues are empty, in case
450 # a previous test failed
451 while not self._toResponderQueue.empty():
452 self._toResponderQueue.get(False)
453
454 while not self._fromResponderQueue.empty():
455 self._fromResponderQueue.get(False)
456
457 @classmethod
458 def clearToResponderQueue(cls):
459 while not cls._toResponderQueue.empty():
460 cls._toResponderQueue.get(False)
461
462 @classmethod
463 def clearFromResponderQueue(cls):
464 while not cls._fromResponderQueue.empty():
465 cls._fromResponderQueue.get(False)
466
467 @classmethod
468 def clearResponderQueues(cls):
469 cls.clearToResponderQueue()
470 cls.clearFromResponderQueue()
471
472 @staticmethod
473 def generateConsoleKey():
474 return libnacl.utils.salsa_key()
475
476 @classmethod
477 def _encryptConsole(cls, command, nonce):
478 command = command.encode('UTF-8')
479 if cls._consoleKey is None:
480 return command
481 return libnacl.crypto_secretbox(command, nonce, cls._consoleKey)
482
483 @classmethod
484 def _decryptConsole(cls, command, nonce):
485 if cls._consoleKey is None:
486 result = command
487 else:
488 result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey)
489 return result.decode('UTF-8')
490
491 @classmethod
492 def sendConsoleCommand(cls, command, timeout=1.0):
493 ourNonce = libnacl.utils.rand_nonce()
494 theirNonce = None
495 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
496 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
497 if timeout:
498 sock.settimeout(timeout)
499
500 sock.connect(("127.0.0.1", cls._consolePort))
501 sock.send(ourNonce)
502 theirNonce = sock.recv(len(ourNonce))
503 if len(theirNonce) != len(ourNonce):
504 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce)))
505 if len(theirNonce) == 0:
506 raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce)))
507 return None
508
509 halfNonceSize = int(len(ourNonce) / 2)
510 readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:]
511 writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:]
512 msg = cls._encryptConsole(command, writingNonce)
513 sock.send(struct.pack("!I", len(msg)))
514 sock.send(msg)
515 data = sock.recv(4)
516 if not data:
517 raise socket.error("Got EOF while reading the response size")
518
519 (responseLen,) = struct.unpack("!I", data)
520 data = sock.recv(responseLen)
521 response = cls._decryptConsole(data, readingNonce)
522 return response
523
524 def compareOptions(self, a, b):
525 self.assertEquals(len(a), len(b))
526 for idx in range(len(a)):
527 self.assertEquals(a[idx], b[idx])
528
529 def checkMessageNoEDNS(self, expected, received):
530 self.assertEquals(expected, received)
531 self.assertEquals(received.edns, -1)
532 self.assertEquals(len(received.options), 0)
533
534 def checkMessageEDNSWithoutOptions(self, expected, received):
535 self.assertEquals(expected, received)
536 self.assertEquals(received.edns, 0)
537 self.assertEquals(expected.payload, received.payload)
538
539 def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
540 self.assertEquals(expected, received)
541 self.assertEquals(received.edns, 0)
542 self.assertEquals(expected.payload, received.payload)
543 self.assertEquals(len(received.options), withCookies)
544 if withCookies:
545 for option in received.options:
546 self.assertEquals(option.otype, 10)
547
548 def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
549 self.assertEquals(expected, received)
550 self.assertEquals(received.edns, 0)
551 self.assertEquals(expected.payload, received.payload)
552 self.assertEquals(len(received.options), 1 + additionalOptions)
553 hasECS = False
554 for option in received.options:
555 if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE:
556 hasECS = True
557 else:
558 self.assertNotEquals(additionalOptions, 0)
559
560 self.compareOptions(expected.options, received.options)
561 self.assertTrue(hasECS)
562
563 def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0):
564 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
565
566 def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0):
567 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
568
569 def checkQueryEDNSWithoutECS(self, expected, received):
570 self.checkMessageEDNSWithoutECS(expected, received)
571
572 def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0):
573 self.checkMessageEDNSWithoutECS(expected, received, withCookies)
574
575 def checkQueryNoEDNS(self, expected, received):
576 self.checkMessageNoEDNS(expected, received)
577
578 def checkResponseNoEDNS(self, expected, received):
579 self.checkMessageNoEDNS(expected, received)