]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/dnsdisttests.py
Merge pull request #8795 from omoerbeek/rec-lua-docs-policytag
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
1 #!/usr/bin/env python2
2
3 import copy
4 import os
5 import socket
6 import ssl
7 import struct
8 import subprocess
9 import sys
10 import threading
11 import time
12 import unittest
13 import clientsubnetoption
14 import dns
15 import dns.message
16 import libnacl
17 import libnacl.utils
18
19 from eqdnsmessage import AssertEqualDNSMessageMixin
20
21 # Python2/3 compatibility hacks
22 try:
23 from queue import Queue
24 except ImportError:
25 from Queue import Queue
26
27 try:
28 range = xrange
29 except NameError:
30 pass
31
32
33 class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
34 """
35 Set up a dnsdist instance and responder threads.
36 Queries sent to dnsdist are relayed to the responder threads,
37 who reply with the response provided by the tests themselves
38 on a queue. Responder threads also queue the queries received
39 from dnsdist on a separate queue, allowing the tests to check
40 that the queries sent from dnsdist were as expected.
41 """
42 _dnsDistPort = 5340
43 _dnsDistListeningAddr = "127.0.0.1"
44 _testServerPort = 5350
45 _toResponderQueue = Queue()
46 _fromResponderQueue = Queue()
47 _queueTimeout = 1
48 _dnsdistStartupDelay = 2.0
49 _dnsdist = None
50 _responsesCounter = {}
51 _config_template = """
52 """
53 _config_params = ['_testServerPort']
54 _acl = ['127.0.0.1/32']
55 _consolePort = 5199
56 _consoleKey = None
57 _healthCheckName = 'a.root-servers.net.'
58 _healthCheckCounter = 0
59 _answerUnexpected = True
60 _checkConfigExpectedOutput = None
61
62 @classmethod
63 def startResponders(cls):
64 print("Launching responders..")
65
66 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
67 cls._UDPResponder.setDaemon(True)
68 cls._UDPResponder.start()
69 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
70 cls._TCPResponder.setDaemon(True)
71 cls._TCPResponder.start()
72
73 @classmethod
74 def startDNSDist(cls):
75 print("Launching dnsdist..")
76 confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__))
77 params = tuple([getattr(cls, param) for param in cls._config_params])
78 print(params)
79 with open(confFile, 'w') as conf:
80 conf.write("-- Autogenerated by dnsdisttests.py\n")
81 conf.write(cls._config_template % params)
82
83 dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile,
84 '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
85 for acl in cls._acl:
86 dnsdistcmd.extend(['--acl', acl])
87 print(' '.join(dnsdistcmd))
88
89 # validate config with --check-config, which sets client=true, possibly exposing bugs.
90 testcmd = dnsdistcmd + ['--check-config']
91 try:
92 output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
93 except subprocess.CalledProcessError as exc:
94 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
95 if cls._checkConfigExpectedOutput is not None:
96 expectedOutput = cls._checkConfigExpectedOutput
97 else:
98 expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
99 if output != expectedOutput:
100 raise AssertionError('dnsdist --check-config failed: %s' % output)
101
102 logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__))
103 with open(logFile, 'w') as fdLog:
104 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog)
105
106 if 'DNSDIST_FAST_TESTS' in os.environ:
107 delay = 0.5
108 else:
109 delay = cls._dnsdistStartupDelay
110
111 time.sleep(delay)
112
113 if cls._dnsdist.poll() is not None:
114 cls._dnsdist.kill()
115 sys.exit(cls._dnsdist.returncode)
116
117 @classmethod
118 def setUpSockets(cls):
119 print("Setting up UDP socket..")
120 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
121 cls._sock.settimeout(2.0)
122 cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
123
124 @classmethod
125 def setUpClass(cls):
126
127 cls.startResponders()
128 cls.startDNSDist()
129 cls.setUpSockets()
130
131 print("Launching tests..")
132
133 @classmethod
134 def tearDownClass(cls):
135 if 'DNSDIST_FAST_TESTS' in os.environ:
136 delay = 0.1
137 else:
138 delay = 1.0
139 if cls._dnsdist:
140 cls._dnsdist.terminate()
141 if cls._dnsdist.poll() is None:
142 time.sleep(delay)
143 if cls._dnsdist.poll() is None:
144 cls._dnsdist.kill()
145 cls._dnsdist.wait()
146
147 @classmethod
148 def _ResponderIncrementCounter(cls):
149 if threading.currentThread().name in cls._responsesCounter:
150 cls._responsesCounter[threading.currentThread().name] += 1
151 else:
152 cls._responsesCounter[threading.currentThread().name] = 1
153
154 @classmethod
155 def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
156 response = None
157 if len(request.question) != 1:
158 print("Skipping query with question count %d" % (len(request.question)))
159 return None
160 healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
161 if healthCheck:
162 cls._healthCheckCounter += 1
163 response = dns.message.make_response(request)
164 else:
165 cls._ResponderIncrementCounter()
166 if not fromQueue.empty():
167 toQueue.put(request, True, cls._queueTimeout)
168 if synthesize is None:
169 response = fromQueue.get(True, cls._queueTimeout)
170 if response:
171 response = copy.copy(response)
172 response.id = request.id
173
174 if not response:
175 if synthesize is not None:
176 response = dns.message.make_response(request)
177 response.set_rcode(synthesize)
178 elif cls._answerUnexpected:
179 response = dns.message.make_response(request)
180 response.set_rcode(dns.rcode.SERVFAIL)
181
182 return response
183
184 @classmethod
185 def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None):
186 # trailingDataResponse=True means "ignore trailing data".
187 # Other values are either False (meaning "raise an exception")
188 # or are interpreted as a response RCODE for queries with trailing data.
189 # callback is invoked for every -even healthcheck ones- query and should return a raw response
190 ignoreTrailing = trailingDataResponse is True
191
192 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
193 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
194 sock.bind(("127.0.0.1", port))
195 while True:
196 data, addr = sock.recvfrom(4096)
197 forceRcode = None
198 try:
199 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
200 except dns.message.TrailingJunk as e:
201 if trailingDataResponse is False or forceRcode is True:
202 raise
203 print("UDP query with trailing data, synthesizing response")
204 request = dns.message.from_wire(data, ignore_trailing=True)
205 forceRcode = trailingDataResponse
206
207 wire = None
208 if callback:
209 wire = callback(request)
210 else:
211 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
212 if response:
213 wire = response.to_wire()
214
215 if not wire:
216 continue
217
218 sock.settimeout(2.0)
219 sock.sendto(wire, addr)
220 sock.settimeout(None)
221 sock.close()
222
223 @classmethod
224 def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None):
225 # trailingDataResponse=True means "ignore trailing data".
226 # Other values are either False (meaning "raise an exception")
227 # or are interpreted as a response RCODE for queries with trailing data.
228 # callback is invoked for every -even healthcheck ones- query and should return a raw response
229 ignoreTrailing = trailingDataResponse is True
230
231 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
232 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
233 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
234 try:
235 sock.bind(("127.0.0.1", port))
236 except socket.error as e:
237 print("Error binding in the TCP responder: %s" % str(e))
238 sys.exit(1)
239
240 sock.listen(100)
241 while True:
242 (conn, _) = sock.accept()
243 conn.settimeout(5.0)
244 data = conn.recv(2)
245 if not data:
246 conn.close()
247 continue
248
249 (datalen,) = struct.unpack("!H", data)
250 data = conn.recv(datalen)
251 forceRcode = None
252 try:
253 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
254 except dns.message.TrailingJunk as e:
255 if trailingDataResponse is False or forceRcode is True:
256 raise
257 print("TCP query with trailing data, synthesizing response")
258 request = dns.message.from_wire(data, ignore_trailing=True)
259 forceRcode = trailingDataResponse
260
261 if callback:
262 wire = callback(request)
263 else:
264 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
265 if response:
266 wire = response.to_wire(max_size=65535)
267
268 if not wire:
269 conn.close()
270 continue
271
272 conn.send(struct.pack("!H", len(wire)))
273 conn.send(wire)
274
275 while multipleResponses:
276 if fromQueue.empty():
277 break
278
279 response = fromQueue.get(True, cls._queueTimeout)
280 if not response:
281 break
282
283 response = copy.copy(response)
284 response.id = request.id
285 wire = response.to_wire(max_size=65535)
286 try:
287 conn.send(struct.pack("!H", len(wire)))
288 conn.send(wire)
289 except socket.error as e:
290 # some of the tests are going to close
291 # the connection on us, just deal with it
292 break
293
294 conn.close()
295
296 sock.close()
297
298 @classmethod
299 def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
300 if useQueue:
301 cls._toResponderQueue.put(response, True, timeout)
302
303 if timeout:
304 cls._sock.settimeout(timeout)
305
306 try:
307 if not rawQuery:
308 query = query.to_wire()
309 cls._sock.send(query)
310 data = cls._sock.recv(4096)
311 except socket.timeout:
312 data = None
313 finally:
314 if timeout:
315 cls._sock.settimeout(None)
316
317 receivedQuery = None
318 message = None
319 if useQueue and not cls._fromResponderQueue.empty():
320 receivedQuery = cls._fromResponderQueue.get(True, timeout)
321 if data:
322 message = dns.message.from_wire(data)
323 return (receivedQuery, message)
324
325 @classmethod
326 def openTCPConnection(cls, timeout=None):
327 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
328 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
329 if timeout:
330 sock.settimeout(timeout)
331
332 sock.connect(("127.0.0.1", cls._dnsDistPort))
333 return sock
334
335 @classmethod
336 def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
337 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
338 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
339 if timeout:
340 sock.settimeout(timeout)
341
342 # 2.7.9+
343 if hasattr(ssl, 'create_default_context'):
344 sslctx = ssl.create_default_context(cafile=caCert)
345 sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
346 else:
347 sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
348
349 sslsock.connect(("127.0.0.1", port))
350 return sslsock
351
352 @classmethod
353 def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0):
354 if not rawQuery:
355 wire = query.to_wire()
356 else:
357 wire = query
358
359 if response:
360 cls._toResponderQueue.put(response, True, timeout)
361
362 sock.send(struct.pack("!H", len(wire)))
363 sock.send(wire)
364
365 @classmethod
366 def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
367 message = None
368 data = sock.recv(2)
369 if data:
370 (datalen,) = struct.unpack("!H", data)
371 data = sock.recv(datalen)
372 if data:
373 message = dns.message.from_wire(data)
374
375 if useQueue and not cls._fromResponderQueue.empty():
376 receivedQuery = cls._fromResponderQueue.get(True, timeout)
377 return (receivedQuery, message)
378 else:
379 return message
380
381 @classmethod
382 def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
383 message = None
384 if useQueue:
385 cls._toResponderQueue.put(response, True, timeout)
386
387 sock = cls.openTCPConnection(timeout)
388
389 try:
390 cls.sendTCPQueryOverConnection(sock, query, rawQuery)
391 message = cls.recvTCPResponseOverConnection(sock)
392 except socket.timeout as e:
393 print("Timeout: %s" % (str(e)))
394 except socket.error as e:
395 print("Network error: %s" % (str(e)))
396 finally:
397 sock.close()
398
399 receivedQuery = None
400 if useQueue and not cls._fromResponderQueue.empty():
401 receivedQuery = cls._fromResponderQueue.get(True, timeout)
402
403 return (receivedQuery, message)
404
405 @classmethod
406 def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False):
407 if useQueue:
408 for response in responses:
409 cls._toResponderQueue.put(response, True, timeout)
410 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
411 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
412 if timeout:
413 sock.settimeout(timeout)
414
415 sock.connect(("127.0.0.1", cls._dnsDistPort))
416 messages = []
417
418 try:
419 if not rawQuery:
420 wire = query.to_wire()
421 else:
422 wire = query
423
424 sock.send(struct.pack("!H", len(wire)))
425 sock.send(wire)
426 while True:
427 data = sock.recv(2)
428 if not data:
429 break
430 (datalen,) = struct.unpack("!H", data)
431 data = sock.recv(datalen)
432 messages.append(dns.message.from_wire(data))
433
434 except socket.timeout as e:
435 print("Timeout: %s" % (str(e)))
436 except socket.error as e:
437 print("Network error: %s" % (str(e)))
438 finally:
439 sock.close()
440
441 receivedQuery = None
442 if useQueue and not cls._fromResponderQueue.empty():
443 receivedQuery = cls._fromResponderQueue.get(True, timeout)
444 return (receivedQuery, messages)
445
446 def setUp(self):
447 # This function is called before every tests
448
449 # Clear the responses counters
450 for key in self._responsesCounter:
451 self._responsesCounter[key] = 0
452
453 self._healthCheckCounter = 0
454
455 # Make sure the queues are empty, in case
456 # a previous test failed
457 while not self._toResponderQueue.empty():
458 self._toResponderQueue.get(False)
459
460 while not self._fromResponderQueue.empty():
461 self._fromResponderQueue.get(False)
462
463 super(DNSDistTest, self).setUp()
464
465 @classmethod
466 def clearToResponderQueue(cls):
467 while not cls._toResponderQueue.empty():
468 cls._toResponderQueue.get(False)
469
470 @classmethod
471 def clearFromResponderQueue(cls):
472 while not cls._fromResponderQueue.empty():
473 cls._fromResponderQueue.get(False)
474
475 @classmethod
476 def clearResponderQueues(cls):
477 cls.clearToResponderQueue()
478 cls.clearFromResponderQueue()
479
480 @staticmethod
481 def generateConsoleKey():
482 return libnacl.utils.salsa_key()
483
484 @classmethod
485 def _encryptConsole(cls, command, nonce):
486 command = command.encode('UTF-8')
487 if cls._consoleKey is None:
488 return command
489 return libnacl.crypto_secretbox(command, nonce, cls._consoleKey)
490
491 @classmethod
492 def _decryptConsole(cls, command, nonce):
493 if cls._consoleKey is None:
494 result = command
495 else:
496 result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey)
497 return result.decode('UTF-8')
498
499 @classmethod
500 def sendConsoleCommand(cls, command, timeout=1.0):
501 ourNonce = libnacl.utils.rand_nonce()
502 theirNonce = None
503 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
504 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
505 if timeout:
506 sock.settimeout(timeout)
507
508 sock.connect(("127.0.0.1", cls._consolePort))
509 sock.send(ourNonce)
510 theirNonce = sock.recv(len(ourNonce))
511 if len(theirNonce) != len(ourNonce):
512 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce)))
513 if len(theirNonce) == 0:
514 raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce)))
515 return None
516
517 halfNonceSize = int(len(ourNonce) / 2)
518 readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:]
519 writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:]
520 msg = cls._encryptConsole(command, writingNonce)
521 sock.send(struct.pack("!I", len(msg)))
522 sock.send(msg)
523 data = sock.recv(4)
524 if not data:
525 raise socket.error("Got EOF while reading the response size")
526
527 (responseLen,) = struct.unpack("!I", data)
528 data = sock.recv(responseLen)
529 response = cls._decryptConsole(data, readingNonce)
530 return response
531
532 def compareOptions(self, a, b):
533 self.assertEquals(len(a), len(b))
534 for idx in range(len(a)):
535 self.assertEquals(a[idx], b[idx])
536
537 def checkMessageNoEDNS(self, expected, received):
538 self.assertEquals(expected, received)
539 self.assertEquals(received.edns, -1)
540 self.assertEquals(len(received.options), 0)
541
542 def checkMessageEDNSWithoutOptions(self, expected, received):
543 self.assertEquals(expected, received)
544 self.assertEquals(received.edns, 0)
545 self.assertEquals(expected.payload, received.payload)
546
547 def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
548 self.assertEquals(expected, received)
549 self.assertEquals(received.edns, 0)
550 self.assertEquals(expected.payload, received.payload)
551 self.assertEquals(len(received.options), withCookies)
552 if withCookies:
553 for option in received.options:
554 self.assertEquals(option.otype, 10)
555 else:
556 for option in received.options:
557 self.assertNotEquals(option.otype, 10)
558
559 def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
560 self.assertEquals(expected, received)
561 self.assertEquals(received.edns, 0)
562 self.assertEquals(expected.payload, received.payload)
563 self.assertEquals(len(received.options), 1 + additionalOptions)
564 hasECS = False
565 for option in received.options:
566 if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE:
567 hasECS = True
568 else:
569 self.assertNotEquals(additionalOptions, 0)
570
571 self.compareOptions(expected.options, received.options)
572 self.assertTrue(hasECS)
573
574 def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0):
575 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
576
577 def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0):
578 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
579
580 def checkQueryEDNSWithoutECS(self, expected, received):
581 self.checkMessageEDNSWithoutECS(expected, received)
582
583 def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0):
584 self.checkMessageEDNSWithoutECS(expected, received, withCookies)
585
586 def checkQueryNoEDNS(self, expected, received):
587 self.checkMessageNoEDNS(expected, received)
588
589 def checkResponseNoEDNS(self, expected, received):
590 self.checkMessageNoEDNS(expected, received)
591