]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/dnsdisttests.py
Merge pull request #12819 from rgacogne/ddist-reuseaddr-udp
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
1 #!/usr/bin/env python2
2
3 import base64
4 import copy
5 import errno
6 import os
7 import socket
8 import ssl
9 import struct
10 import subprocess
11 import sys
12 import threading
13 import time
14 import unittest
15
16 import clientsubnetoption
17
18 import dns
19 import dns.message
20
21 import libnacl
22 import libnacl.utils
23
24 import h2.connection
25 import h2.events
26 import h2.config
27
28 import pycurl
29 from io import BytesIO
30
31 from eqdnsmessage import AssertEqualDNSMessageMixin
32 from proxyprotocol import ProxyProtocol
33
34 # Python2/3 compatibility hacks
35 try:
36 from queue import Queue
37 except ImportError:
38 from Queue import Queue
39
40 try:
41 range = xrange
42 except NameError:
43 pass
44
45
46 class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
47 """
48 Set up a dnsdist instance and responder threads.
49 Queries sent to dnsdist are relayed to the responder threads,
50 who reply with the response provided by the tests themselves
51 on a queue. Responder threads also queue the queries received
52 from dnsdist on a separate queue, allowing the tests to check
53 that the queries sent from dnsdist were as expected.
54 """
55 _dnsDistPort = 5340
56 _dnsDistListeningAddr = "127.0.0.1"
57 _testServerPort = 5350
58 _toResponderQueue = Queue()
59 _fromResponderQueue = Queue()
60 _queueTimeout = 1
61 _dnsdist = None
62 _responsesCounter = {}
63 _config_template = """
64 """
65 _config_params = ['_testServerPort']
66 _acl = ['127.0.0.1/32']
67 _consolePort = 5199
68 _consoleKey = None
69 _healthCheckName = 'a.root-servers.net.'
70 _healthCheckCounter = 0
71 _answerUnexpected = True
72 _checkConfigExpectedOutput = None
73 _verboseMode = False
74 _skipListeningOnCL = False
75 _alternateListeningAddr = None
76 _alternateListeningPort = None
77 _backgroundThreads = {}
78 _UDPResponder = None
79 _TCPResponder = None
80 _extraStartupSleep = 0
81
82 @classmethod
83 def waitForTCPSocket(cls, ipaddress, port):
84 for try_number in range(0, 20):
85 try:
86 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
87 sock.settimeout(1.0)
88 sock.connect((ipaddress, port))
89 sock.close()
90 return
91 except Exception as err:
92 if err.errno != errno.ECONNREFUSED:
93 print(f'Error occurred: {try_number} {err}', file=sys.stderr)
94 time.sleep(0.1)
95 # We assume the dnsdist instance does not listen. That's fine.
96
97 @classmethod
98 def startResponders(cls):
99 print("Launching responders..")
100
101 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
102 cls._UDPResponder.setDaemon(True)
103 cls._UDPResponder.start()
104 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
105 cls._TCPResponder.setDaemon(True)
106 cls._TCPResponder.start()
107 cls.waitForTCPSocket("127.0.0.1", cls._testServerPort);
108
109 @classmethod
110 def startDNSDist(cls):
111 print("Launching dnsdist..")
112 confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__))
113 params = tuple([getattr(cls, param) for param in cls._config_params])
114 print(params)
115 with open(confFile, 'w') as conf:
116 conf.write("-- Autogenerated by dnsdisttests.py\n")
117 conf.write(cls._config_template % params)
118 conf.write("setSecurityPollSuffix('')")
119
120 if cls._skipListeningOnCL:
121 dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile ]
122 else:
123 dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile,
124 '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
125
126 if cls._verboseMode:
127 dnsdistcmd.append('-v')
128
129 for acl in cls._acl:
130 dnsdistcmd.extend(['--acl', acl])
131 print(' '.join(dnsdistcmd))
132
133 # validate config with --check-config, which sets client=true, possibly exposing bugs.
134 testcmd = dnsdistcmd + ['--check-config']
135 try:
136 output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
137 except subprocess.CalledProcessError as exc:
138 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
139 if cls._checkConfigExpectedOutput is not None:
140 expectedOutput = cls._checkConfigExpectedOutput
141 else:
142 expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
143 if not cls._verboseMode and output != expectedOutput:
144 raise AssertionError('dnsdist --check-config failed: %s' % output)
145
146 logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__))
147 with open(logFile, 'w') as fdLog:
148 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog)
149
150 if cls._alternateListeningAddr and cls._alternateListeningPort:
151 cls.waitForTCPSocket(cls._alternateListeningAddr, cls._alternateListeningPort)
152 else:
153 cls.waitForTCPSocket(cls._dnsDistListeningAddr, cls._dnsDistPort)
154
155 if cls._dnsdist.poll() is not None:
156 print(f"\n*** startDNSDist log for {logFile} ***")
157 with open(logFile, 'r') as fdLog:
158 print(fdLog.read())
159 print(f"*** End startDNSDist log for {logFile} ***")
160 raise AssertionError('%s failed (%d)' % (dnsdistcmd, cls._dnsdist.returncode))
161 time.sleep(cls._extraStartupSleep)
162
163 @classmethod
164 def setUpSockets(cls):
165 print("Setting up UDP socket..")
166 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
167 cls._sock.settimeout(2.0)
168 cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
169
170 @classmethod
171 def killProcess(cls, p):
172 # Don't try to kill it if it's already dead
173 if p.poll() is not None:
174 return
175 try:
176 p.terminate()
177 for count in range(20):
178 x = p.poll()
179 if x is not None:
180 break
181 time.sleep(0.1)
182 if x is None:
183 print("kill...", p, file=sys.stderr)
184 p.kill()
185 p.wait()
186 except OSError as e:
187 # There is a race-condition with the poll() and
188 # kill() statements, when the process is dead on the
189 # kill(), this is fine
190 if e.errno != errno.ESRCH:
191 raise
192
193 @classmethod
194 def setUpClass(cls):
195
196 cls.startResponders()
197 cls.startDNSDist()
198 cls.setUpSockets()
199
200 print("Launching tests..")
201
202 @classmethod
203 def tearDownClass(cls):
204 cls._sock.close()
205 # tell the background threads to stop, if any
206 for backgroundThread in cls._backgroundThreads:
207 cls._backgroundThreads[backgroundThread] = False
208 cls.killProcess(cls._dnsdist)
209
210 @classmethod
211 def _ResponderIncrementCounter(cls):
212 if threading.currentThread().name in cls._responsesCounter:
213 cls._responsesCounter[threading.currentThread().name] += 1
214 else:
215 cls._responsesCounter[threading.currentThread().name] = 1
216
217 @classmethod
218 def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
219 response = None
220 if len(request.question) != 1:
221 print("Skipping query with question count %d" % (len(request.question)))
222 return None
223 healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
224 if healthCheck:
225 cls._healthCheckCounter += 1
226 response = dns.message.make_response(request)
227 else:
228 cls._ResponderIncrementCounter()
229 if not fromQueue.empty():
230 toQueue.put(request, True, cls._queueTimeout)
231 response = fromQueue.get(True, cls._queueTimeout)
232 if response:
233 response = copy.copy(response)
234 response.id = request.id
235
236 if synthesize is not None:
237 response = dns.message.make_response(request)
238 response.set_rcode(synthesize)
239
240 if not response:
241 if cls._answerUnexpected:
242 response = dns.message.make_response(request)
243 response.set_rcode(dns.rcode.SERVFAIL)
244
245 return response
246
247 @classmethod
248 def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None):
249 cls._backgroundThreads[threading.get_native_id()] = True
250 # trailingDataResponse=True means "ignore trailing data".
251 # Other values are either False (meaning "raise an exception")
252 # or are interpreted as a response RCODE for queries with trailing data.
253 # callback is invoked for every -even healthcheck ones- query and should return a raw response
254 ignoreTrailing = trailingDataResponse is True
255
256 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
257 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
258 sock.bind(("127.0.0.1", port))
259 sock.settimeout(0.5)
260 while True:
261 try:
262 data, addr = sock.recvfrom(4096)
263 except socket.timeout:
264 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
265 del cls._backgroundThreads[threading.get_native_id()]
266 break
267 else:
268 continue
269
270 forceRcode = None
271 try:
272 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
273 except dns.message.TrailingJunk as e:
274 print('trailing data exception in UDPResponder')
275 if trailingDataResponse is False or forceRcode is True:
276 raise
277 print("UDP query with trailing data, synthesizing response")
278 request = dns.message.from_wire(data, ignore_trailing=True)
279 forceRcode = trailingDataResponse
280
281 wire = None
282 if callback:
283 wire = callback(request)
284 else:
285 if request.edns > 1:
286 forceRcode = dns.rcode.BADVERS
287 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
288 if response:
289 wire = response.to_wire()
290
291 if not wire:
292 continue
293
294 sock.sendto(wire, addr)
295
296 sock.close()
297
298 @classmethod
299 def handleTCPConnection(cls, conn, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None):
300 ignoreTrailing = trailingDataResponse is True
301 data = conn.recv(2)
302 if not data:
303 conn.close()
304 return
305
306 (datalen,) = struct.unpack("!H", data)
307 data = conn.recv(datalen)
308 forceRcode = None
309 try:
310 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
311 except dns.message.TrailingJunk as e:
312 if trailingDataResponse is False or forceRcode is True:
313 raise
314 print("TCP query with trailing data, synthesizing response")
315 request = dns.message.from_wire(data, ignore_trailing=True)
316 forceRcode = trailingDataResponse
317
318 if callback:
319 wire = callback(request)
320 else:
321 if request.edns > 1:
322 forceRcode = dns.rcode.BADVERS
323 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
324 if response:
325 wire = response.to_wire(max_size=65535)
326
327 if not wire:
328 conn.close()
329 return
330
331 conn.send(struct.pack("!H", len(wire)))
332 conn.send(wire)
333
334 while multipleResponses:
335 # do not block, and stop as soon as the queue is empty, either the next response is already here or we are done
336 # otherwise we might read responses intended for the next connection
337 if fromQueue.empty():
338 break
339
340 response = fromQueue.get(False)
341 if not response:
342 break
343
344 response = copy.copy(response)
345 response.id = request.id
346 wire = response.to_wire(max_size=65535)
347 try:
348 conn.send(struct.pack("!H", len(wire)))
349 conn.send(wire)
350 except socket.error as e:
351 # some of the tests are going to close
352 # the connection on us, just deal with it
353 break
354
355 conn.close()
356
357 @classmethod
358 def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, multipleConnections=False, listeningAddr='127.0.0.1'):
359 cls._backgroundThreads[threading.get_native_id()] = True
360 # trailingDataResponse=True means "ignore trailing data".
361 # Other values are either False (meaning "raise an exception")
362 # or are interpreted as a response RCODE for queries with trailing data.
363 # callback is invoked for every -even healthcheck ones- query and should return a raw response
364
365 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
366 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
367 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
368 try:
369 sock.bind((listeningAddr, port))
370 except socket.error as e:
371 print("Error binding in the TCP responder: %s" % str(e))
372 sys.exit(1)
373
374 sock.listen(100)
375 sock.settimeout(0.5)
376 if tlsContext:
377 sock = tlsContext.wrap_socket(sock, server_side=True)
378
379 while True:
380 try:
381 (conn, _) = sock.accept()
382 except ssl.SSLError:
383 continue
384 except ConnectionResetError:
385 continue
386 except socket.timeout:
387 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
388 del cls._backgroundThreads[threading.get_native_id()]
389 break
390 else:
391 continue
392
393 conn.settimeout(5.0)
394 if multipleConnections:
395 thread = threading.Thread(name='TCP Connection Handler',
396 target=cls.handleTCPConnection,
397 args=[conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback])
398 thread.setDaemon(True)
399 thread.start()
400 else:
401 cls.handleTCPConnection(conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback)
402
403 sock.close()
404
405 @classmethod
406 def handleDoHConnection(cls, config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol):
407 ignoreTrailing = trailingDataResponse is True
408 try:
409 h2conn = h2.connection.H2Connection(config=config)
410 h2conn.initiate_connection()
411 conn.sendall(h2conn.data_to_send())
412 except ssl.SSLEOFError as e:
413 print("Unexpected EOF: %s" % (e))
414 return
415
416 dnsData = {}
417
418 if useProxyProtocol:
419 # try to read the entire Proxy Protocol header
420 proxy = ProxyProtocol()
421 header = conn.recv(proxy.HEADER_SIZE)
422 if not header:
423 print('unable to get header')
424 conn.close()
425 return
426
427 if not proxy.parseHeader(header):
428 print('unable to parse header')
429 print(header)
430 conn.close()
431 return
432
433 proxyContent = conn.recv(proxy.contentLen)
434 if not proxyContent:
435 print('unable to get content')
436 conn.close()
437 return
438
439 payload = header + proxyContent
440 toQueue.put(payload, True, cls._queueTimeout)
441
442 # be careful, HTTP/2 headers and data might be in different recv() results
443 requestHeaders = None
444 while True:
445 data = conn.recv(65535)
446 if not data:
447 break
448
449 events = h2conn.receive_data(data)
450 for event in events:
451 if isinstance(event, h2.events.RequestReceived):
452 requestHeaders = event.headers
453 if isinstance(event, h2.events.DataReceived):
454 h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id)
455 if not event.stream_id in dnsData:
456 dnsData[event.stream_id] = b''
457 dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data)
458 if event.stream_ended:
459 forceRcode = None
460 status = 200
461 try:
462 request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing)
463 except dns.message.TrailingJunk as e:
464 if trailingDataResponse is False or forceRcode is True:
465 raise
466 print("DOH query with trailing data, synthesizing response")
467 request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True)
468 forceRcode = trailingDataResponse
469
470 if callback:
471 status, wire = callback(request, requestHeaders, fromQueue, toQueue)
472 else:
473 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
474 if response:
475 wire = response.to_wire(max_size=65535)
476
477 if not wire:
478 conn.close()
479 conn = None
480 break
481
482 headers = [
483 (':status', str(status)),
484 ('content-length', str(len(wire))),
485 ('content-type', 'application/dns-message'),
486 ]
487 h2conn.send_headers(stream_id=event.stream_id, headers=headers)
488 h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True)
489
490 data_to_send = h2conn.data_to_send()
491 if data_to_send:
492 conn.sendall(data_to_send)
493
494 if conn is None:
495 break
496
497 if conn is not None:
498 conn.close()
499
500 @classmethod
501 def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False):
502 cls._backgroundThreads[threading.get_native_id()] = True
503 # trailingDataResponse=True means "ignore trailing data".
504 # Other values are either False (meaning "raise an exception")
505 # or are interpreted as a response RCODE for queries with trailing data.
506 # callback is invoked for every -even healthcheck ones- query and should return a raw response
507
508 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
509 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
510 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
511 try:
512 sock.bind(("127.0.0.1", port))
513 except socket.error as e:
514 print("Error binding in the TCP responder: %s" % str(e))
515 sys.exit(1)
516
517 sock.listen(100)
518 sock.settimeout(0.5)
519 if tlsContext:
520 sock = tlsContext.wrap_socket(sock, server_side=True)
521
522 config = h2.config.H2Configuration(client_side=False)
523
524 while True:
525 try:
526 (conn, _) = sock.accept()
527 except ssl.SSLError:
528 continue
529 except ConnectionResetError:
530 continue
531 except socket.timeout:
532 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
533 del cls._backgroundThreads[threading.get_native_id()]
534 break
535 else:
536 continue
537
538 conn.settimeout(5.0)
539 thread = threading.Thread(name='DoH Connection Handler',
540 target=cls.handleDoHConnection,
541 args=[config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol])
542 thread.setDaemon(True)
543 thread.start()
544
545 sock.close()
546
547 @classmethod
548 def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
549 if useQueue and response is not None:
550 cls._toResponderQueue.put(response, True, timeout)
551
552 if timeout:
553 cls._sock.settimeout(timeout)
554
555 try:
556 if not rawQuery:
557 query = query.to_wire()
558 cls._sock.send(query)
559 data = cls._sock.recv(4096)
560 except socket.timeout:
561 data = None
562 finally:
563 if timeout:
564 cls._sock.settimeout(None)
565
566 receivedQuery = None
567 message = None
568 if useQueue and not cls._fromResponderQueue.empty():
569 receivedQuery = cls._fromResponderQueue.get(True, timeout)
570 if data:
571 message = dns.message.from_wire(data)
572 return (receivedQuery, message)
573
574 @classmethod
575 def openTCPConnection(cls, timeout=None, port=None):
576 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
577 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
578 if timeout:
579 sock.settimeout(timeout)
580
581 if not port:
582 port = cls._dnsDistPort
583
584 sock.connect(("127.0.0.1", port))
585 return sock
586
587 @classmethod
588 def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
589 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
590 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
591 if timeout:
592 sock.settimeout(timeout)
593
594 # 2.7.9+
595 if hasattr(ssl, 'create_default_context'):
596 sslctx = ssl.create_default_context(cafile=caCert)
597 sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
598 else:
599 sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
600
601 sslsock.connect(("127.0.0.1", port))
602 return sslsock
603
604 @classmethod
605 def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0):
606 if not rawQuery:
607 wire = query.to_wire()
608 else:
609 wire = query
610
611 if response:
612 cls._toResponderQueue.put(response, True, timeout)
613
614 sock.send(struct.pack("!H", len(wire)))
615 sock.send(wire)
616
617 @classmethod
618 def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
619 print("reading data")
620 message = None
621 data = sock.recv(2)
622 if data:
623 (datalen,) = struct.unpack("!H", data)
624 print(datalen)
625 data = sock.recv(datalen)
626 if data:
627 print(data)
628 message = dns.message.from_wire(data)
629
630 print(useQueue)
631 if useQueue and not cls._fromResponderQueue.empty():
632 receivedQuery = cls._fromResponderQueue.get(True, timeout)
633 print("Got from queue")
634 print(receivedQuery)
635 return (receivedQuery, message)
636 else:
637 print("queue empty")
638 return message
639
640 @classmethod
641 def sendDOTQuery(cls, port, serverName, query, response, caFile, useQueue=True):
642 conn = cls.openTLSConnection(port, serverName, caFile)
643 cls.sendTCPQueryOverConnection(conn, query, response=response)
644 if useQueue:
645 return cls.recvTCPResponseOverConnection(conn, useQueue=useQueue)
646 return None, cls.recvTCPResponseOverConnection(conn, useQueue=useQueue)
647
648 @classmethod
649 def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
650 message = None
651 if useQueue:
652 cls._toResponderQueue.put(response, True, timeout)
653
654 sock = cls.openTCPConnection(timeout)
655
656 try:
657 cls.sendTCPQueryOverConnection(sock, query, rawQuery)
658 message = cls.recvTCPResponseOverConnection(sock)
659 except socket.timeout as e:
660 print("Timeout while sending or receiving TCP data: %s" % (str(e)))
661 except socket.error as e:
662 print("Network error: %s" % (str(e)))
663 finally:
664 sock.close()
665
666 receivedQuery = None
667 print(useQueue)
668 if useQueue and not cls._fromResponderQueue.empty():
669 print("Got from queue")
670 print(receivedQuery)
671 receivedQuery = cls._fromResponderQueue.get(True, timeout)
672 else:
673 print("queue is empty")
674
675 return (receivedQuery, message)
676
677 @classmethod
678 def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False):
679 if useQueue:
680 for response in responses:
681 cls._toResponderQueue.put(response, True, timeout)
682 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
683 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
684 if timeout:
685 sock.settimeout(timeout)
686
687 sock.connect(("127.0.0.1", cls._dnsDistPort))
688 messages = []
689
690 try:
691 if not rawQuery:
692 wire = query.to_wire()
693 else:
694 wire = query
695
696 sock.send(struct.pack("!H", len(wire)))
697 sock.send(wire)
698 while True:
699 data = sock.recv(2)
700 if not data:
701 break
702 (datalen,) = struct.unpack("!H", data)
703 data = sock.recv(datalen)
704 messages.append(dns.message.from_wire(data))
705
706 except socket.timeout as e:
707 print("Timeout while receiving multiple TCP responses: %s" % (str(e)))
708 except socket.error as e:
709 print("Network error: %s" % (str(e)))
710 finally:
711 sock.close()
712
713 receivedQuery = None
714 if useQueue and not cls._fromResponderQueue.empty():
715 receivedQuery = cls._fromResponderQueue.get(True, timeout)
716 return (receivedQuery, messages)
717
718 def setUp(self):
719 # This function is called before every test
720
721 # Clear the responses counters
722 self._responsesCounter.clear()
723
724 self._healthCheckCounter = 0
725
726 # Make sure the queues are empty, in case
727 # a previous test failed
728 self.clearResponderQueues()
729
730 super(DNSDistTest, self).setUp()
731
732 @classmethod
733 def clearToResponderQueue(cls):
734 while not cls._toResponderQueue.empty():
735 cls._toResponderQueue.get(False)
736
737 @classmethod
738 def clearFromResponderQueue(cls):
739 while not cls._fromResponderQueue.empty():
740 cls._fromResponderQueue.get(False)
741
742 @classmethod
743 def clearResponderQueues(cls):
744 cls.clearToResponderQueue()
745 cls.clearFromResponderQueue()
746
747 @staticmethod
748 def generateConsoleKey():
749 return libnacl.utils.salsa_key()
750
751 @classmethod
752 def _encryptConsole(cls, command, nonce):
753 command = command.encode('UTF-8')
754 if cls._consoleKey is None:
755 return command
756 return libnacl.crypto_secretbox(command, nonce, cls._consoleKey)
757
758 @classmethod
759 def _decryptConsole(cls, command, nonce):
760 if cls._consoleKey is None:
761 result = command
762 else:
763 result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey)
764 return result.decode('UTF-8')
765
766 @classmethod
767 def sendConsoleCommand(cls, command, timeout=5.0):
768 ourNonce = libnacl.utils.rand_nonce()
769 theirNonce = None
770 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
771 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
772 if timeout:
773 sock.settimeout(timeout)
774
775 sock.connect(("127.0.0.1", cls._consolePort))
776 sock.send(ourNonce)
777 theirNonce = sock.recv(len(ourNonce))
778 if len(theirNonce) != len(ourNonce):
779 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce)))
780 if len(theirNonce) == 0:
781 raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce)))
782 return None
783
784 halfNonceSize = int(len(ourNonce) / 2)
785 readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:]
786 writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:]
787 msg = cls._encryptConsole(command, writingNonce)
788 sock.send(struct.pack("!I", len(msg)))
789 sock.send(msg)
790 data = sock.recv(4)
791 if not data:
792 raise socket.error("Got EOF while reading the response size")
793
794 (responseLen,) = struct.unpack("!I", data)
795 data = sock.recv(responseLen)
796 response = cls._decryptConsole(data, readingNonce)
797 sock.close()
798 return response
799
800 def compareOptions(self, a, b):
801 self.assertEqual(len(a), len(b))
802 for idx in range(len(a)):
803 self.assertEqual(a[idx], b[idx])
804
805 def checkMessageNoEDNS(self, expected, received):
806 self.assertEqual(expected, received)
807 self.assertEqual(received.edns, -1)
808 self.assertEqual(len(received.options), 0)
809
810 def checkMessageEDNSWithoutOptions(self, expected, received):
811 self.assertEqual(expected, received)
812 self.assertEqual(received.edns, 0)
813 self.assertEqual(expected.payload, received.payload)
814
815 def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
816 self.assertEqual(expected, received)
817 self.assertEqual(received.edns, 0)
818 self.assertEqual(expected.payload, received.payload)
819 self.assertEqual(len(received.options), withCookies)
820 if withCookies:
821 for option in received.options:
822 self.assertEqual(option.otype, 10)
823 else:
824 for option in received.options:
825 self.assertNotEqual(option.otype, 10)
826
827 def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
828 self.assertEqual(expected, received)
829 self.assertEqual(received.edns, 0)
830 self.assertEqual(expected.payload, received.payload)
831 self.assertEqual(len(received.options), 1 + additionalOptions)
832 hasECS = False
833 for option in received.options:
834 if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE:
835 hasECS = True
836 else:
837 self.assertNotEqual(additionalOptions, 0)
838
839 self.compareOptions(expected.options, received.options)
840 self.assertTrue(hasECS)
841
842 def checkMessageEDNS(self, expected, received):
843 self.assertEqual(expected, received)
844 self.assertEqual(received.edns, 0)
845 self.assertEqual(expected.payload, received.payload)
846 self.assertEqual(len(expected.options), len(received.options))
847 self.compareOptions(expected.options, received.options)
848
849 def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0):
850 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
851
852 def checkQueryEDNS(self, expected, received):
853 self.checkMessageEDNS(expected, received)
854
855 def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0):
856 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
857
858 def checkQueryEDNSWithoutECS(self, expected, received):
859 self.checkMessageEDNSWithoutECS(expected, received)
860
861 def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0):
862 self.checkMessageEDNSWithoutECS(expected, received, withCookies)
863
864 def checkQueryNoEDNS(self, expected, received):
865 self.checkMessageNoEDNS(expected, received)
866
867 def checkResponseNoEDNS(self, expected, received):
868 self.checkMessageNoEDNS(expected, received)
869
870 def generateNewCertificateAndKey(self):
871 # generate and sign a new cert
872 cmd = ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', 'server.key', '-out', 'server.csr', '-config', 'configServer.conf']
873 output = None
874 try:
875 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
876 output = process.communicate(input='')
877 except subprocess.CalledProcessError as exc:
878 raise AssertionError('openssl req failed (%d): %s' % (exc.returncode, exc.output))
879 cmd = ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', 'server.csr', '-out', 'server.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req']
880 output = None
881 try:
882 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
883 output = process.communicate(input='')
884 except subprocess.CalledProcessError as exc:
885 raise AssertionError('openssl x509 failed (%d): %s' % (exc.returncode, exc.output))
886
887 with open('server.chain', 'w') as outFile:
888 for inFileName in ['server.pem', 'ca.pem']:
889 with open(inFileName) as inFile:
890 outFile.write(inFile.read())
891
892 cmd = ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', 'server.pem', '-CAfile', 'ca.pem', '-inkey', 'server.key', '-out', 'server.p12']
893 output = None
894 try:
895 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
896 output = process.communicate(input='')
897 except subprocess.CalledProcessError as exc:
898 raise AssertionError('openssl pkcs12 failed (%d): %s' % (exc.returncode, exc.output))
899
900 def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[], v6=False, sourcePort=None, destinationPort=None):
901 proxy = ProxyProtocol()
902 self.assertTrue(proxy.parseHeader(receivedProxyPayload))
903 self.assertEqual(proxy.version, 0x02)
904 self.assertEqual(proxy.command, 0x01)
905 if v6:
906 self.assertEqual(proxy.family, 0x02)
907 else:
908 self.assertEqual(proxy.family, 0x01)
909 if not isTCP:
910 self.assertEqual(proxy.protocol, 0x02)
911 else:
912 self.assertEqual(proxy.protocol, 0x01)
913 self.assertGreater(proxy.contentLen, 0)
914
915 self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload))
916 self.assertEqual(proxy.source, source)
917 self.assertEqual(proxy.destination, destination)
918 if sourcePort:
919 self.assertEqual(proxy.sourcePort, sourcePort)
920 if destinationPort:
921 self.assertEqual(proxy.destinationPort, destinationPort)
922 else:
923 self.assertEqual(proxy.destinationPort, self._dnsDistPort)
924
925 self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload))
926 proxy.values.sort()
927 values.sort()
928 self.assertEqual(proxy.values, values)
929
930 @classmethod
931 def getDOHGetURL(cls, baseurl, query, rawQuery=False):
932 if rawQuery:
933 wire = query
934 else:
935 wire = query.to_wire()
936 param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
937 return baseurl + "?dns=" + param
938
939 @classmethod
940 def openDOHConnection(cls, port, caFile, timeout=2.0):
941 conn = pycurl.Curl()
942 conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
943
944 conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message",
945 "Accept: application/dns-message"])
946 return conn
947
948 @classmethod
949 def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None):
950 url = cls.getDOHGetURL(baseurl, query, rawQuery)
951 conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
952 response_headers = BytesIO()
953 #conn.setopt(pycurl.VERBOSE, True)
954 conn.setopt(pycurl.URL, url)
955 conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
956 if useHTTPS:
957 conn.setopt(pycurl.SSL_VERIFYPEER, 1)
958 conn.setopt(pycurl.SSL_VERIFYHOST, 2)
959 if caFile:
960 conn.setopt(pycurl.CAINFO, caFile)
961
962 conn.setopt(pycurl.HTTPHEADER, customHeaders)
963 conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
964
965 if response:
966 if toQueue:
967 toQueue.put(response, True, timeout)
968 else:
969 cls._toResponderQueue.put(response, True, timeout)
970
971 receivedQuery = None
972 message = None
973 cls._response_headers = ''
974 data = conn.perform_rb()
975 cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
976 if cls._rcode == 200 and not rawResponse:
977 message = dns.message.from_wire(data)
978 elif rawResponse:
979 message = data
980
981 if useQueue:
982 if fromQueue:
983 if not fromQueue.empty():
984 receivedQuery = fromQueue.get(True, timeout)
985 else:
986 if not cls._fromResponderQueue.empty():
987 receivedQuery = cls._fromResponderQueue.get(True, timeout)
988
989 cls._response_headers = response_headers.getvalue()
990 return (receivedQuery, message)
991
992 @classmethod
993 def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
994 url = baseurl
995 conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
996 response_headers = BytesIO()
997 #conn.setopt(pycurl.VERBOSE, True)
998 conn.setopt(pycurl.URL, url)
999 conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
1000 if useHTTPS:
1001 conn.setopt(pycurl.SSL_VERIFYPEER, 1)
1002 conn.setopt(pycurl.SSL_VERIFYHOST, 2)
1003 if caFile:
1004 conn.setopt(pycurl.CAINFO, caFile)
1005
1006 conn.setopt(pycurl.HTTPHEADER, customHeaders)
1007 conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
1008 conn.setopt(pycurl.POST, True)
1009 data = query
1010 if not rawQuery:
1011 data = data.to_wire()
1012
1013 conn.setopt(pycurl.POSTFIELDS, data)
1014
1015 if response:
1016 cls._toResponderQueue.put(response, True, timeout)
1017
1018 receivedQuery = None
1019 message = None
1020 cls._response_headers = ''
1021 data = conn.perform_rb()
1022 cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
1023 if cls._rcode == 200 and not rawResponse:
1024 message = dns.message.from_wire(data)
1025 elif rawResponse:
1026 message = data
1027
1028 if useQueue and not cls._fromResponderQueue.empty():
1029 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1030
1031 cls._response_headers = response_headers.getvalue()
1032 return (receivedQuery, message)
1033
1034 def sendDOHQueryWrapper(self, query, response, useQueue=True):
1035 return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
1036
1037 def sendDOTQueryWrapper(self, query, response, useQueue=True):
1038 return self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response, self._caCert, useQueue=useQueue)