]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/dnsdisttests.py
Merge pull request #7479 from phonedph1/addrecords
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
1 #!/usr/bin/env python2
2
3 import copy
4 import os
5 import socket
6 import ssl
7 import struct
8 import subprocess
9 import sys
10 import threading
11 import time
12 import unittest
13 import clientsubnetoption
14 import dns
15 import dns.message
16 import libnacl
17 import libnacl.utils
18
19 # Python2/3 compatibility hacks
20 try:
21 from queue import Queue
22 except ImportError:
23 from Queue import Queue
24
25 try:
26 range = xrange
27 except NameError:
28 pass
29
30
31 class DNSDistTest(unittest.TestCase):
32 """
33 Set up a dnsdist instance and responder threads.
34 Queries sent to dnsdist are relayed to the responder threads,
35 who reply with the response provided by the tests themselves
36 on a queue. Responder threads also queue the queries received
37 from dnsdist on a separate queue, allowing the tests to check
38 that the queries sent from dnsdist were as expected.
39 """
40 _dnsDistPort = 5340
41 _dnsDistListeningAddr = "127.0.0.1"
42 _testServerPort = 5350
43 _toResponderQueue = Queue()
44 _fromResponderQueue = Queue()
45 _queueTimeout = 1
46 _dnsdistStartupDelay = 2.0
47 _dnsdist = None
48 _responsesCounter = {}
49 _config_template = """
50 """
51 _config_params = ['_testServerPort']
52 _acl = ['127.0.0.1/32']
53 _consolePort = 5199
54 _consoleKey = None
55 _healthCheckName = 'a.root-servers.net.'
56 _healthCheckCounter = 0
57 _answerUnexpected = True
58
59 @classmethod
60 def startResponders(cls):
61 print("Launching responders..")
62
63 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
64 cls._UDPResponder.setDaemon(True)
65 cls._UDPResponder.start()
66 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
67 cls._TCPResponder.setDaemon(True)
68 cls._TCPResponder.start()
69
70 @classmethod
71 def startDNSDist(cls):
72 print("Launching dnsdist..")
73 confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__))
74 params = tuple([getattr(cls, param) for param in cls._config_params])
75 print(params)
76 with open(confFile, 'w') as conf:
77 conf.write("-- Autogenerated by dnsdisttests.py\n")
78 conf.write(cls._config_template % params)
79
80 dnsdistcmd = [os.environ['DNSDISTBIN'], '-C', confFile,
81 '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
82 for acl in cls._acl:
83 dnsdistcmd.extend(['--acl', acl])
84 print(' '.join(dnsdistcmd))
85
86 # validate config with --check-config, which sets client=true, possibly exposing bugs.
87 testcmd = dnsdistcmd + ['--check-config']
88 try:
89 output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
90 except subprocess.CalledProcessError as exc:
91 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
92 expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
93 if output != expectedOutput:
94 raise AssertionError('dnsdist --check-config failed: %s' % output)
95
96 logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__))
97 with open(logFile, 'w') as fdLog:
98 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog)
99
100 if 'DNSDIST_FAST_TESTS' in os.environ:
101 delay = 0.5
102 else:
103 delay = cls._dnsdistStartupDelay
104
105 time.sleep(delay)
106
107 if cls._dnsdist.poll() is not None:
108 cls._dnsdist.kill()
109 sys.exit(cls._dnsdist.returncode)
110
111 @classmethod
112 def setUpSockets(cls):
113 print("Setting up UDP socket..")
114 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
115 cls._sock.settimeout(2.0)
116 cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
117
118 @classmethod
119 def setUpClass(cls):
120
121 cls.startResponders()
122 cls.startDNSDist()
123 cls.setUpSockets()
124
125 print("Launching tests..")
126
127 @classmethod
128 def tearDownClass(cls):
129 if 'DNSDIST_FAST_TESTS' in os.environ:
130 delay = 0.1
131 else:
132 delay = 1.0
133 if cls._dnsdist:
134 cls._dnsdist.terminate()
135 if cls._dnsdist.poll() is None:
136 time.sleep(delay)
137 if cls._dnsdist.poll() is None:
138 cls._dnsdist.kill()
139 cls._dnsdist.wait()
140
141 @classmethod
142 def _ResponderIncrementCounter(cls):
143 if threading.currentThread().name in cls._responsesCounter:
144 cls._responsesCounter[threading.currentThread().name] += 1
145 else:
146 cls._responsesCounter[threading.currentThread().name] = 1
147
148 @classmethod
149 def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
150 response = None
151 if len(request.question) != 1:
152 print("Skipping query with question count %d" % (len(request.question)))
153 return None
154 healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
155 if healthCheck:
156 cls._healthCheckCounter += 1
157 response = dns.message.make_response(request)
158 else:
159 cls._ResponderIncrementCounter()
160 if not fromQueue.empty():
161 toQueue.put(request, True, cls._queueTimeout)
162 if synthesize is None:
163 response = fromQueue.get(True, cls._queueTimeout)
164 if response:
165 response = copy.copy(response)
166 response.id = request.id
167
168 if not response:
169 if synthesize is not None:
170 response = dns.message.make_response(request)
171 response.set_rcode(synthesize)
172 elif cls._answerUnexpected:
173 response = dns.message.make_response(request)
174 response.set_rcode(dns.rcode.SERVFAIL)
175
176 return response
177
178 @classmethod
179 def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False):
180 # trailingDataResponse=True means "ignore trailing data".
181 # Other values are either False (meaning "raise an exception")
182 # or are interpreted as a response RCODE for queries with trailing data.
183 ignoreTrailing = trailingDataResponse is True
184
185 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
186 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
187 sock.bind(("127.0.0.1", port))
188 while True:
189 data, addr = sock.recvfrom(4096)
190 forceRcode = None
191 try:
192 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
193 except dns.message.TrailingJunk as e:
194 if trailingDataResponse is False or forceRcode is True:
195 raise
196 print("UDP query with trailing data, synthesizing response")
197 request = dns.message.from_wire(data, ignore_trailing=True)
198 forceRcode = trailingDataResponse
199
200 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
201 if not response:
202 continue
203
204 sock.settimeout(2.0)
205 sock.sendto(response.to_wire(), addr)
206 sock.settimeout(None)
207 sock.close()
208
209 @classmethod
210 def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False):
211 # trailingDataResponse=True means "ignore trailing data".
212 # Other values are either False (meaning "raise an exception")
213 # or are interpreted as a response RCODE for queries with trailing data.
214 ignoreTrailing = trailingDataResponse is True
215
216 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
217 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
218 try:
219 sock.bind(("127.0.0.1", port))
220 except socket.error as e:
221 print("Error binding in the TCP responder: %s" % str(e))
222 sys.exit(1)
223
224 sock.listen(100)
225 while True:
226 (conn, _) = sock.accept()
227 conn.settimeout(2.0)
228 data = conn.recv(2)
229 if not data:
230 conn.close()
231 continue
232
233 (datalen,) = struct.unpack("!H", data)
234 data = conn.recv(datalen)
235 forceRcode = None
236 try:
237 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
238 except dns.message.TrailingJunk as e:
239 if trailingDataResponse is False or forceRcode is True:
240 raise
241 print("TCP query with trailing data, synthesizing response")
242 request = dns.message.from_wire(data, ignore_trailing=True)
243 forceRcode = trailingDataResponse
244
245 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
246 if not response:
247 conn.close()
248 continue
249
250 wire = response.to_wire()
251 conn.send(struct.pack("!H", len(wire)))
252 conn.send(wire)
253
254 while multipleResponses:
255 if fromQueue.empty():
256 break
257
258 response = fromQueue.get(True, cls._queueTimeout)
259 if not response:
260 break
261
262 response = copy.copy(response)
263 response.id = request.id
264 wire = response.to_wire()
265 try:
266 conn.send(struct.pack("!H", len(wire)))
267 conn.send(wire)
268 except socket.error as e:
269 # some of the tests are going to close
270 # the connection on us, just deal with it
271 break
272
273 conn.close()
274
275 sock.close()
276
277 @classmethod
278 def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
279 if useQueue:
280 cls._toResponderQueue.put(response, True, timeout)
281
282 if timeout:
283 cls._sock.settimeout(timeout)
284
285 try:
286 if not rawQuery:
287 query = query.to_wire()
288 cls._sock.send(query)
289 data = cls._sock.recv(4096)
290 except socket.timeout:
291 data = None
292 finally:
293 if timeout:
294 cls._sock.settimeout(None)
295
296 receivedQuery = None
297 message = None
298 if useQueue and not cls._fromResponderQueue.empty():
299 receivedQuery = cls._fromResponderQueue.get(True, timeout)
300 if data:
301 message = dns.message.from_wire(data)
302 return (receivedQuery, message)
303
304 @classmethod
305 def openTCPConnection(cls, timeout=None):
306 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
307 if timeout:
308 sock.settimeout(timeout)
309
310 sock.connect(("127.0.0.1", cls._dnsDistPort))
311 return sock
312
313 @classmethod
314 def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
315 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
316 if timeout:
317 sock.settimeout(timeout)
318
319 # 2.7.9+
320 if hasattr(ssl, 'create_default_context'):
321 sslctx = ssl.create_default_context(cafile=caCert)
322 sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
323 else:
324 sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
325
326 sslsock.connect(("127.0.0.1", port))
327 return sslsock
328
329 @classmethod
330 def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0):
331 if not rawQuery:
332 wire = query.to_wire()
333 else:
334 wire = query
335
336 if response:
337 cls._toResponderQueue.put(response, True, timeout)
338
339 sock.send(struct.pack("!H", len(wire)))
340 sock.send(wire)
341
342 @classmethod
343 def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
344 message = None
345 data = sock.recv(2)
346 if data:
347 (datalen,) = struct.unpack("!H", data)
348 data = sock.recv(datalen)
349 if data:
350 message = dns.message.from_wire(data)
351
352 if useQueue and not cls._fromResponderQueue.empty():
353 receivedQuery = cls._fromResponderQueue.get(True, timeout)
354 return (receivedQuery, message)
355 else:
356 return message
357
358 @classmethod
359 def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
360 message = None
361 if useQueue:
362 cls._toResponderQueue.put(response, True, timeout)
363
364 sock = cls.openTCPConnection(timeout)
365
366 try:
367 cls.sendTCPQueryOverConnection(sock, query, rawQuery)
368 message = cls.recvTCPResponseOverConnection(sock)
369 except socket.timeout as e:
370 print("Timeout: %s" % (str(e)))
371 except socket.error as e:
372 print("Network error: %s" % (str(e)))
373 finally:
374 sock.close()
375
376 receivedQuery = None
377 if useQueue and not cls._fromResponderQueue.empty():
378 receivedQuery = cls._fromResponderQueue.get(True, timeout)
379
380 return (receivedQuery, message)
381
382 @classmethod
383 def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False):
384 if useQueue:
385 for response in responses:
386 cls._toResponderQueue.put(response, True, timeout)
387 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
388 if timeout:
389 sock.settimeout(timeout)
390
391 sock.connect(("127.0.0.1", cls._dnsDistPort))
392 messages = []
393
394 try:
395 if not rawQuery:
396 wire = query.to_wire()
397 else:
398 wire = query
399
400 sock.send(struct.pack("!H", len(wire)))
401 sock.send(wire)
402 while True:
403 data = sock.recv(2)
404 if not data:
405 break
406 (datalen,) = struct.unpack("!H", data)
407 data = sock.recv(datalen)
408 messages.append(dns.message.from_wire(data))
409
410 except socket.timeout as e:
411 print("Timeout: %s" % (str(e)))
412 except socket.error as e:
413 print("Network error: %s" % (str(e)))
414 finally:
415 sock.close()
416
417 receivedQuery = None
418 if useQueue and not cls._fromResponderQueue.empty():
419 receivedQuery = cls._fromResponderQueue.get(True, timeout)
420 return (receivedQuery, messages)
421
422 def setUp(self):
423 # This function is called before every tests
424
425 # Clear the responses counters
426 for key in self._responsesCounter:
427 self._responsesCounter[key] = 0
428
429 self._healthCheckCounter = 0
430
431 # Make sure the queues are empty, in case
432 # a previous test failed
433 while not self._toResponderQueue.empty():
434 self._toResponderQueue.get(False)
435
436 while not self._fromResponderQueue.empty():
437 self._fromResponderQueue.get(False)
438
439 @classmethod
440 def clearToResponderQueue(cls):
441 while not cls._toResponderQueue.empty():
442 cls._toResponderQueue.get(False)
443
444 @classmethod
445 def clearFromResponderQueue(cls):
446 while not cls._fromResponderQueue.empty():
447 cls._fromResponderQueue.get(False)
448
449 @classmethod
450 def clearResponderQueues(cls):
451 cls.clearToResponderQueue()
452 cls.clearFromResponderQueue()
453
454 @staticmethod
455 def generateConsoleKey():
456 return libnacl.utils.salsa_key()
457
458 @classmethod
459 def _encryptConsole(cls, command, nonce):
460 command = command.encode('UTF-8')
461 if cls._consoleKey is None:
462 return command
463 return libnacl.crypto_secretbox(command, nonce, cls._consoleKey)
464
465 @classmethod
466 def _decryptConsole(cls, command, nonce):
467 if cls._consoleKey is None:
468 result = command
469 else:
470 result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey)
471 return result.decode('UTF-8')
472
473 @classmethod
474 def sendConsoleCommand(cls, command, timeout=1.0):
475 ourNonce = libnacl.utils.rand_nonce()
476 theirNonce = None
477 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
478 if timeout:
479 sock.settimeout(timeout)
480
481 sock.connect(("127.0.0.1", cls._consolePort))
482 sock.send(ourNonce)
483 theirNonce = sock.recv(len(ourNonce))
484 if len(theirNonce) != len(ourNonce):
485 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce)))
486 if len(theirNonce) == 0:
487 raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce)))
488 return None
489
490 halfNonceSize = int(len(ourNonce) / 2)
491 readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:]
492 writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:]
493 msg = cls._encryptConsole(command, writingNonce)
494 sock.send(struct.pack("!I", len(msg)))
495 sock.send(msg)
496 data = sock.recv(4)
497 if not data:
498 raise socket.error("Got EOF while reading the response size")
499
500 (responseLen,) = struct.unpack("!I", data)
501 data = sock.recv(responseLen)
502 response = cls._decryptConsole(data, readingNonce)
503 return response
504
505 def compareOptions(self, a, b):
506 self.assertEquals(len(a), len(b))
507 for idx in range(len(a)):
508 self.assertEquals(a[idx], b[idx])
509
510 def checkMessageNoEDNS(self, expected, received):
511 self.assertEquals(expected, received)
512 self.assertEquals(received.edns, -1)
513 self.assertEquals(len(received.options), 0)
514
515 def checkMessageEDNSWithoutOptions(self, expected, received):
516 self.assertEquals(expected, received)
517 self.assertEquals(received.edns, 0)
518
519 def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
520 self.assertEquals(expected, received)
521 self.assertEquals(received.edns, 0)
522 self.assertEquals(len(received.options), withCookies)
523 if withCookies:
524 for option in received.options:
525 self.assertEquals(option.otype, 10)
526
527 def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
528 self.assertEquals(expected, received)
529 self.assertEquals(received.edns, 0)
530 self.assertEquals(len(received.options), 1 + additionalOptions)
531 hasECS = False
532 for option in received.options:
533 if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE:
534 hasECS = True
535 else:
536 self.assertNotEquals(additionalOptions, 0)
537
538 self.compareOptions(expected.options, received.options)
539 self.assertTrue(hasECS)
540
541 def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0):
542 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
543
544 def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0):
545 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
546
547 def checkQueryEDNSWithoutECS(self, expected, received):
548 self.checkMessageEDNSWithoutECS(expected, received)
549
550 def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0):
551 self.checkMessageEDNSWithoutECS(expected, received, withCookies)
552
553 def checkQueryNoEDNS(self, expected, received):
554 self.checkMessageNoEDNS(expected, received)
555
556 def checkResponseNoEDNS(self, expected, received):
557 self.checkMessageNoEDNS(expected, received)