]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/dnsdisttests.py
Merge pull request #6823 from klaus3000/load-ourserial-on-NOTIFY
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
1 #!/usr/bin/env python2
2
3 import copy
4 import os
5 import socket
6 import ssl
7 import struct
8 import subprocess
9 import sys
10 import threading
11 import time
12 import unittest
13 import clientsubnetoption
14 import dns
15 import dns.message
16 import libnacl
17 import libnacl.utils
18
19 # Python2/3 compatibility hacks
20 try:
21 from queue import Queue
22 except ImportError:
23 from Queue import Queue
24
25 try:
26 range = xrange
27 except NameError:
28 pass
29
30
31 class DNSDistTest(unittest.TestCase):
32 """
33 Set up a dnsdist instance and responder threads.
34 Queries sent to dnsdist are relayed to the responder threads,
35 who reply with the response provided by the tests themselves
36 on a queue. Responder threads also queue the queries received
37 from dnsdist on a separate queue, allowing the tests to check
38 that the queries sent from dnsdist were as expected.
39 """
40 _dnsDistPort = 5340
41 _dnsDistListeningAddr = "127.0.0.1"
42 _testServerPort = 5350
43 _toResponderQueue = Queue()
44 _fromResponderQueue = Queue()
45 _queueTimeout = 1
46 _dnsdistStartupDelay = 2.0
47 _dnsdist = None
48 _responsesCounter = {}
49 _shutUp = True
50 _config_template = """
51 """
52 _config_params = ['_testServerPort']
53 _acl = ['127.0.0.1/32']
54 _consolePort = 5199
55 _consoleKey = None
56 _healthCheckName = 'a.root-servers.net.'
57 _healthCheckCounter = 0
58 _answerUnexpected = True
59
60 @classmethod
61 def startResponders(cls):
62 print("Launching responders..")
63
64 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
65 cls._UDPResponder.setDaemon(True)
66 cls._UDPResponder.start()
67 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
68 cls._TCPResponder.setDaemon(True)
69 cls._TCPResponder.start()
70
71 @classmethod
72 def startDNSDist(cls, shutUp=True):
73 print("Launching dnsdist..")
74 conffile = 'dnsdist_test.conf'
75 params = tuple([getattr(cls, param) for param in cls._config_params])
76 print(params)
77 with open(conffile, 'w') as conf:
78 conf.write("-- Autogenerated by dnsdisttests.py\n")
79 conf.write(cls._config_template % params)
80
81 dnsdistcmd = [os.environ['DNSDISTBIN'], '-C', conffile,
82 '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
83 for acl in cls._acl:
84 dnsdistcmd.extend(['--acl', acl])
85 print(' '.join(dnsdistcmd))
86
87 # validate config with --check-config, which sets client=true, possibly exposing bugs.
88 testcmd = dnsdistcmd + ['--check-config']
89 output = subprocess.check_output(testcmd, close_fds=True)
90 if output != b'Configuration \'dnsdist_test.conf\' OK!\n':
91 raise AssertionError('dnsdist --check-config failed: %s' % output)
92
93 if shutUp:
94 with open(os.devnull, 'w') as fdDevNull:
95 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdDevNull)
96 else:
97 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True)
98
99 if 'DNSDIST_FAST_TESTS' in os.environ:
100 delay = 0.5
101 else:
102 delay = cls._dnsdistStartupDelay
103
104 time.sleep(delay)
105
106 if cls._dnsdist.poll() is not None:
107 cls._dnsdist.kill()
108 sys.exit(cls._dnsdist.returncode)
109
110 @classmethod
111 def setUpSockets(cls):
112 print("Setting up UDP socket..")
113 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
114 cls._sock.settimeout(2.0)
115 cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
116
117 @classmethod
118 def setUpClass(cls):
119
120 cls.startResponders()
121 cls.startDNSDist(cls._shutUp)
122 cls.setUpSockets()
123
124 print("Launching tests..")
125
126 @classmethod
127 def tearDownClass(cls):
128 if 'DNSDIST_FAST_TESTS' in os.environ:
129 delay = 0.1
130 else:
131 delay = 1.0
132 if cls._dnsdist:
133 cls._dnsdist.terminate()
134 if cls._dnsdist.poll() is None:
135 time.sleep(delay)
136 if cls._dnsdist.poll() is None:
137 cls._dnsdist.kill()
138 cls._dnsdist.wait()
139
140 @classmethod
141 def _ResponderIncrementCounter(cls):
142 if threading.currentThread().name in cls._responsesCounter:
143 cls._responsesCounter[threading.currentThread().name] += 1
144 else:
145 cls._responsesCounter[threading.currentThread().name] = 1
146
147 @classmethod
148 def _getResponse(cls, request, fromQueue, toQueue):
149 response = None
150 if len(request.question) != 1:
151 print("Skipping query with question count %d" % (len(request.question)))
152 return None
153 healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
154 if healthCheck:
155 cls._healthCheckCounter += 1
156 else:
157 cls._ResponderIncrementCounter()
158 if not fromQueue.empty():
159 response = fromQueue.get(True, cls._queueTimeout)
160 if response:
161 response = copy.copy(response)
162 response.id = request.id
163 toQueue.put(request, True, cls._queueTimeout)
164
165 if not response:
166 if healthCheck:
167 response = dns.message.make_response(request)
168 elif cls._answerUnexpected:
169 response = dns.message.make_response(request)
170 response.set_rcode(dns.rcode.SERVFAIL)
171
172 return response
173
174 @classmethod
175 def UDPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False):
176 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
177 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
178 sock.bind(("127.0.0.1", port))
179 while True:
180 data, addr = sock.recvfrom(4096)
181 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
182 response = cls._getResponse(request, fromQueue, toQueue)
183
184 if not response:
185 continue
186
187 sock.settimeout(2.0)
188 sock.sendto(response.to_wire(), addr)
189 sock.settimeout(None)
190 sock.close()
191
192 @classmethod
193 def TCPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False, multipleResponses=False):
194 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
195 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
196 try:
197 sock.bind(("127.0.0.1", port))
198 except socket.error as e:
199 print("Error binding in the TCP responder: %s" % str(e))
200 sys.exit(1)
201
202 sock.listen(100)
203 while True:
204 (conn, _) = sock.accept()
205 conn.settimeout(2.0)
206 data = conn.recv(2)
207 if not data:
208 conn.close()
209 continue
210
211 (datalen,) = struct.unpack("!H", data)
212 data = conn.recv(datalen)
213 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
214 response = cls._getResponse(request, fromQueue, toQueue)
215
216 if not response:
217 conn.close()
218 continue
219
220 wire = response.to_wire()
221 conn.send(struct.pack("!H", len(wire)))
222 conn.send(wire)
223
224 while multipleResponses:
225 if fromQueue.empty():
226 break
227
228 response = fromQueue.get(True, cls._queueTimeout)
229 if not response:
230 break
231
232 response = copy.copy(response)
233 response.id = request.id
234 wire = response.to_wire()
235 try:
236 conn.send(struct.pack("!H", len(wire)))
237 conn.send(wire)
238 except socket.error as e:
239 # some of the tests are going to close
240 # the connection on us, just deal with it
241 break
242
243 conn.close()
244
245 sock.close()
246
247 @classmethod
248 def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
249 if useQueue:
250 cls._toResponderQueue.put(response, True, timeout)
251
252 if timeout:
253 cls._sock.settimeout(timeout)
254
255 try:
256 if not rawQuery:
257 query = query.to_wire()
258 cls._sock.send(query)
259 data = cls._sock.recv(4096)
260 except socket.timeout:
261 data = None
262 finally:
263 if timeout:
264 cls._sock.settimeout(None)
265
266 receivedQuery = None
267 message = None
268 if useQueue and not cls._fromResponderQueue.empty():
269 receivedQuery = cls._fromResponderQueue.get(True, timeout)
270 if data:
271 message = dns.message.from_wire(data)
272 return (receivedQuery, message)
273
274 @classmethod
275 def openTCPConnection(cls, timeout=None):
276 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
277 if timeout:
278 sock.settimeout(timeout)
279
280 sock.connect(("127.0.0.1", cls._dnsDistPort))
281 return sock
282
283 @classmethod
284 def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
285 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
286 if timeout:
287 sock.settimeout(timeout)
288
289 # 2.7.9+
290 if hasattr(ssl, 'create_default_context'):
291 sslctx = ssl.create_default_context(cafile=caCert)
292 sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
293 else:
294 sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
295
296 sslsock.connect(("127.0.0.1", port))
297 return sslsock
298
299 @classmethod
300 def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0):
301 if not rawQuery:
302 wire = query.to_wire()
303 else:
304 wire = query
305
306 if response:
307 cls._toResponderQueue.put(response, True, timeout)
308
309 sock.send(struct.pack("!H", len(wire)))
310 sock.send(wire)
311
312 @classmethod
313 def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
314 message = None
315 data = sock.recv(2)
316 if data:
317 (datalen,) = struct.unpack("!H", data)
318 data = sock.recv(datalen)
319 if data:
320 message = dns.message.from_wire(data)
321
322 if useQueue and not cls._fromResponderQueue.empty():
323 receivedQuery = cls._fromResponderQueue.get(True, timeout)
324 return (receivedQuery, message)
325 else:
326 return message
327
328 @classmethod
329 def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
330 message = None
331 if useQueue:
332 cls._toResponderQueue.put(response, True, timeout)
333
334 sock = cls.openTCPConnection(timeout)
335
336 try:
337 cls.sendTCPQueryOverConnection(sock, query, rawQuery)
338 message = cls.recvTCPResponseOverConnection(sock)
339 except socket.timeout as e:
340 print("Timeout: %s" % (str(e)))
341 except socket.error as e:
342 print("Network error: %s" % (str(e)))
343 finally:
344 sock.close()
345
346 receivedQuery = None
347 if useQueue and not cls._fromResponderQueue.empty():
348 receivedQuery = cls._fromResponderQueue.get(True, timeout)
349
350 return (receivedQuery, message)
351
352 @classmethod
353 def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False):
354 if useQueue:
355 for response in responses:
356 cls._toResponderQueue.put(response, True, timeout)
357 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
358 if timeout:
359 sock.settimeout(timeout)
360
361 sock.connect(("127.0.0.1", cls._dnsDistPort))
362 messages = []
363
364 try:
365 if not rawQuery:
366 wire = query.to_wire()
367 else:
368 wire = query
369
370 sock.send(struct.pack("!H", len(wire)))
371 sock.send(wire)
372 while True:
373 data = sock.recv(2)
374 if not data:
375 break
376 (datalen,) = struct.unpack("!H", data)
377 data = sock.recv(datalen)
378 messages.append(dns.message.from_wire(data))
379
380 except socket.timeout as e:
381 print("Timeout: %s" % (str(e)))
382 except socket.error as e:
383 print("Network error: %s" % (str(e)))
384 finally:
385 sock.close()
386
387 receivedQuery = None
388 if useQueue and not cls._fromResponderQueue.empty():
389 receivedQuery = cls._fromResponderQueue.get(True, timeout)
390 return (receivedQuery, messages)
391
392 def setUp(self):
393 # This function is called before every tests
394
395 # Clear the responses counters
396 for key in self._responsesCounter:
397 self._responsesCounter[key] = 0
398
399 self._healthCheckCounter = 0
400
401 # Make sure the queues are empty, in case
402 # a previous test failed
403 while not self._toResponderQueue.empty():
404 self._toResponderQueue.get(False)
405
406 while not self._fromResponderQueue.empty():
407 self._fromResponderQueue.get(False)
408
409 @classmethod
410 def clearToResponderQueue(cls):
411 while not cls._toResponderQueue.empty():
412 cls._toResponderQueue.get(False)
413
414 @classmethod
415 def clearFromResponderQueue(cls):
416 while not cls._fromResponderQueue.empty():
417 cls._fromResponderQueue.get(False)
418
419 @classmethod
420 def clearResponderQueues(cls):
421 cls.clearToResponderQueue()
422 cls.clearFromResponderQueue()
423
424 @staticmethod
425 def generateConsoleKey():
426 return libnacl.utils.salsa_key()
427
428 @classmethod
429 def _encryptConsole(cls, command, nonce):
430 command = command.encode('UTF-8')
431 if cls._consoleKey is None:
432 return command
433 return libnacl.crypto_secretbox(command, nonce, cls._consoleKey)
434
435 @classmethod
436 def _decryptConsole(cls, command, nonce):
437 if cls._consoleKey is None:
438 result = command
439 else:
440 result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey)
441 return result.decode('UTF-8')
442
443 @classmethod
444 def sendConsoleCommand(cls, command, timeout=1.0):
445 ourNonce = libnacl.utils.rand_nonce()
446 theirNonce = None
447 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
448 if timeout:
449 sock.settimeout(timeout)
450
451 sock.connect(("127.0.0.1", cls._consolePort))
452 sock.send(ourNonce)
453 theirNonce = sock.recv(len(ourNonce))
454 if len(theirNonce) != len(ourNonce):
455 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce)))
456 if len(theirNonce) == 0:
457 raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce)))
458 return None
459
460 halfNonceSize = int(len(ourNonce) / 2)
461 readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:]
462 writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:]
463 msg = cls._encryptConsole(command, writingNonce)
464 sock.send(struct.pack("!I", len(msg)))
465 sock.send(msg)
466 data = sock.recv(4)
467 if not data:
468 raise socket.error("Got EOF while reading the response size")
469
470 (responseLen,) = struct.unpack("!I", data)
471 data = sock.recv(responseLen)
472 response = cls._decryptConsole(data, readingNonce)
473 return response
474
475 def compareOptions(self, a, b):
476 self.assertEquals(len(a), len(b))
477 for idx in range(len(a)):
478 self.assertEquals(a[idx], b[idx])
479
480 def checkMessageNoEDNS(self, expected, received):
481 self.assertEquals(expected, received)
482 self.assertEquals(received.edns, -1)
483 self.assertEquals(len(received.options), 0)
484
485 def checkMessageEDNSWithoutOptions(self, expected, received):
486 self.assertEquals(expected, received)
487 self.assertEquals(received.edns, 0)
488
489 def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
490 self.assertEquals(expected, received)
491 self.assertEquals(received.edns, 0)
492 self.assertEquals(len(received.options), withCookies)
493 if withCookies:
494 for option in received.options:
495 self.assertEquals(option.otype, 10)
496
497 def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
498 self.assertEquals(expected, received)
499 self.assertEquals(received.edns, 0)
500 self.assertEquals(len(received.options), 1 + additionalOptions)
501 hasECS = False
502 for option in received.options:
503 if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE:
504 hasECS = True
505 else:
506 self.assertNotEquals(additionalOptions, 0)
507
508 self.compareOptions(expected.options, received.options)
509 self.assertTrue(hasECS)
510
511 def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0):
512 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
513
514 def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0):
515 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
516
517 def checkQueryEDNSWithoutECS(self, expected, received):
518 self.checkMessageEDNSWithoutECS(expected, received)
519
520 def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0):
521 self.checkMessageEDNSWithoutECS(expected, received, withCookies)
522
523 def checkQueryNoEDNS(self, expected, received):
524 self.checkMessageNoEDNS(expected, received)
525
526 def checkResponseNoEDNS(self, expected, received):
527 self.checkMessageNoEDNS(expected, received)