]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/dnsdisttests.py
Merge pull request #6650 from pieterlexis/doc-nits
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
1 #!/usr/bin/env python2
2
3 import copy
4 import os
5 import socket
6 import ssl
7 import struct
8 import subprocess
9 import sys
10 import threading
11 import time
12 import unittest
13 import clientsubnetoption
14 import dns
15 import dns.message
16 import libnacl
17 import libnacl.utils
18
19 # Python2/3 compatibility hacks
20 if sys.version_info[0] == 2:
21 from Queue import Queue
22 range = xrange
23 else:
24 from queue import Queue
25 range = range # allow re-export of the builtin name
26
27
28 class DNSDistTest(unittest.TestCase):
29 """
30 Set up a dnsdist instance and responder threads.
31 Queries sent to dnsdist are relayed to the responder threads,
32 who reply with the response provided by the tests themselves
33 on a queue. Responder threads also queue the queries received
34 from dnsdist on a separate queue, allowing the tests to check
35 that the queries sent from dnsdist were as expected.
36 """
37 _dnsDistPort = 5340
38 _dnsDistListeningAddr = "127.0.0.1"
39 _testServerPort = 5350
40 _toResponderQueue = Queue()
41 _fromResponderQueue = Queue()
42 _queueTimeout = 1
43 _dnsdistStartupDelay = 2.0
44 _dnsdist = None
45 _responsesCounter = {}
46 _shutUp = True
47 _config_template = """
48 """
49 _config_params = ['_testServerPort']
50 _acl = ['127.0.0.1/32']
51 _consolePort = 5199
52 _consoleKey = None
53 _healthCheckName = 'a.root-servers.net.'
54 _healthCheckCounter = 0
55 _answerUnexpected = True
56
57 @classmethod
58 def startResponders(cls):
59 print("Launching responders..")
60
61 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
62 cls._UDPResponder.setDaemon(True)
63 cls._UDPResponder.start()
64 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
65 cls._TCPResponder.setDaemon(True)
66 cls._TCPResponder.start()
67
68 @classmethod
69 def startDNSDist(cls, shutUp=True):
70 print("Launching dnsdist..")
71 conffile = 'dnsdist_test.conf'
72 params = tuple([getattr(cls, param) for param in cls._config_params])
73 print(params)
74 with open(conffile, 'w') as conf:
75 conf.write("-- Autogenerated by dnsdisttests.py\n")
76 conf.write(cls._config_template % params)
77
78 dnsdistcmd = [os.environ['DNSDISTBIN'], '-C', conffile,
79 '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
80 for acl in cls._acl:
81 dnsdistcmd.extend(['--acl', acl])
82 print(' '.join(dnsdistcmd))
83
84 # validate config with --check-config, which sets client=true, possibly exposing bugs.
85 testcmd = dnsdistcmd + ['--check-config']
86 output = subprocess.check_output(testcmd, close_fds=True)
87 if output != b'Configuration \'dnsdist_test.conf\' OK!\n':
88 raise AssertionError('dnsdist --check-config failed: %s' % output)
89
90 if shutUp:
91 with open(os.devnull, 'w') as fdDevNull:
92 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdDevNull)
93 else:
94 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True)
95
96 if 'DNSDIST_FAST_TESTS' in os.environ:
97 delay = 0.5
98 else:
99 delay = cls._dnsdistStartupDelay
100
101 time.sleep(delay)
102
103 if cls._dnsdist.poll() is not None:
104 cls._dnsdist.kill()
105 sys.exit(cls._dnsdist.returncode)
106
107 @classmethod
108 def setUpSockets(cls):
109 print("Setting up UDP socket..")
110 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
111 cls._sock.settimeout(2.0)
112 cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
113
114 @classmethod
115 def setUpClass(cls):
116
117 cls.startResponders()
118 cls.startDNSDist(cls._shutUp)
119 cls.setUpSockets()
120
121 print("Launching tests..")
122
123 @classmethod
124 def tearDownClass(cls):
125 if 'DNSDIST_FAST_TESTS' in os.environ:
126 delay = 0.1
127 else:
128 delay = 1.0
129 if cls._dnsdist:
130 cls._dnsdist.terminate()
131 if cls._dnsdist.poll() is None:
132 time.sleep(delay)
133 if cls._dnsdist.poll() is None:
134 cls._dnsdist.kill()
135 cls._dnsdist.wait()
136
137 @classmethod
138 def _ResponderIncrementCounter(cls):
139 if threading.currentThread().name in cls._responsesCounter:
140 cls._responsesCounter[threading.currentThread().name] += 1
141 else:
142 cls._responsesCounter[threading.currentThread().name] = 1
143
144 @classmethod
145 def _getResponse(cls, request, fromQueue, toQueue):
146 response = None
147 if len(request.question) != 1:
148 print("Skipping query with question count %d" % (len(request.question)))
149 return None
150 healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
151 if healthCheck:
152 cls._healthCheckCounter += 1
153 else:
154 cls._ResponderIncrementCounter()
155 if not fromQueue.empty():
156 response = fromQueue.get(True, cls._queueTimeout)
157 if response:
158 response = copy.copy(response)
159 response.id = request.id
160 toQueue.put(request, True, cls._queueTimeout)
161
162 if not response:
163 if healthCheck:
164 response = dns.message.make_response(request)
165 elif cls._answerUnexpected:
166 response = dns.message.make_response(request)
167 response.set_rcode(dns.rcode.SERVFAIL)
168
169 return response
170
171 @classmethod
172 def UDPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False):
173 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
174 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
175 sock.bind(("127.0.0.1", port))
176 while True:
177 data, addr = sock.recvfrom(4096)
178 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
179 response = cls._getResponse(request, fromQueue, toQueue)
180
181 if not response:
182 continue
183
184 sock.settimeout(2.0)
185 sock.sendto(response.to_wire(), addr)
186 sock.settimeout(None)
187 sock.close()
188
189 @classmethod
190 def TCPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False, multipleResponses=False):
191 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
192 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
193 try:
194 sock.bind(("127.0.0.1", port))
195 except socket.error as e:
196 print("Error binding in the TCP responder: %s" % str(e))
197 sys.exit(1)
198
199 sock.listen(100)
200 while True:
201 (conn, _) = sock.accept()
202 conn.settimeout(2.0)
203 data = conn.recv(2)
204 if not data:
205 conn.close()
206 continue
207
208 (datalen,) = struct.unpack("!H", data)
209 data = conn.recv(datalen)
210 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
211 response = cls._getResponse(request, fromQueue, toQueue)
212
213 if not response:
214 conn.close()
215 continue
216
217 wire = response.to_wire()
218 conn.send(struct.pack("!H", len(wire)))
219 conn.send(wire)
220
221 while multipleResponses:
222 if fromQueue.empty():
223 break
224
225 response = fromQueue.get(True, cls._queueTimeout)
226 if not response:
227 break
228
229 response = copy.copy(response)
230 response.id = request.id
231 wire = response.to_wire()
232 try:
233 conn.send(struct.pack("!H", len(wire)))
234 conn.send(wire)
235 except socket.error as e:
236 # some of the tests are going to close
237 # the connection on us, just deal with it
238 break
239
240 conn.close()
241
242 sock.close()
243
244 @classmethod
245 def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
246 if useQueue:
247 cls._toResponderQueue.put(response, True, timeout)
248
249 if timeout:
250 cls._sock.settimeout(timeout)
251
252 try:
253 if not rawQuery:
254 query = query.to_wire()
255 cls._sock.send(query)
256 data = cls._sock.recv(4096)
257 except socket.timeout:
258 data = None
259 finally:
260 if timeout:
261 cls._sock.settimeout(None)
262
263 receivedQuery = None
264 message = None
265 if useQueue and not cls._fromResponderQueue.empty():
266 receivedQuery = cls._fromResponderQueue.get(True, timeout)
267 if data:
268 message = dns.message.from_wire(data)
269 return (receivedQuery, message)
270
271 @classmethod
272 def openTCPConnection(cls, timeout=None):
273 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
274 if timeout:
275 sock.settimeout(timeout)
276
277 sock.connect(("127.0.0.1", cls._dnsDistPort))
278 return sock
279
280 @classmethod
281 def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
282 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
283 if timeout:
284 sock.settimeout(timeout)
285
286 # 2.7.9+
287 if hasattr(ssl, 'create_default_context'):
288 sslctx = ssl.create_default_context(cafile=caCert)
289 sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
290 else:
291 sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
292
293 sslsock.connect(("127.0.0.1", port))
294 return sslsock
295
296 @classmethod
297 def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0):
298 if not rawQuery:
299 wire = query.to_wire()
300 else:
301 wire = query
302
303 if response:
304 cls._toResponderQueue.put(response, True, timeout)
305
306 sock.send(struct.pack("!H", len(wire)))
307 sock.send(wire)
308
309 @classmethod
310 def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
311 message = None
312 data = sock.recv(2)
313 if data:
314 (datalen,) = struct.unpack("!H", data)
315 data = sock.recv(datalen)
316 if data:
317 message = dns.message.from_wire(data)
318
319 if useQueue and not cls._fromResponderQueue.empty():
320 receivedQuery = cls._fromResponderQueue.get(True, timeout)
321 return (receivedQuery, message)
322 else:
323 return message
324
325 @classmethod
326 def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
327 message = None
328 if useQueue:
329 cls._toResponderQueue.put(response, True, timeout)
330
331 sock = cls.openTCPConnection(timeout)
332
333 try:
334 cls.sendTCPQueryOverConnection(sock, query, rawQuery)
335 message = cls.recvTCPResponseOverConnection(sock)
336 except socket.timeout as e:
337 print("Timeout: %s" % (str(e)))
338 except socket.error as e:
339 print("Network error: %s" % (str(e)))
340 finally:
341 sock.close()
342
343 receivedQuery = None
344 if useQueue and not cls._fromResponderQueue.empty():
345 receivedQuery = cls._fromResponderQueue.get(True, timeout)
346
347 return (receivedQuery, message)
348
349 @classmethod
350 def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False):
351 if useQueue:
352 for response in responses:
353 cls._toResponderQueue.put(response, True, timeout)
354 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
355 if timeout:
356 sock.settimeout(timeout)
357
358 sock.connect(("127.0.0.1", cls._dnsDistPort))
359 messages = []
360
361 try:
362 if not rawQuery:
363 wire = query.to_wire()
364 else:
365 wire = query
366
367 sock.send(struct.pack("!H", len(wire)))
368 sock.send(wire)
369 while True:
370 data = sock.recv(2)
371 if not data:
372 break
373 (datalen,) = struct.unpack("!H", data)
374 data = sock.recv(datalen)
375 messages.append(dns.message.from_wire(data))
376
377 except socket.timeout as e:
378 print("Timeout: %s" % (str(e)))
379 except socket.error as e:
380 print("Network error: %s" % (str(e)))
381 finally:
382 sock.close()
383
384 receivedQuery = None
385 if useQueue and not cls._fromResponderQueue.empty():
386 receivedQuery = cls._fromResponderQueue.get(True, timeout)
387 return (receivedQuery, messages)
388
389 def setUp(self):
390 # This function is called before every tests
391
392 # Clear the responses counters
393 for key in self._responsesCounter:
394 self._responsesCounter[key] = 0
395
396 self._healthCheckCounter = 0
397
398 # Make sure the queues are empty, in case
399 # a previous test failed
400 while not self._toResponderQueue.empty():
401 self._toResponderQueue.get(False)
402
403 while not self._fromResponderQueue.empty():
404 self._fromResponderQueue.get(False)
405
406 @classmethod
407 def clearToResponderQueue(cls):
408 while not cls._toResponderQueue.empty():
409 cls._toResponderQueue.get(False)
410
411 @classmethod
412 def clearFromResponderQueue(cls):
413 while not cls._fromResponderQueue.empty():
414 cls._fromResponderQueue.get(False)
415
416 @classmethod
417 def clearResponderQueues(cls):
418 cls.clearToResponderQueue()
419 cls.clearFromResponderQueue()
420
421 @staticmethod
422 def generateConsoleKey():
423 return libnacl.utils.salsa_key()
424
425 @classmethod
426 def _encryptConsole(cls, command, nonce):
427 command = command.encode('UTF-8')
428 if cls._consoleKey is None:
429 return command
430 return libnacl.crypto_secretbox(command, nonce, cls._consoleKey)
431
432 @classmethod
433 def _decryptConsole(cls, command, nonce):
434 if cls._consoleKey is None:
435 result = command
436 else:
437 result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey)
438 return result.decode('UTF-8')
439
440 @classmethod
441 def sendConsoleCommand(cls, command, timeout=1.0):
442 ourNonce = libnacl.utils.rand_nonce()
443 theirNonce = None
444 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
445 if timeout:
446 sock.settimeout(timeout)
447
448 sock.connect(("127.0.0.1", cls._consolePort))
449 sock.send(ourNonce)
450 theirNonce = sock.recv(len(ourNonce))
451 if len(theirNonce) != len(ourNonce):
452 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce)))
453 if len(theirNonce) == 0:
454 raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce)))
455 return None
456
457 halfNonceSize = int(len(ourNonce) / 2)
458 readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:]
459 writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:]
460 msg = cls._encryptConsole(command, writingNonce)
461 sock.send(struct.pack("!I", len(msg)))
462 sock.send(msg)
463 data = sock.recv(4)
464 if not data:
465 raise socket.error("Got EOF while reading the response size")
466
467 (responseLen,) = struct.unpack("!I", data)
468 data = sock.recv(responseLen)
469 response = cls._decryptConsole(data, readingNonce)
470 return response
471
472 def compareOptions(self, a, b):
473 self.assertEquals(len(a), len(b))
474 for idx in range(len(a)):
475 self.assertEquals(a[idx], b[idx])
476
477 def checkMessageNoEDNS(self, expected, received):
478 self.assertEquals(expected, received)
479 self.assertEquals(received.edns, -1)
480 self.assertEquals(len(received.options), 0)
481
482 def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
483 self.assertEquals(expected, received)
484 self.assertEquals(received.edns, 0)
485 self.assertEquals(len(received.options), withCookies)
486 if withCookies:
487 for option in received.options:
488 self.assertEquals(option.otype, 10)
489
490 def checkMessageEDNSWithECS(self, expected, received):
491 self.assertEquals(expected, received)
492 self.assertEquals(received.edns, 0)
493 self.assertEquals(len(received.options), 1)
494 self.assertEquals(received.options[0].otype, clientsubnetoption.ASSIGNED_OPTION_CODE)
495 self.compareOptions(expected.options, received.options)
496
497 def checkQueryEDNSWithECS(self, expected, received):
498 self.checkMessageEDNSWithECS(expected, received)
499
500 def checkResponseEDNSWithECS(self, expected, received):
501 self.checkMessageEDNSWithECS(expected, received)
502
503 def checkQueryEDNSWithoutECS(self, expected, received):
504 self.checkMessageEDNSWithoutECS(expected, received)
505
506 def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0):
507 self.checkMessageEDNSWithoutECS(expected, received, withCookies)
508
509 def checkQueryNoEDNS(self, expected, received):
510 self.checkMessageNoEDNS(expected, received)
511
512 def checkResponseNoEDNS(self, expected, received):
513 self.checkMessageNoEDNS(expected, received)