]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/dnsdisttests.py
Merge pull request #7148 from Habbie/sdig-class-in
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
1 #!/usr/bin/env python2
2
3 import copy
4 import os
5 import socket
6 import ssl
7 import struct
8 import subprocess
9 import sys
10 import threading
11 import time
12 import unittest
13 import clientsubnetoption
14 import dns
15 import dns.message
16 import libnacl
17 import libnacl.utils
18
19 # Python2/3 compatibility hacks
20 try:
21 from queue import Queue
22 except ImportError:
23 from Queue import Queue
24
25 try:
26 range = xrange
27 except NameError:
28 pass
29
30
31 class DNSDistTest(unittest.TestCase):
32 """
33 Set up a dnsdist instance and responder threads.
34 Queries sent to dnsdist are relayed to the responder threads,
35 who reply with the response provided by the tests themselves
36 on a queue. Responder threads also queue the queries received
37 from dnsdist on a separate queue, allowing the tests to check
38 that the queries sent from dnsdist were as expected.
39 """
40 _dnsDistPort = 5340
41 _dnsDistListeningAddr = "127.0.0.1"
42 _testServerPort = 5350
43 _toResponderQueue = Queue()
44 _fromResponderQueue = Queue()
45 _queueTimeout = 1
46 _dnsdistStartupDelay = 2.0
47 _dnsdist = None
48 _responsesCounter = {}
49 _shutUp = True
50 _config_template = """
51 """
52 _config_params = ['_testServerPort']
53 _acl = ['127.0.0.1/32']
54 _consolePort = 5199
55 _consoleKey = None
56 _healthCheckName = 'a.root-servers.net.'
57 _healthCheckCounter = 0
58 _answerUnexpected = True
59
60 @classmethod
61 def startResponders(cls):
62 print("Launching responders..")
63
64 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
65 cls._UDPResponder.setDaemon(True)
66 cls._UDPResponder.start()
67 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
68 cls._TCPResponder.setDaemon(True)
69 cls._TCPResponder.start()
70
71 @classmethod
72 def startDNSDist(cls, shutUp=True):
73 print("Launching dnsdist..")
74 conffile = 'dnsdist_test.conf'
75 params = tuple([getattr(cls, param) for param in cls._config_params])
76 print(params)
77 with open(conffile, 'w') as conf:
78 conf.write("-- Autogenerated by dnsdisttests.py\n")
79 conf.write(cls._config_template % params)
80
81 dnsdistcmd = [os.environ['DNSDISTBIN'], '-C', conffile,
82 '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
83 for acl in cls._acl:
84 dnsdistcmd.extend(['--acl', acl])
85 print(' '.join(dnsdistcmd))
86
87 # validate config with --check-config, which sets client=true, possibly exposing bugs.
88 testcmd = dnsdistcmd + ['--check-config']
89 try:
90 output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
91 except subprocess.CalledProcessError as exc:
92 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
93 if output != b'Configuration \'dnsdist_test.conf\' OK!\n':
94 raise AssertionError('dnsdist --check-config failed: %s' % output)
95
96 if shutUp:
97 with open(os.devnull, 'w') as fdDevNull:
98 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdDevNull)
99 else:
100 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True)
101
102 if 'DNSDIST_FAST_TESTS' in os.environ:
103 delay = 0.5
104 else:
105 delay = cls._dnsdistStartupDelay
106
107 time.sleep(delay)
108
109 if cls._dnsdist.poll() is not None:
110 cls._dnsdist.kill()
111 sys.exit(cls._dnsdist.returncode)
112
113 @classmethod
114 def setUpSockets(cls):
115 print("Setting up UDP socket..")
116 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
117 cls._sock.settimeout(2.0)
118 cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
119
120 @classmethod
121 def setUpClass(cls):
122
123 cls.startResponders()
124 cls.startDNSDist(cls._shutUp)
125 cls.setUpSockets()
126
127 print("Launching tests..")
128
129 @classmethod
130 def tearDownClass(cls):
131 if 'DNSDIST_FAST_TESTS' in os.environ:
132 delay = 0.1
133 else:
134 delay = 1.0
135 if cls._dnsdist:
136 cls._dnsdist.terminate()
137 if cls._dnsdist.poll() is None:
138 time.sleep(delay)
139 if cls._dnsdist.poll() is None:
140 cls._dnsdist.kill()
141 cls._dnsdist.wait()
142
143 @classmethod
144 def _ResponderIncrementCounter(cls):
145 if threading.currentThread().name in cls._responsesCounter:
146 cls._responsesCounter[threading.currentThread().name] += 1
147 else:
148 cls._responsesCounter[threading.currentThread().name] = 1
149
150 @classmethod
151 def _getResponse(cls, request, fromQueue, toQueue):
152 response = None
153 if len(request.question) != 1:
154 print("Skipping query with question count %d" % (len(request.question)))
155 return None
156 healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
157 if healthCheck:
158 cls._healthCheckCounter += 1
159 else:
160 cls._ResponderIncrementCounter()
161 if not fromQueue.empty():
162 response = fromQueue.get(True, cls._queueTimeout)
163 if response:
164 response = copy.copy(response)
165 response.id = request.id
166 toQueue.put(request, True, cls._queueTimeout)
167
168 if not response:
169 if healthCheck:
170 response = dns.message.make_response(request)
171 elif cls._answerUnexpected:
172 response = dns.message.make_response(request)
173 response.set_rcode(dns.rcode.SERVFAIL)
174
175 return response
176
177 @classmethod
178 def UDPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False):
179 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
180 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
181 sock.bind(("127.0.0.1", port))
182 while True:
183 data, addr = sock.recvfrom(4096)
184 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
185 response = cls._getResponse(request, fromQueue, toQueue)
186
187 if not response:
188 continue
189
190 sock.settimeout(2.0)
191 sock.sendto(response.to_wire(), addr)
192 sock.settimeout(None)
193 sock.close()
194
195 @classmethod
196 def TCPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False, multipleResponses=False):
197 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
198 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
199 try:
200 sock.bind(("127.0.0.1", port))
201 except socket.error as e:
202 print("Error binding in the TCP responder: %s" % str(e))
203 sys.exit(1)
204
205 sock.listen(100)
206 while True:
207 (conn, _) = sock.accept()
208 conn.settimeout(2.0)
209 data = conn.recv(2)
210 if not data:
211 conn.close()
212 continue
213
214 (datalen,) = struct.unpack("!H", data)
215 data = conn.recv(datalen)
216 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
217 response = cls._getResponse(request, fromQueue, toQueue)
218
219 if not response:
220 conn.close()
221 continue
222
223 wire = response.to_wire()
224 conn.send(struct.pack("!H", len(wire)))
225 conn.send(wire)
226
227 while multipleResponses:
228 if fromQueue.empty():
229 break
230
231 response = fromQueue.get(True, cls._queueTimeout)
232 if not response:
233 break
234
235 response = copy.copy(response)
236 response.id = request.id
237 wire = response.to_wire()
238 try:
239 conn.send(struct.pack("!H", len(wire)))
240 conn.send(wire)
241 except socket.error as e:
242 # some of the tests are going to close
243 # the connection on us, just deal with it
244 break
245
246 conn.close()
247
248 sock.close()
249
250 @classmethod
251 def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
252 if useQueue:
253 cls._toResponderQueue.put(response, True, timeout)
254
255 if timeout:
256 cls._sock.settimeout(timeout)
257
258 try:
259 if not rawQuery:
260 query = query.to_wire()
261 cls._sock.send(query)
262 data = cls._sock.recv(4096)
263 except socket.timeout:
264 data = None
265 finally:
266 if timeout:
267 cls._sock.settimeout(None)
268
269 receivedQuery = None
270 message = None
271 if useQueue and not cls._fromResponderQueue.empty():
272 receivedQuery = cls._fromResponderQueue.get(True, timeout)
273 if data:
274 message = dns.message.from_wire(data)
275 return (receivedQuery, message)
276
277 @classmethod
278 def openTCPConnection(cls, timeout=None):
279 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
280 if timeout:
281 sock.settimeout(timeout)
282
283 sock.connect(("127.0.0.1", cls._dnsDistPort))
284 return sock
285
286 @classmethod
287 def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
288 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
289 if timeout:
290 sock.settimeout(timeout)
291
292 # 2.7.9+
293 if hasattr(ssl, 'create_default_context'):
294 sslctx = ssl.create_default_context(cafile=caCert)
295 sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
296 else:
297 sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
298
299 sslsock.connect(("127.0.0.1", port))
300 return sslsock
301
302 @classmethod
303 def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0):
304 if not rawQuery:
305 wire = query.to_wire()
306 else:
307 wire = query
308
309 if response:
310 cls._toResponderQueue.put(response, True, timeout)
311
312 sock.send(struct.pack("!H", len(wire)))
313 sock.send(wire)
314
315 @classmethod
316 def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
317 message = None
318 data = sock.recv(2)
319 if data:
320 (datalen,) = struct.unpack("!H", data)
321 data = sock.recv(datalen)
322 if data:
323 message = dns.message.from_wire(data)
324
325 if useQueue and not cls._fromResponderQueue.empty():
326 receivedQuery = cls._fromResponderQueue.get(True, timeout)
327 return (receivedQuery, message)
328 else:
329 return message
330
331 @classmethod
332 def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
333 message = None
334 if useQueue:
335 cls._toResponderQueue.put(response, True, timeout)
336
337 sock = cls.openTCPConnection(timeout)
338
339 try:
340 cls.sendTCPQueryOverConnection(sock, query, rawQuery)
341 message = cls.recvTCPResponseOverConnection(sock)
342 except socket.timeout as e:
343 print("Timeout: %s" % (str(e)))
344 except socket.error as e:
345 print("Network error: %s" % (str(e)))
346 finally:
347 sock.close()
348
349 receivedQuery = None
350 if useQueue and not cls._fromResponderQueue.empty():
351 receivedQuery = cls._fromResponderQueue.get(True, timeout)
352
353 return (receivedQuery, message)
354
355 @classmethod
356 def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False):
357 if useQueue:
358 for response in responses:
359 cls._toResponderQueue.put(response, True, timeout)
360 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
361 if timeout:
362 sock.settimeout(timeout)
363
364 sock.connect(("127.0.0.1", cls._dnsDistPort))
365 messages = []
366
367 try:
368 if not rawQuery:
369 wire = query.to_wire()
370 else:
371 wire = query
372
373 sock.send(struct.pack("!H", len(wire)))
374 sock.send(wire)
375 while True:
376 data = sock.recv(2)
377 if not data:
378 break
379 (datalen,) = struct.unpack("!H", data)
380 data = sock.recv(datalen)
381 messages.append(dns.message.from_wire(data))
382
383 except socket.timeout as e:
384 print("Timeout: %s" % (str(e)))
385 except socket.error as e:
386 print("Network error: %s" % (str(e)))
387 finally:
388 sock.close()
389
390 receivedQuery = None
391 if useQueue and not cls._fromResponderQueue.empty():
392 receivedQuery = cls._fromResponderQueue.get(True, timeout)
393 return (receivedQuery, messages)
394
395 def setUp(self):
396 # This function is called before every tests
397
398 # Clear the responses counters
399 for key in self._responsesCounter:
400 self._responsesCounter[key] = 0
401
402 self._healthCheckCounter = 0
403
404 # Make sure the queues are empty, in case
405 # a previous test failed
406 while not self._toResponderQueue.empty():
407 self._toResponderQueue.get(False)
408
409 while not self._fromResponderQueue.empty():
410 self._fromResponderQueue.get(False)
411
412 @classmethod
413 def clearToResponderQueue(cls):
414 while not cls._toResponderQueue.empty():
415 cls._toResponderQueue.get(False)
416
417 @classmethod
418 def clearFromResponderQueue(cls):
419 while not cls._fromResponderQueue.empty():
420 cls._fromResponderQueue.get(False)
421
422 @classmethod
423 def clearResponderQueues(cls):
424 cls.clearToResponderQueue()
425 cls.clearFromResponderQueue()
426
427 @staticmethod
428 def generateConsoleKey():
429 return libnacl.utils.salsa_key()
430
431 @classmethod
432 def _encryptConsole(cls, command, nonce):
433 command = command.encode('UTF-8')
434 if cls._consoleKey is None:
435 return command
436 return libnacl.crypto_secretbox(command, nonce, cls._consoleKey)
437
438 @classmethod
439 def _decryptConsole(cls, command, nonce):
440 if cls._consoleKey is None:
441 result = command
442 else:
443 result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey)
444 return result.decode('UTF-8')
445
446 @classmethod
447 def sendConsoleCommand(cls, command, timeout=1.0):
448 ourNonce = libnacl.utils.rand_nonce()
449 theirNonce = None
450 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
451 if timeout:
452 sock.settimeout(timeout)
453
454 sock.connect(("127.0.0.1", cls._consolePort))
455 sock.send(ourNonce)
456 theirNonce = sock.recv(len(ourNonce))
457 if len(theirNonce) != len(ourNonce):
458 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce)))
459 if len(theirNonce) == 0:
460 raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce)))
461 return None
462
463 halfNonceSize = int(len(ourNonce) / 2)
464 readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:]
465 writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:]
466 msg = cls._encryptConsole(command, writingNonce)
467 sock.send(struct.pack("!I", len(msg)))
468 sock.send(msg)
469 data = sock.recv(4)
470 if not data:
471 raise socket.error("Got EOF while reading the response size")
472
473 (responseLen,) = struct.unpack("!I", data)
474 data = sock.recv(responseLen)
475 response = cls._decryptConsole(data, readingNonce)
476 return response
477
478 def compareOptions(self, a, b):
479 self.assertEquals(len(a), len(b))
480 for idx in range(len(a)):
481 self.assertEquals(a[idx], b[idx])
482
483 def checkMessageNoEDNS(self, expected, received):
484 self.assertEquals(expected, received)
485 self.assertEquals(received.edns, -1)
486 self.assertEquals(len(received.options), 0)
487
488 def checkMessageEDNSWithoutOptions(self, expected, received):
489 self.assertEquals(expected, received)
490 self.assertEquals(received.edns, 0)
491
492 def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
493 self.assertEquals(expected, received)
494 self.assertEquals(received.edns, 0)
495 self.assertEquals(len(received.options), withCookies)
496 if withCookies:
497 for option in received.options:
498 self.assertEquals(option.otype, 10)
499
500 def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
501 self.assertEquals(expected, received)
502 self.assertEquals(received.edns, 0)
503 self.assertEquals(len(received.options), 1 + additionalOptions)
504 hasECS = False
505 for option in received.options:
506 if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE:
507 hasECS = True
508 else:
509 self.assertNotEquals(additionalOptions, 0)
510
511 self.compareOptions(expected.options, received.options)
512 self.assertTrue(hasECS)
513
514 def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0):
515 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
516
517 def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0):
518 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
519
520 def checkQueryEDNSWithoutECS(self, expected, received):
521 self.checkMessageEDNSWithoutECS(expected, received)
522
523 def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0):
524 self.checkMessageEDNSWithoutECS(expected, received, withCookies)
525
526 def checkQueryNoEDNS(self, expected, received):
527 self.checkMessageNoEDNS(expected, received)
528
529 def checkResponseNoEDNS(self, expected, received):
530 self.checkMessageNoEDNS(expected, received)