]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/dnsdisttests.py
under test, run dnsdist --supervised so EOF on stdin does not exit it
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
1 #!/usr/bin/env python2
2
3 import copy
4 import os
5 import socket
6 import ssl
7 import struct
8 import subprocess
9 import sys
10 import threading
11 import time
12 import unittest
13 import clientsubnetoption
14 import dns
15 import dns.message
16 import libnacl
17 import libnacl.utils
18
19 # Python2/3 compatibility hacks
20 try:
21 from queue import Queue
22 except ImportError:
23 from Queue import Queue
24
25 try:
26 range = xrange
27 except NameError:
28 pass
29
30
31 class DNSDistTest(unittest.TestCase):
32 """
33 Set up a dnsdist instance and responder threads.
34 Queries sent to dnsdist are relayed to the responder threads,
35 who reply with the response provided by the tests themselves
36 on a queue. Responder threads also queue the queries received
37 from dnsdist on a separate queue, allowing the tests to check
38 that the queries sent from dnsdist were as expected.
39 """
40 _dnsDistPort = 5340
41 _dnsDistListeningAddr = "127.0.0.1"
42 _testServerPort = 5350
43 _toResponderQueue = Queue()
44 _fromResponderQueue = Queue()
45 _queueTimeout = 1
46 _dnsdistStartupDelay = 2.0
47 _dnsdist = None
48 _responsesCounter = {}
49 _config_template = """
50 """
51 _config_params = ['_testServerPort']
52 _acl = ['127.0.0.1/32']
53 _consolePort = 5199
54 _consoleKey = None
55 _healthCheckName = 'a.root-servers.net.'
56 _healthCheckCounter = 0
57 _answerUnexpected = True
58
59 @classmethod
60 def startResponders(cls):
61 print("Launching responders..")
62
63 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
64 cls._UDPResponder.setDaemon(True)
65 cls._UDPResponder.start()
66 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
67 cls._TCPResponder.setDaemon(True)
68 cls._TCPResponder.start()
69
70 @classmethod
71 def startDNSDist(cls):
72 print("Launching dnsdist..")
73 confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__))
74 params = tuple([getattr(cls, param) for param in cls._config_params])
75 print(params)
76 with open(confFile, 'w') as conf:
77 conf.write("-- Autogenerated by dnsdisttests.py\n")
78 conf.write(cls._config_template % params)
79
80 dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile,
81 '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
82 for acl in cls._acl:
83 dnsdistcmd.extend(['--acl', acl])
84 print(' '.join(dnsdistcmd))
85
86 # validate config with --check-config, which sets client=true, possibly exposing bugs.
87 testcmd = dnsdistcmd + ['--check-config']
88 try:
89 output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
90 except subprocess.CalledProcessError as exc:
91 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
92 expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
93 if output != expectedOutput:
94 raise AssertionError('dnsdist --check-config failed: %s' % output)
95
96 logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__))
97 with open(logFile, 'w') as fdLog:
98 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog)
99
100 if 'DNSDIST_FAST_TESTS' in os.environ:
101 delay = 0.5
102 else:
103 delay = cls._dnsdistStartupDelay
104
105 time.sleep(delay)
106
107 if cls._dnsdist.poll() is not None:
108 cls._dnsdist.kill()
109 sys.exit(cls._dnsdist.returncode)
110
111 @classmethod
112 def setUpSockets(cls):
113 print("Setting up UDP socket..")
114 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
115 cls._sock.settimeout(2.0)
116 cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
117
118 @classmethod
119 def setUpClass(cls):
120
121 cls.startResponders()
122 cls.startDNSDist()
123 cls.setUpSockets()
124
125 print("Launching tests..")
126
127 @classmethod
128 def tearDownClass(cls):
129 if 'DNSDIST_FAST_TESTS' in os.environ:
130 delay = 0.1
131 else:
132 delay = 1.0
133 if cls._dnsdist:
134 cls._dnsdist.terminate()
135 if cls._dnsdist.poll() is None:
136 time.sleep(delay)
137 if cls._dnsdist.poll() is None:
138 cls._dnsdist.kill()
139 cls._dnsdist.wait()
140
141 @classmethod
142 def _ResponderIncrementCounter(cls):
143 if threading.currentThread().name in cls._responsesCounter:
144 cls._responsesCounter[threading.currentThread().name] += 1
145 else:
146 cls._responsesCounter[threading.currentThread().name] = 1
147
148 @classmethod
149 def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
150 response = None
151 if len(request.question) != 1:
152 print("Skipping query with question count %d" % (len(request.question)))
153 return None
154 healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
155 if healthCheck:
156 cls._healthCheckCounter += 1
157 response = dns.message.make_response(request)
158 else:
159 cls._ResponderIncrementCounter()
160 if not fromQueue.empty():
161 toQueue.put(request, True, cls._queueTimeout)
162 if synthesize is None:
163 response = fromQueue.get(True, cls._queueTimeout)
164 if response:
165 response = copy.copy(response)
166 response.id = request.id
167
168 if not response:
169 if synthesize is not None:
170 response = dns.message.make_response(request)
171 response.set_rcode(synthesize)
172 elif cls._answerUnexpected:
173 response = dns.message.make_response(request)
174 response.set_rcode(dns.rcode.SERVFAIL)
175
176 return response
177
178 @classmethod
179 def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False):
180 # trailingDataResponse=True means "ignore trailing data".
181 # Other values are either False (meaning "raise an exception")
182 # or are interpreted as a response RCODE for queries with trailing data.
183 ignoreTrailing = trailingDataResponse is True
184
185 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
186 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
187 sock.bind(("127.0.0.1", port))
188 while True:
189 data, addr = sock.recvfrom(4096)
190 forceRcode = None
191 try:
192 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
193 except dns.message.TrailingJunk as e:
194 if trailingDataResponse is False or forceRcode is True:
195 raise
196 print("UDP query with trailing data, synthesizing response")
197 request = dns.message.from_wire(data, ignore_trailing=True)
198 forceRcode = trailingDataResponse
199
200 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
201 if not response:
202 continue
203
204 sock.settimeout(2.0)
205 sock.sendto(response.to_wire(), addr)
206 sock.settimeout(None)
207 sock.close()
208
209 @classmethod
210 def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False):
211 # trailingDataResponse=True means "ignore trailing data".
212 # Other values are either False (meaning "raise an exception")
213 # or are interpreted as a response RCODE for queries with trailing data.
214 ignoreTrailing = trailingDataResponse is True
215
216 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
217 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
218 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
219 try:
220 sock.bind(("127.0.0.1", port))
221 except socket.error as e:
222 print("Error binding in the TCP responder: %s" % str(e))
223 sys.exit(1)
224
225 sock.listen(100)
226 while True:
227 (conn, _) = sock.accept()
228 conn.settimeout(2.0)
229 data = conn.recv(2)
230 if not data:
231 conn.close()
232 continue
233
234 (datalen,) = struct.unpack("!H", data)
235 data = conn.recv(datalen)
236 forceRcode = None
237 try:
238 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
239 except dns.message.TrailingJunk as e:
240 if trailingDataResponse is False or forceRcode is True:
241 raise
242 print("TCP query with trailing data, synthesizing response")
243 request = dns.message.from_wire(data, ignore_trailing=True)
244 forceRcode = trailingDataResponse
245
246 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
247 if not response:
248 conn.close()
249 continue
250
251 wire = response.to_wire()
252 conn.send(struct.pack("!H", len(wire)))
253 conn.send(wire)
254
255 while multipleResponses:
256 if fromQueue.empty():
257 break
258
259 response = fromQueue.get(True, cls._queueTimeout)
260 if not response:
261 break
262
263 response = copy.copy(response)
264 response.id = request.id
265 wire = response.to_wire()
266 try:
267 conn.send(struct.pack("!H", len(wire)))
268 conn.send(wire)
269 except socket.error as e:
270 # some of the tests are going to close
271 # the connection on us, just deal with it
272 break
273
274 conn.close()
275
276 sock.close()
277
278 @classmethod
279 def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
280 if useQueue:
281 cls._toResponderQueue.put(response, True, timeout)
282
283 if timeout:
284 cls._sock.settimeout(timeout)
285
286 try:
287 if not rawQuery:
288 query = query.to_wire()
289 cls._sock.send(query)
290 data = cls._sock.recv(4096)
291 except socket.timeout:
292 data = None
293 finally:
294 if timeout:
295 cls._sock.settimeout(None)
296
297 receivedQuery = None
298 message = None
299 if useQueue and not cls._fromResponderQueue.empty():
300 receivedQuery = cls._fromResponderQueue.get(True, timeout)
301 if data:
302 message = dns.message.from_wire(data)
303 return (receivedQuery, message)
304
305 @classmethod
306 def openTCPConnection(cls, timeout=None):
307 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
308 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
309 if timeout:
310 sock.settimeout(timeout)
311
312 sock.connect(("127.0.0.1", cls._dnsDistPort))
313 return sock
314
315 @classmethod
316 def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
317 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
318 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
319 if timeout:
320 sock.settimeout(timeout)
321
322 # 2.7.9+
323 if hasattr(ssl, 'create_default_context'):
324 sslctx = ssl.create_default_context(cafile=caCert)
325 sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
326 else:
327 sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
328
329 sslsock.connect(("127.0.0.1", port))
330 return sslsock
331
332 @classmethod
333 def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0):
334 if not rawQuery:
335 wire = query.to_wire()
336 else:
337 wire = query
338
339 if response:
340 cls._toResponderQueue.put(response, True, timeout)
341
342 sock.send(struct.pack("!H", len(wire)))
343 sock.send(wire)
344
345 @classmethod
346 def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
347 message = None
348 data = sock.recv(2)
349 if data:
350 (datalen,) = struct.unpack("!H", data)
351 data = sock.recv(datalen)
352 if data:
353 message = dns.message.from_wire(data)
354
355 if useQueue and not cls._fromResponderQueue.empty():
356 receivedQuery = cls._fromResponderQueue.get(True, timeout)
357 return (receivedQuery, message)
358 else:
359 return message
360
361 @classmethod
362 def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
363 message = None
364 if useQueue:
365 cls._toResponderQueue.put(response, True, timeout)
366
367 sock = cls.openTCPConnection(timeout)
368
369 try:
370 cls.sendTCPQueryOverConnection(sock, query, rawQuery)
371 message = cls.recvTCPResponseOverConnection(sock)
372 except socket.timeout as e:
373 print("Timeout: %s" % (str(e)))
374 except socket.error as e:
375 print("Network error: %s" % (str(e)))
376 finally:
377 sock.close()
378
379 receivedQuery = None
380 if useQueue and not cls._fromResponderQueue.empty():
381 receivedQuery = cls._fromResponderQueue.get(True, timeout)
382
383 return (receivedQuery, message)
384
385 @classmethod
386 def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False):
387 if useQueue:
388 for response in responses:
389 cls._toResponderQueue.put(response, True, timeout)
390 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
391 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
392 if timeout:
393 sock.settimeout(timeout)
394
395 sock.connect(("127.0.0.1", cls._dnsDistPort))
396 messages = []
397
398 try:
399 if not rawQuery:
400 wire = query.to_wire()
401 else:
402 wire = query
403
404 sock.send(struct.pack("!H", len(wire)))
405 sock.send(wire)
406 while True:
407 data = sock.recv(2)
408 if not data:
409 break
410 (datalen,) = struct.unpack("!H", data)
411 data = sock.recv(datalen)
412 messages.append(dns.message.from_wire(data))
413
414 except socket.timeout as e:
415 print("Timeout: %s" % (str(e)))
416 except socket.error as e:
417 print("Network error: %s" % (str(e)))
418 finally:
419 sock.close()
420
421 receivedQuery = None
422 if useQueue and not cls._fromResponderQueue.empty():
423 receivedQuery = cls._fromResponderQueue.get(True, timeout)
424 return (receivedQuery, messages)
425
426 def setUp(self):
427 # This function is called before every tests
428
429 # Clear the responses counters
430 for key in self._responsesCounter:
431 self._responsesCounter[key] = 0
432
433 self._healthCheckCounter = 0
434
435 # Make sure the queues are empty, in case
436 # a previous test failed
437 while not self._toResponderQueue.empty():
438 self._toResponderQueue.get(False)
439
440 while not self._fromResponderQueue.empty():
441 self._fromResponderQueue.get(False)
442
443 @classmethod
444 def clearToResponderQueue(cls):
445 while not cls._toResponderQueue.empty():
446 cls._toResponderQueue.get(False)
447
448 @classmethod
449 def clearFromResponderQueue(cls):
450 while not cls._fromResponderQueue.empty():
451 cls._fromResponderQueue.get(False)
452
453 @classmethod
454 def clearResponderQueues(cls):
455 cls.clearToResponderQueue()
456 cls.clearFromResponderQueue()
457
458 @staticmethod
459 def generateConsoleKey():
460 return libnacl.utils.salsa_key()
461
462 @classmethod
463 def _encryptConsole(cls, command, nonce):
464 command = command.encode('UTF-8')
465 if cls._consoleKey is None:
466 return command
467 return libnacl.crypto_secretbox(command, nonce, cls._consoleKey)
468
469 @classmethod
470 def _decryptConsole(cls, command, nonce):
471 if cls._consoleKey is None:
472 result = command
473 else:
474 result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey)
475 return result.decode('UTF-8')
476
477 @classmethod
478 def sendConsoleCommand(cls, command, timeout=1.0):
479 ourNonce = libnacl.utils.rand_nonce()
480 theirNonce = None
481 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
482 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
483 if timeout:
484 sock.settimeout(timeout)
485
486 sock.connect(("127.0.0.1", cls._consolePort))
487 sock.send(ourNonce)
488 theirNonce = sock.recv(len(ourNonce))
489 if len(theirNonce) != len(ourNonce):
490 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce)))
491 if len(theirNonce) == 0:
492 raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce)))
493 return None
494
495 halfNonceSize = int(len(ourNonce) / 2)
496 readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:]
497 writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:]
498 msg = cls._encryptConsole(command, writingNonce)
499 sock.send(struct.pack("!I", len(msg)))
500 sock.send(msg)
501 data = sock.recv(4)
502 if not data:
503 raise socket.error("Got EOF while reading the response size")
504
505 (responseLen,) = struct.unpack("!I", data)
506 data = sock.recv(responseLen)
507 response = cls._decryptConsole(data, readingNonce)
508 return response
509
510 def compareOptions(self, a, b):
511 self.assertEquals(len(a), len(b))
512 for idx in range(len(a)):
513 self.assertEquals(a[idx], b[idx])
514
515 def checkMessageNoEDNS(self, expected, received):
516 self.assertEquals(expected, received)
517 self.assertEquals(received.edns, -1)
518 self.assertEquals(len(received.options), 0)
519
520 def checkMessageEDNSWithoutOptions(self, expected, received):
521 self.assertEquals(expected, received)
522 self.assertEquals(received.edns, 0)
523
524 def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
525 self.assertEquals(expected, received)
526 self.assertEquals(received.edns, 0)
527 self.assertEquals(len(received.options), withCookies)
528 if withCookies:
529 for option in received.options:
530 self.assertEquals(option.otype, 10)
531
532 def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
533 self.assertEquals(expected, received)
534 self.assertEquals(received.edns, 0)
535 self.assertEquals(len(received.options), 1 + additionalOptions)
536 hasECS = False
537 for option in received.options:
538 if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE:
539 hasECS = True
540 else:
541 self.assertNotEquals(additionalOptions, 0)
542
543 self.compareOptions(expected.options, received.options)
544 self.assertTrue(hasECS)
545
546 def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0):
547 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
548
549 def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0):
550 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
551
552 def checkQueryEDNSWithoutECS(self, expected, received):
553 self.checkMessageEDNSWithoutECS(expected, received)
554
555 def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0):
556 self.checkMessageEDNSWithoutECS(expected, received, withCookies)
557
558 def checkQueryNoEDNS(self, expected, received):
559 self.checkMessageNoEDNS(expected, received)
560
561 def checkResponseNoEDNS(self, expected, received):
562 self.checkMessageNoEDNS(expected, received)