]> git.ipfire.org Git - thirdparty/pdns.git/blame - regression-tests.dnsdist/dnsdisttests.py
Merge pull request #7062 from rgacogne/dnsdist-tls-stored-sessions
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
CommitLineData
ca404e94
RG
1#!/usr/bin/env python2
2
95f0b802 3import copy
ca404e94
RG
4import os
5import socket
a227f47d 6import ssl
ca404e94
RG
7import struct
8import subprocess
9import sys
10import threading
11import time
12import unittest
5df86a8a 13import clientsubnetoption
b1bec9f0
RG
14import dns
15import dns.message
1ea747c0
RG
16import libnacl
17import libnacl.utils
ca404e94 18
b4f23783
CH
19# Python2/3 compatibility hacks
20if sys.version_info[0] == 2:
21 from Queue import Queue
22 range = xrange
23else:
24 from queue import Queue
25 range = range # allow re-export of the builtin name
26
27
ca404e94
RG
28class DNSDistTest(unittest.TestCase):
29 """
30 Set up a dnsdist instance and responder threads.
31 Queries sent to dnsdist are relayed to the responder threads,
32 who reply with the response provided by the tests themselves
33 on a queue. Responder threads also queue the queries received
34 from dnsdist on a separate queue, allowing the tests to check
35 that the queries sent from dnsdist were as expected.
36 """
37 _dnsDistPort = 5340
b052847c 38 _dnsDistListeningAddr = "127.0.0.1"
ca404e94 39 _testServerPort = 5350
b4f23783
CH
40 _toResponderQueue = Queue()
41 _fromResponderQueue = Queue()
617dfe22 42 _queueTimeout = 1
b1bec9f0 43 _dnsdistStartupDelay = 2.0
ca404e94 44 _dnsdist = None
ec5f5c6b 45 _responsesCounter = {}
b1bec9f0 46 _shutUp = True
18a0e7c6 47 _config_template = """
18a0e7c6
CH
48 """
49 _config_params = ['_testServerPort']
50 _acl = ['127.0.0.1/32']
1ea747c0
RG
51 _consolePort = 5199
52 _consoleKey = None
98650fde
RG
53 _healthCheckName = 'a.root-servers.net.'
54 _healthCheckCounter = 0
e44df0f1 55 _answerUnexpected = True
ca404e94
RG
56
57 @classmethod
58 def startResponders(cls):
59 print("Launching responders..")
ec5f5c6b 60
5df86a8a 61 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
ca404e94
RG
62 cls._UDPResponder.setDaemon(True)
63 cls._UDPResponder.start()
5df86a8a 64 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
ca404e94
RG
65 cls._TCPResponder.setDaemon(True)
66 cls._TCPResponder.start()
67
68 @classmethod
69 def startDNSDist(cls, shutUp=True):
70 print("Launching dnsdist..")
18a0e7c6
CH
71 conffile = 'dnsdist_test.conf'
72 params = tuple([getattr(cls, param) for param in cls._config_params])
73 print(params)
74 with open(conffile, 'w') as conf:
75 conf.write("-- Autogenerated by dnsdisttests.py\n")
76 conf.write(cls._config_template % params)
77
78 dnsdistcmd = [os.environ['DNSDISTBIN'], '-C', conffile,
b052847c 79 '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
18a0e7c6
CH
80 for acl in cls._acl:
81 dnsdistcmd.extend(['--acl', acl])
82 print(' '.join(dnsdistcmd))
83
6b44773a
CH
84 # validate config with --check-config, which sets client=true, possibly exposing bugs.
85 testcmd = dnsdistcmd + ['--check-config']
86 output = subprocess.check_output(testcmd, close_fds=True)
87 if output != b'Configuration \'dnsdist_test.conf\' OK!\n':
88 raise AssertionError('dnsdist --check-config failed: %s' % output)
89
ca404e94
RG
90 if shutUp:
91 with open(os.devnull, 'w') as fdDevNull:
bd64cc44 92 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdDevNull)
ca404e94 93 else:
18a0e7c6 94 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True)
ca404e94 95
0a2087eb
RG
96 if 'DNSDIST_FAST_TESTS' in os.environ:
97 delay = 0.5
98 else:
617dfe22
RG
99 delay = cls._dnsdistStartupDelay
100
0a2087eb 101 time.sleep(delay)
ca404e94
RG
102
103 if cls._dnsdist.poll() is not None:
0a2087eb 104 cls._dnsdist.kill()
ca404e94
RG
105 sys.exit(cls._dnsdist.returncode)
106
107 @classmethod
108 def setUpSockets(cls):
109 print("Setting up UDP socket..")
110 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1ade83b2 111 cls._sock.settimeout(2.0)
ca404e94
RG
112 cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
113
114 @classmethod
115 def setUpClass(cls):
116
117 cls.startResponders()
b1bec9f0 118 cls.startDNSDist(cls._shutUp)
ca404e94
RG
119 cls.setUpSockets()
120
121 print("Launching tests..")
122
123 @classmethod
124 def tearDownClass(cls):
0a2087eb
RG
125 if 'DNSDIST_FAST_TESTS' in os.environ:
126 delay = 0.1
127 else:
b1bec9f0 128 delay = 1.0
ca404e94
RG
129 if cls._dnsdist:
130 cls._dnsdist.terminate()
0a2087eb
RG
131 if cls._dnsdist.poll() is None:
132 time.sleep(delay)
133 if cls._dnsdist.poll() is None:
134 cls._dnsdist.kill()
1ade83b2 135 cls._dnsdist.wait()
ca404e94
RG
136
137 @classmethod
fe1c60f2 138 def _ResponderIncrementCounter(cls):
ec5f5c6b
RG
139 if threading.currentThread().name in cls._responsesCounter:
140 cls._responsesCounter[threading.currentThread().name] += 1
141 else:
142 cls._responsesCounter[threading.currentThread().name] = 1
143
fe1c60f2 144 @classmethod
5df86a8a 145 def _getResponse(cls, request, fromQueue, toQueue):
fe1c60f2
RG
146 response = None
147 if len(request.question) != 1:
148 print("Skipping query with question count %d" % (len(request.question)))
149 return None
98650fde
RG
150 healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
151 if healthCheck:
152 cls._healthCheckCounter += 1
153 else:
fe1c60f2 154 cls._ResponderIncrementCounter()
5df86a8a
RG
155 if not fromQueue.empty():
156 response = fromQueue.get(True, cls._queueTimeout)
fe1c60f2
RG
157 if response:
158 response = copy.copy(response)
159 response.id = request.id
5df86a8a 160 toQueue.put(request, True, cls._queueTimeout)
fe1c60f2 161
e44df0f1
RG
162 if not response:
163 if healthCheck:
164 response = dns.message.make_response(request)
165 elif cls._answerUnexpected:
166 response = dns.message.make_response(request)
167 response.set_rcode(dns.rcode.SERVFAIL)
fe1c60f2
RG
168
169 return response
170
ec5f5c6b 171 @classmethod
5df86a8a 172 def UDPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False):
ca404e94
RG
173 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
174 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
ec5f5c6b 175 sock.bind(("127.0.0.1", port))
ca404e94
RG
176 while True:
177 data, addr = sock.recvfrom(4096)
55baa1f2 178 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
5df86a8a 179 response = cls._getResponse(request, fromQueue, toQueue)
55baa1f2 180
fe1c60f2 181 if not response:
ca404e94 182 continue
87c605c4 183
1ade83b2 184 sock.settimeout(2.0)
ca404e94 185 sock.sendto(response.to_wire(), addr)
1ade83b2 186 sock.settimeout(None)
ca404e94
RG
187 sock.close()
188
189 @classmethod
5df86a8a 190 def TCPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False, multipleResponses=False):
ca404e94
RG
191 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
192 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
193 try:
ec5f5c6b 194 sock.bind(("127.0.0.1", port))
ca404e94
RG
195 except socket.error as e:
196 print("Error binding in the TCP responder: %s" % str(e))
197 sys.exit(1)
198
199 sock.listen(100)
200 while True:
b1bec9f0 201 (conn, _) = sock.accept()
1ade83b2 202 conn.settimeout(2.0)
ca404e94 203 data = conn.recv(2)
98650fde
RG
204 if not data:
205 conn.close()
206 continue
207
ca404e94
RG
208 (datalen,) = struct.unpack("!H", data)
209 data = conn.recv(datalen)
55baa1f2 210 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
5df86a8a 211 response = cls._getResponse(request, fromQueue, toQueue)
55baa1f2 212
fe1c60f2 213 if not response:
548c8b66 214 conn.close()
ca404e94 215 continue
ca404e94
RG
216
217 wire = response.to_wire()
218 conn.send(struct.pack("!H", len(wire)))
219 conn.send(wire)
548c8b66
RG
220
221 while multipleResponses:
5df86a8a 222 if fromQueue.empty():
548c8b66
RG
223 break
224
5df86a8a 225 response = fromQueue.get(True, cls._queueTimeout)
548c8b66
RG
226 if not response:
227 break
228
229 response = copy.copy(response)
230 response.id = request.id
231 wire = response.to_wire()
284d460c
RG
232 try:
233 conn.send(struct.pack("!H", len(wire)))
234 conn.send(wire)
235 except socket.error as e:
236 # some of the tests are going to close
237 # the connection on us, just deal with it
238 break
548c8b66 239
ca404e94 240 conn.close()
548c8b66 241
ca404e94
RG
242 sock.close()
243
244 @classmethod
55baa1f2 245 def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
ca404e94 246 if useQueue:
617dfe22 247 cls._toResponderQueue.put(response, True, timeout)
ca404e94
RG
248
249 if timeout:
250 cls._sock.settimeout(timeout)
251
252 try:
55baa1f2
RG
253 if not rawQuery:
254 query = query.to_wire()
255 cls._sock.send(query)
ca404e94 256 data = cls._sock.recv(4096)
b1bec9f0 257 except socket.timeout:
ca404e94
RG
258 data = None
259 finally:
260 if timeout:
261 cls._sock.settimeout(None)
262
263 receivedQuery = None
264 message = None
265 if useQueue and not cls._fromResponderQueue.empty():
617dfe22 266 receivedQuery = cls._fromResponderQueue.get(True, timeout)
ca404e94
RG
267 if data:
268 message = dns.message.from_wire(data)
269 return (receivedQuery, message)
270
271 @classmethod
9396d955 272 def openTCPConnection(cls, timeout=None):
ca404e94 273 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
ca404e94
RG
274 if timeout:
275 sock.settimeout(timeout)
276
0a2087eb 277 sock.connect(("127.0.0.1", cls._dnsDistPort))
9396d955 278 return sock
0a2087eb 279
9396d955 280 @classmethod
a227f47d
RG
281 def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
282 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
283 if timeout:
284 sock.settimeout(timeout)
285
286 # 2.7.9+
287 if hasattr(ssl, 'create_default_context'):
288 sslctx = ssl.create_default_context(cafile=caCert)
289 sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
290 else:
291 sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
292
293 sslsock.connect(("127.0.0.1", port))
294 return sslsock
295
296 @classmethod
297 def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0):
9396d955
RG
298 if not rawQuery:
299 wire = query.to_wire()
300 else:
301 wire = query
55baa1f2 302
a227f47d
RG
303 if response:
304 cls._toResponderQueue.put(response, True, timeout)
305
9396d955
RG
306 sock.send(struct.pack("!H", len(wire)))
307 sock.send(wire)
308
309 @classmethod
a227f47d 310 def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
9396d955
RG
311 message = None
312 data = sock.recv(2)
313 if data:
314 (datalen,) = struct.unpack("!H", data)
315 data = sock.recv(datalen)
ca404e94 316 if data:
9396d955 317 message = dns.message.from_wire(data)
a227f47d
RG
318
319 if useQueue and not cls._fromResponderQueue.empty():
320 receivedQuery = cls._fromResponderQueue.get(True, timeout)
321 return (receivedQuery, message)
322 else:
323 return message
9396d955
RG
324
325 @classmethod
326 def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
327 message = None
328 if useQueue:
329 cls._toResponderQueue.put(response, True, timeout)
330
331 sock = cls.openTCPConnection(timeout)
332
333 try:
334 cls.sendTCPQueryOverConnection(sock, query, rawQuery)
335 message = cls.recvTCPResponseOverConnection(sock)
ca404e94
RG
336 except socket.timeout as e:
337 print("Timeout: %s" % (str(e)))
ca404e94
RG
338 except socket.error as e:
339 print("Network error: %s" % (str(e)))
ca404e94
RG
340 finally:
341 sock.close()
342
343 receivedQuery = None
ca404e94 344 if useQueue and not cls._fromResponderQueue.empty():
617dfe22 345 receivedQuery = cls._fromResponderQueue.get(True, timeout)
9396d955 346
ca404e94 347 return (receivedQuery, message)
617dfe22 348
548c8b66
RG
349 @classmethod
350 def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False):
351 if useQueue:
352 for response in responses:
353 cls._toResponderQueue.put(response, True, timeout)
354 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
355 if timeout:
356 sock.settimeout(timeout)
357
358 sock.connect(("127.0.0.1", cls._dnsDistPort))
359 messages = []
360
361 try:
362 if not rawQuery:
363 wire = query.to_wire()
364 else:
365 wire = query
366
367 sock.send(struct.pack("!H", len(wire)))
368 sock.send(wire)
369 while True:
370 data = sock.recv(2)
371 if not data:
372 break
373 (datalen,) = struct.unpack("!H", data)
374 data = sock.recv(datalen)
375 messages.append(dns.message.from_wire(data))
376
377 except socket.timeout as e:
378 print("Timeout: %s" % (str(e)))
379 except socket.error as e:
380 print("Network error: %s" % (str(e)))
381 finally:
382 sock.close()
383
384 receivedQuery = None
385 if useQueue and not cls._fromResponderQueue.empty():
386 receivedQuery = cls._fromResponderQueue.get(True, timeout)
387 return (receivedQuery, messages)
388
617dfe22
RG
389 def setUp(self):
390 # This function is called before every tests
391
392 # Clear the responses counters
393 for key in self._responsesCounter:
394 self._responsesCounter[key] = 0
395
98650fde
RG
396 self._healthCheckCounter = 0
397
617dfe22
RG
398 # Make sure the queues are empty, in case
399 # a previous test failed
400 while not self._toResponderQueue.empty():
401 self._toResponderQueue.get(False)
402
403 while not self._fromResponderQueue.empty():
fe1c60f2 404 self._fromResponderQueue.get(False)
1ea747c0 405
3bef39c3
RG
406 @classmethod
407 def clearToResponderQueue(cls):
408 while not cls._toResponderQueue.empty():
409 cls._toResponderQueue.get(False)
410
411 @classmethod
412 def clearFromResponderQueue(cls):
413 while not cls._fromResponderQueue.empty():
414 cls._fromResponderQueue.get(False)
415
416 @classmethod
417 def clearResponderQueues(cls):
418 cls.clearToResponderQueue()
419 cls.clearFromResponderQueue()
420
1ea747c0
RG
421 @staticmethod
422 def generateConsoleKey():
423 return libnacl.utils.salsa_key()
424
425 @classmethod
426 def _encryptConsole(cls, command, nonce):
b4f23783 427 command = command.encode('UTF-8')
1ea747c0
RG
428 if cls._consoleKey is None:
429 return command
430 return libnacl.crypto_secretbox(command, nonce, cls._consoleKey)
431
432 @classmethod
433 def _decryptConsole(cls, command, nonce):
434 if cls._consoleKey is None:
b4f23783
CH
435 result = command
436 else:
437 result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey)
438 return result.decode('UTF-8')
1ea747c0
RG
439
440 @classmethod
441 def sendConsoleCommand(cls, command, timeout=1.0):
442 ourNonce = libnacl.utils.rand_nonce()
443 theirNonce = None
444 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
445 if timeout:
446 sock.settimeout(timeout)
447
448 sock.connect(("127.0.0.1", cls._consolePort))
449 sock.send(ourNonce)
450 theirNonce = sock.recv(len(ourNonce))
7b925432 451 if len(theirNonce) != len(ourNonce):
05a5b575 452 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce)))
bdfa6902
RG
453 if len(theirNonce) == 0:
454 raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce)))
7b925432 455 return None
1ea747c0 456
b4f23783 457 halfNonceSize = int(len(ourNonce) / 2)
333ea16e
RG
458 readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:]
459 writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:]
333ea16e 460 msg = cls._encryptConsole(command, writingNonce)
1ea747c0
RG
461 sock.send(struct.pack("!I", len(msg)))
462 sock.send(msg)
463 data = sock.recv(4)
9c9b4998
RG
464 if not data:
465 raise socket.error("Got EOF while reading the response size")
466
1ea747c0
RG
467 (responseLen,) = struct.unpack("!I", data)
468 data = sock.recv(responseLen)
333ea16e 469 response = cls._decryptConsole(data, readingNonce)
1ea747c0 470 return response
5df86a8a
RG
471
472 def compareOptions(self, a, b):
473 self.assertEquals(len(a), len(b))
b4f23783 474 for idx in range(len(a)):
5df86a8a
RG
475 self.assertEquals(a[idx], b[idx])
476
477 def checkMessageNoEDNS(self, expected, received):
478 self.assertEquals(expected, received)
479 self.assertEquals(received.edns, -1)
480 self.assertEquals(len(received.options), 0)
481
e7c732b8
RG
482 def checkMessageEDNSWithoutOptions(self, expected, received):
483 self.assertEquals(expected, received)
484 self.assertEquals(received.edns, 0)
485
5df86a8a
RG
486 def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
487 self.assertEquals(expected, received)
488 self.assertEquals(received.edns, 0)
489 self.assertEquals(len(received.options), withCookies)
490 if withCookies:
491 for option in received.options:
492 self.assertEquals(option.otype, 10)
493
cbf4e13a 494 def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
5df86a8a
RG
495 self.assertEquals(expected, received)
496 self.assertEquals(received.edns, 0)
cbf4e13a
RG
497 self.assertEquals(len(received.options), 1 + additionalOptions)
498 hasECS = False
499 for option in received.options:
500 if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE:
501 hasECS = True
502 else:
503 self.assertNotEquals(additionalOptions, 0)
504
5df86a8a 505 self.compareOptions(expected.options, received.options)
cbf4e13a 506 self.assertTrue(hasECS)
5df86a8a 507
cbf4e13a
RG
508 def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0):
509 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
5df86a8a 510
cbf4e13a
RG
511 def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0):
512 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
5df86a8a
RG
513
514 def checkQueryEDNSWithoutECS(self, expected, received):
515 self.checkMessageEDNSWithoutECS(expected, received)
516
517 def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0):
518 self.checkMessageEDNSWithoutECS(expected, received, withCookies)
519
520 def checkQueryNoEDNS(self, expected, received):
521 self.checkMessageNoEDNS(expected, received)
522
523 def checkResponseNoEDNS(self, expected, received):
524 self.checkMessageNoEDNS(expected, received)