]> git.ipfire.org Git - thirdparty/pdns.git/blame - regression-tests.dnsdist/dnsdisttests.py
Smarter startup delay: wait for listen port to come alive
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
CommitLineData
ca404e94
RG
1#!/usr/bin/env python2
2
95f0b802 3import copy
ffabdc3e 4import errno
ca404e94
RG
5import os
6import socket
a227f47d 7import ssl
ca404e94
RG
8import struct
9import subprocess
10import sys
11import threading
12import time
13import unittest
9d71a0cf 14
5df86a8a 15import clientsubnetoption
9d71a0cf 16
b1bec9f0
RG
17import dns
18import dns.message
9d71a0cf 19
1ea747c0
RG
20import libnacl
21import libnacl.utils
ca404e94 22
9d71a0cf
RG
23import h2.connection
24import h2.events
25import h2.config
26
6bd430bf 27from eqdnsmessage import AssertEqualDNSMessageMixin
0e6892c6 28from proxyprotocol import ProxyProtocol
6bd430bf 29
b4f23783 30# Python2/3 compatibility hacks
7a0ea291 31try:
32 from queue import Queue
33except ImportError:
b4f23783 34 from Queue import Queue
7a0ea291 35
36try:
b4f23783 37 range = xrange
7a0ea291 38except NameError:
39 pass
b4f23783
CH
40
41
6bd430bf 42class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
ca404e94
RG
43 """
44 Set up a dnsdist instance and responder threads.
45 Queries sent to dnsdist are relayed to the responder threads,
46 who reply with the response provided by the tests themselves
47 on a queue. Responder threads also queue the queries received
48 from dnsdist on a separate queue, allowing the tests to check
49 that the queries sent from dnsdist were as expected.
50 """
51 _dnsDistPort = 5340
b052847c 52 _dnsDistListeningAddr = "127.0.0.1"
ca404e94 53 _testServerPort = 5350
b4f23783
CH
54 _toResponderQueue = Queue()
55 _fromResponderQueue = Queue()
617dfe22 56 _queueTimeout = 1
b1bec9f0 57 _dnsdistStartupDelay = 2.0
ca404e94 58 _dnsdist = None
ec5f5c6b 59 _responsesCounter = {}
18a0e7c6 60 _config_template = """
18a0e7c6
CH
61 """
62 _config_params = ['_testServerPort']
63 _acl = ['127.0.0.1/32']
1ea747c0
RG
64 _consolePort = 5199
65 _consoleKey = None
98650fde
RG
66 _healthCheckName = 'a.root-servers.net.'
67 _healthCheckCounter = 0
e44df0f1 68 _answerUnexpected = True
f73ce0e3 69 _checkConfigExpectedOutput = None
2a3cafcd 70 _verboseMode = False
db7acdaf 71 _skipListeningOnCL = False
7373e3a6
RG
72 _backgroundThreads = {}
73 _UDPResponder = None
74 _TCPResponder = None
ca404e94 75
ffabdc3e
OM
76 @classmethod
77 def waitForTCPSocket(cls, ipaddress, port):
78 for try_number in range(0, 20):
79 try:
80 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
81 sock.settimeout(1.0)
82 sock.connect((ipaddress, port))
83 sock.close()
84 return
85 except Exception as err:
86 if err.errno != errno.ECONNREFUSED:
87 print(f'Error occurred: {try_number} {err}', file=sys.stderr)
88 time.sleep(0.1)
89 # We assume the dnsdist instance does not listen. That's fine.
90
ca404e94
RG
91 @classmethod
92 def startResponders(cls):
93 print("Launching responders..")
ec5f5c6b 94
5df86a8a 95 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
ca404e94
RG
96 cls._UDPResponder.setDaemon(True)
97 cls._UDPResponder.start()
5df86a8a 98 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
ca404e94
RG
99 cls._TCPResponder.setDaemon(True)
100 cls._TCPResponder.start()
101
102 @classmethod
aac2124b 103 def startDNSDist(cls):
ca404e94 104 print("Launching dnsdist..")
aac2124b 105 confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__))
18a0e7c6
CH
106 params = tuple([getattr(cls, param) for param in cls._config_params])
107 print(params)
aac2124b 108 with open(confFile, 'w') as conf:
18a0e7c6
CH
109 conf.write("-- Autogenerated by dnsdisttests.py\n")
110 conf.write(cls._config_template % params)
f3853a40 111 conf.write("setSecurityPollSuffix('')")
18a0e7c6 112
db7acdaf
RG
113 if cls._skipListeningOnCL:
114 dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile ]
115 else:
116 dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile,
117 '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
118
2a3cafcd
RG
119 if cls._verboseMode:
120 dnsdistcmd.append('-v')
121
18a0e7c6
CH
122 for acl in cls._acl:
123 dnsdistcmd.extend(['--acl', acl])
124 print(' '.join(dnsdistcmd))
125
6b44773a
CH
126 # validate config with --check-config, which sets client=true, possibly exposing bugs.
127 testcmd = dnsdistcmd + ['--check-config']
ff0bc6a6
JS
128 try:
129 output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
130 except subprocess.CalledProcessError as exc:
131 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
f73ce0e3
RG
132 if cls._checkConfigExpectedOutput is not None:
133 expectedOutput = cls._checkConfigExpectedOutput
134 else:
135 expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
2a3cafcd 136 if not cls._verboseMode and output != expectedOutput:
6b44773a
CH
137 raise AssertionError('dnsdist --check-config failed: %s' % output)
138
aac2124b
RG
139 logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__))
140 with open(logFile, 'w') as fdLog:
141 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog)
ca404e94 142
ffabdc3e 143 cls.waitForTCPSocket(cls._dnsDistListeningAddr, cls._dnsDistPort);
ca404e94
RG
144
145 if cls._dnsdist.poll() is not None:
ffabdc3e
OM
146 print(f"\n*** startDNSDist log for {logFile} ***")
147 with open(logFile, 'r') as fdLog:
148 print(fdLog.read())
149 print(f"*** End startDNSDist log for {logFile} ***")
150 raise AssertionError('%s failed (%d)' % (dnsdistcmd, cls._dnsdist.returncode))
ca404e94
RG
151
152 @classmethod
153 def setUpSockets(cls):
154 print("Setting up UDP socket..")
155 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1ade83b2 156 cls._sock.settimeout(2.0)
ca404e94
RG
157 cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
158
e4284d05
OM
159 @classmethod
160 def killProcess(cls, p):
161 # Don't try to kill it if it's already dead
162 if p.poll() is not None:
163 return
164 try:
165 p.terminate()
166 for count in range(10):
167 x = p.poll()
168 if x is not None:
169 break
170 time.sleep(0.1)
171 if x is None:
172 print("kill...", p, file=sys.stderr)
173 p.kill()
174 p.wait()
175 except OSError as e:
176 # There is a race-condition with the poll() and
177 # kill() statements, when the process is dead on the
178 # kill(), this is fine
179 if e.errno != errno.ESRCH:
180 raise
181
ca404e94
RG
182 @classmethod
183 def setUpClass(cls):
184
185 cls.startResponders()
aac2124b 186 cls.startDNSDist()
ca404e94
RG
187 cls.setUpSockets()
188
189 print("Launching tests..")
190
191 @classmethod
192 def tearDownClass(cls):
ffabdc3e
OM
193 cls._sock.close()
194 # tell the background threads to stop, if any
195 for backgroundThread in cls._backgroundThreads:
196 cls._backgroundThreads[backgroundThread] = False
e4284d05 197 cls.killProcess(cls._dnsdist)
7373e3a6 198
ca404e94 199 @classmethod
fe1c60f2 200 def _ResponderIncrementCounter(cls):
ec5f5c6b
RG
201 if threading.currentThread().name in cls._responsesCounter:
202 cls._responsesCounter[threading.currentThread().name] += 1
203 else:
204 cls._responsesCounter[threading.currentThread().name] = 1
205
fe1c60f2 206 @classmethod
4aa08b62 207 def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
fe1c60f2
RG
208 response = None
209 if len(request.question) != 1:
210 print("Skipping query with question count %d" % (len(request.question)))
211 return None
98650fde
RG
212 healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
213 if healthCheck:
214 cls._healthCheckCounter += 1
4aa08b62 215 response = dns.message.make_response(request)
98650fde 216 else:
fe1c60f2 217 cls._ResponderIncrementCounter()
5df86a8a 218 if not fromQueue.empty():
4aa08b62 219 toQueue.put(request, True, cls._queueTimeout)
90186270
RG
220 response = fromQueue.get(True, cls._queueTimeout)
221 if response:
222 response = copy.copy(response)
223 response.id = request.id
224
225 if synthesize is not None:
226 response = dns.message.make_response(request)
227 response.set_rcode(synthesize)
fe1c60f2 228
e44df0f1 229 if not response:
90186270 230 if cls._answerUnexpected:
e44df0f1
RG
231 response = dns.message.make_response(request)
232 response.set_rcode(dns.rcode.SERVFAIL)
fe1c60f2
RG
233
234 return response
235
ec5f5c6b 236 @classmethod
a620f197 237 def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None):
7373e3a6 238 cls._backgroundThreads[threading.get_native_id()] = True
3ef7ab0d
RG
239 # trailingDataResponse=True means "ignore trailing data".
240 # Other values are either False (meaning "raise an exception")
241 # or are interpreted as a response RCODE for queries with trailing data.
a620f197 242 # callback is invoked for every -even healthcheck ones- query and should return a raw response
4aa08b62 243 ignoreTrailing = trailingDataResponse is True
3ef7ab0d 244
ca404e94
RG
245 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
246 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
ec5f5c6b 247 sock.bind(("127.0.0.1", port))
7373e3a6 248 sock.settimeout(1.0)
ca404e94 249 while True:
7373e3a6
RG
250 try:
251 data, addr = sock.recvfrom(4096)
252 except socket.timeout:
253 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
254 del cls._backgroundThreads[threading.get_native_id()]
255 break
256 else:
257 continue
258
4aa08b62
RG
259 forceRcode = None
260 try:
261 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
262 except dns.message.TrailingJunk as e:
51f07ad4 263 print('trailing data exception in UDPResponder')
3ef7ab0d 264 if trailingDataResponse is False or forceRcode is True:
4aa08b62
RG
265 raise
266 print("UDP query with trailing data, synthesizing response")
267 request = dns.message.from_wire(data, ignore_trailing=True)
268 forceRcode = trailingDataResponse
269
f3913dd2 270 wire = None
a620f197
RG
271 if callback:
272 wire = callback(request)
273 else:
f8662974
RG
274 if request.edns > 1:
275 forceRcode = dns.rcode.BADVERS
a620f197
RG
276 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
277 if response:
278 wire = response.to_wire()
87c605c4 279
f3913dd2
RG
280 if not wire:
281 continue
282
a620f197 283 sock.sendto(wire, addr)
7373e3a6 284
ca404e94
RG
285 sock.close()
286
287 @classmethod
645a1ca4
RG
288 def handleTCPConnection(cls, conn, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None):
289 ignoreTrailing = trailingDataResponse is True
290 data = conn.recv(2)
291 if not data:
292 conn.close()
293 return
294
295 (datalen,) = struct.unpack("!H", data)
296 data = conn.recv(datalen)
297 forceRcode = None
298 try:
299 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
300 except dns.message.TrailingJunk as e:
301 if trailingDataResponse is False or forceRcode is True:
302 raise
303 print("TCP query with trailing data, synthesizing response")
304 request = dns.message.from_wire(data, ignore_trailing=True)
305 forceRcode = trailingDataResponse
306
307 if callback:
308 wire = callback(request)
309 else:
f8662974
RG
310 if request.edns > 1:
311 forceRcode = dns.rcode.BADVERS
645a1ca4
RG
312 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
313 if response:
314 wire = response.to_wire(max_size=65535)
315
316 if not wire:
317 conn.close()
318 return
319
320 conn.send(struct.pack("!H", len(wire)))
321 conn.send(wire)
322
323 while multipleResponses:
936dd73c
RG
324 # do not block, and stop as soon as the queue is empty, either the next response is already here or we are done
325 # otherwise we might read responses intended for the next connection
645a1ca4
RG
326 if fromQueue.empty():
327 break
328
936dd73c 329 response = fromQueue.get(False)
645a1ca4
RG
330 if not response:
331 break
332
333 response = copy.copy(response)
334 response.id = request.id
335 wire = response.to_wire(max_size=65535)
336 try:
337 conn.send(struct.pack("!H", len(wire)))
338 conn.send(wire)
339 except socket.error as e:
340 # some of the tests are going to close
341 # the connection on us, just deal with it
342 break
343
344 conn.close()
345
346 @classmethod
144eebeb 347 def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, multipleConnections=False, listeningAddr='127.0.0.1'):
7373e3a6 348 cls._backgroundThreads[threading.get_native_id()] = True
3ef7ab0d
RG
349 # trailingDataResponse=True means "ignore trailing data".
350 # Other values are either False (meaning "raise an exception")
351 # or are interpreted as a response RCODE for queries with trailing data.
a620f197 352 # callback is invoked for every -even healthcheck ones- query and should return a raw response
3ef7ab0d 353
ca404e94 354 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
501af9ae 355 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
ca404e94
RG
356 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
357 try:
144eebeb 358 sock.bind((listeningAddr, port))
ca404e94
RG
359 except socket.error as e:
360 print("Error binding in the TCP responder: %s" % str(e))
361 sys.exit(1)
362
363 sock.listen(100)
7373e3a6 364 sock.settimeout(1.0)
7d2856a6
RG
365 if tlsContext:
366 sock = tlsContext.wrap_socket(sock, server_side=True)
367
ca404e94 368 while True:
7d2856a6
RG
369 try:
370 (conn, _) = sock.accept()
371 except ssl.SSLError:
372 continue
ae3b96d9
RG
373 except ConnectionResetError:
374 continue
7373e3a6
RG
375 except socket.timeout:
376 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
377 del cls._backgroundThreads[threading.get_native_id()]
378 break
379 else:
380 continue
ae3b96d9 381
6ac8517d 382 conn.settimeout(5.0)
645a1ca4
RG
383 if multipleConnections:
384 thread = threading.Thread(name='TCP Connection Handler',
385 target=cls.handleTCPConnection,
386 args=[conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback])
387 thread.setDaemon(True)
388 thread.start()
a620f197 389 else:
645a1ca4 390 cls.handleTCPConnection(conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback)
548c8b66 391
ca404e94
RG
392 sock.close()
393
c4c72a2c
RG
394 @classmethod
395 def handleDoHConnection(cls, config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol):
396 ignoreTrailing = trailingDataResponse is True
144eebeb
RG
397 try:
398 h2conn = h2.connection.H2Connection(config=config)
399 h2conn.initiate_connection()
400 conn.sendall(h2conn.data_to_send())
401 except ssl.SSLEOFError as e:
402 print("Unexpected EOF: %s" % (e))
403 return
404
c4c72a2c
RG
405 dnsData = {}
406
407 if useProxyProtocol:
408 # try to read the entire Proxy Protocol header
409 proxy = ProxyProtocol()
410 header = conn.recv(proxy.HEADER_SIZE)
411 if not header:
412 print('unable to get header')
413 conn.close()
414 return
415
416 if not proxy.parseHeader(header):
417 print('unable to parse header')
418 print(header)
419 conn.close()
420 return
421
422 proxyContent = conn.recv(proxy.contentLen)
423 if not proxyContent:
424 print('unable to get content')
425 conn.close()
426 return
427
428 payload = header + proxyContent
429 toQueue.put(payload, True, cls._queueTimeout)
430
431 # be careful, HTTP/2 headers and data might be in different recv() results
432 requestHeaders = None
433 while True:
434 data = conn.recv(65535)
435 if not data:
436 break
437
438 events = h2conn.receive_data(data)
439 for event in events:
440 if isinstance(event, h2.events.RequestReceived):
441 requestHeaders = event.headers
442 if isinstance(event, h2.events.DataReceived):
443 h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id)
444 if not event.stream_id in dnsData:
445 dnsData[event.stream_id] = b''
446 dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data)
447 if event.stream_ended:
448 forceRcode = None
449 status = 200
450 try:
451 request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing)
452 except dns.message.TrailingJunk as e:
453 if trailingDataResponse is False or forceRcode is True:
454 raise
455 print("DOH query with trailing data, synthesizing response")
456 request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True)
457 forceRcode = trailingDataResponse
458
459 if callback:
460 status, wire = callback(request, requestHeaders, fromQueue, toQueue)
461 else:
462 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
463 if response:
464 wire = response.to_wire(max_size=65535)
465
466 if not wire:
467 conn.close()
468 conn = None
469 break
470
471 headers = [
472 (':status', str(status)),
473 ('content-length', str(len(wire))),
474 ('content-type', 'application/dns-message'),
475 ]
476 h2conn.send_headers(stream_id=event.stream_id, headers=headers)
477 h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True)
478
479 data_to_send = h2conn.data_to_send()
480 if data_to_send:
481 conn.sendall(data_to_send)
482
483 if conn is None:
484 break
485
486 if conn is not None:
487 conn.close()
488
9d71a0cf 489 @classmethod
0e6892c6 490 def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False):
7373e3a6 491 cls._backgroundThreads[threading.get_native_id()] = True
9d71a0cf
RG
492 # trailingDataResponse=True means "ignore trailing data".
493 # Other values are either False (meaning "raise an exception")
494 # or are interpreted as a response RCODE for queries with trailing data.
495 # callback is invoked for every -even healthcheck ones- query and should return a raw response
9d71a0cf
RG
496
497 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
498 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
499 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
500 try:
501 sock.bind(("127.0.0.1", port))
502 except socket.error as e:
503 print("Error binding in the TCP responder: %s" % str(e))
504 sys.exit(1)
505
506 sock.listen(100)
7373e3a6 507 sock.settimeout(1.0)
9d71a0cf
RG
508 if tlsContext:
509 sock = tlsContext.wrap_socket(sock, server_side=True)
510
511 config = h2.config.H2Configuration(client_side=False)
512
513 while True:
514 try:
515 (conn, _) = sock.accept()
516 except ssl.SSLError:
517 continue
518 except ConnectionResetError:
519 continue
7373e3a6
RG
520 except socket.timeout:
521 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
522 del cls._backgroundThreads[threading.get_native_id()]
523 break
524 else:
525 continue
ae3b96d9 526
9d71a0cf 527 conn.settimeout(5.0)
c4c72a2c
RG
528 thread = threading.Thread(name='DoH Connection Handler',
529 target=cls.handleDoHConnection,
530 args=[config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol])
531 thread.setDaemon(True)
532 thread.start()
9d71a0cf
RG
533
534 sock.close()
535
ca404e94 536 @classmethod
55baa1f2 537 def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
90186270 538 if useQueue and response is not None:
617dfe22 539 cls._toResponderQueue.put(response, True, timeout)
ca404e94
RG
540
541 if timeout:
542 cls._sock.settimeout(timeout)
543
544 try:
55baa1f2
RG
545 if not rawQuery:
546 query = query.to_wire()
547 cls._sock.send(query)
ca404e94 548 data = cls._sock.recv(4096)
b1bec9f0 549 except socket.timeout:
ca404e94
RG
550 data = None
551 finally:
552 if timeout:
553 cls._sock.settimeout(None)
554
555 receivedQuery = None
556 message = None
557 if useQueue and not cls._fromResponderQueue.empty():
617dfe22 558 receivedQuery = cls._fromResponderQueue.get(True, timeout)
ca404e94
RG
559 if data:
560 message = dns.message.from_wire(data)
561 return (receivedQuery, message)
562
563 @classmethod
db7acdaf 564 def openTCPConnection(cls, timeout=None, port=None):
ca404e94 565 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
501af9ae 566 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
ca404e94
RG
567 if timeout:
568 sock.settimeout(timeout)
569
db7acdaf
RG
570 if not port:
571 port = cls._dnsDistPort
572
573 sock.connect(("127.0.0.1", port))
9396d955 574 return sock
0a2087eb 575
9396d955 576 @classmethod
a227f47d
RG
577 def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
578 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
501af9ae 579 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
a227f47d
RG
580 if timeout:
581 sock.settimeout(timeout)
582
583 # 2.7.9+
584 if hasattr(ssl, 'create_default_context'):
585 sslctx = ssl.create_default_context(cafile=caCert)
586 sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
587 else:
588 sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
589
590 sslsock.connect(("127.0.0.1", port))
591 return sslsock
592
593 @classmethod
594 def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0):
9396d955
RG
595 if not rawQuery:
596 wire = query.to_wire()
597 else:
598 wire = query
55baa1f2 599
a227f47d
RG
600 if response:
601 cls._toResponderQueue.put(response, True, timeout)
602
9396d955
RG
603 sock.send(struct.pack("!H", len(wire)))
604 sock.send(wire)
605
606 @classmethod
a227f47d 607 def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
f05cd66c 608 print("reading data")
9396d955
RG
609 message = None
610 data = sock.recv(2)
611 if data:
612 (datalen,) = struct.unpack("!H", data)
f05cd66c 613 print(datalen)
9396d955 614 data = sock.recv(datalen)
ca404e94 615 if data:
f05cd66c 616 print(data)
9396d955 617 message = dns.message.from_wire(data)
a227f47d 618
f05cd66c 619 print(useQueue)
a227f47d
RG
620 if useQueue and not cls._fromResponderQueue.empty():
621 receivedQuery = cls._fromResponderQueue.get(True, timeout)
f05cd66c
RG
622 print("Got from queue")
623 print(receivedQuery)
a227f47d
RG
624 return (receivedQuery, message)
625 else:
f05cd66c 626 print("queue empty")
a227f47d 627 return message
9396d955
RG
628
629 @classmethod
630 def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
631 message = None
632 if useQueue:
633 cls._toResponderQueue.put(response, True, timeout)
634
635 sock = cls.openTCPConnection(timeout)
636
637 try:
638 cls.sendTCPQueryOverConnection(sock, query, rawQuery)
639 message = cls.recvTCPResponseOverConnection(sock)
ca404e94 640 except socket.timeout as e:
0e6892c6 641 print("Timeout while sending or receiving TCP data: %s" % (str(e)))
ca404e94
RG
642 except socket.error as e:
643 print("Network error: %s" % (str(e)))
ca404e94
RG
644 finally:
645 sock.close()
646
647 receivedQuery = None
f05cd66c 648 print(useQueue)
ca404e94 649 if useQueue and not cls._fromResponderQueue.empty():
f05cd66c
RG
650 print("Got from queue")
651 print(receivedQuery)
617dfe22 652 receivedQuery = cls._fromResponderQueue.get(True, timeout)
f05cd66c
RG
653 else:
654 print("queue is empty")
9396d955 655
ca404e94 656 return (receivedQuery, message)
617dfe22 657
548c8b66
RG
658 @classmethod
659 def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False):
660 if useQueue:
661 for response in responses:
662 cls._toResponderQueue.put(response, True, timeout)
663 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
501af9ae 664 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
548c8b66
RG
665 if timeout:
666 sock.settimeout(timeout)
667
668 sock.connect(("127.0.0.1", cls._dnsDistPort))
669 messages = []
670
671 try:
672 if not rawQuery:
673 wire = query.to_wire()
674 else:
675 wire = query
676
677 sock.send(struct.pack("!H", len(wire)))
678 sock.send(wire)
679 while True:
680 data = sock.recv(2)
681 if not data:
682 break
683 (datalen,) = struct.unpack("!H", data)
684 data = sock.recv(datalen)
685 messages.append(dns.message.from_wire(data))
686
687 except socket.timeout as e:
0e6892c6 688 print("Timeout while receiving multiple TCP responses: %s" % (str(e)))
548c8b66
RG
689 except socket.error as e:
690 print("Network error: %s" % (str(e)))
691 finally:
692 sock.close()
693
694 receivedQuery = None
695 if useQueue and not cls._fromResponderQueue.empty():
696 receivedQuery = cls._fromResponderQueue.get(True, timeout)
697 return (receivedQuery, messages)
698
617dfe22 699 def setUp(self):
936dd73c 700 # This function is called before every test
617dfe22
RG
701
702 # Clear the responses counters
fd71df4e 703 self._responsesCounter.clear()
617dfe22 704
98650fde
RG
705 self._healthCheckCounter = 0
706
617dfe22
RG
707 # Make sure the queues are empty, in case
708 # a previous test failed
936dd73c 709 self.clearResponderQueues()
1ea747c0 710
6bd430bf
PD
711 super(DNSDistTest, self).setUp()
712
3bef39c3
RG
713 @classmethod
714 def clearToResponderQueue(cls):
715 while not cls._toResponderQueue.empty():
716 cls._toResponderQueue.get(False)
717
718 @classmethod
719 def clearFromResponderQueue(cls):
720 while not cls._fromResponderQueue.empty():
721 cls._fromResponderQueue.get(False)
722
723 @classmethod
724 def clearResponderQueues(cls):
725 cls.clearToResponderQueue()
726 cls.clearFromResponderQueue()
727
1ea747c0
RG
728 @staticmethod
729 def generateConsoleKey():
730 return libnacl.utils.salsa_key()
731
732 @classmethod
733 def _encryptConsole(cls, command, nonce):
b4f23783 734 command = command.encode('UTF-8')
1ea747c0
RG
735 if cls._consoleKey is None:
736 return command
737 return libnacl.crypto_secretbox(command, nonce, cls._consoleKey)
738
739 @classmethod
740 def _decryptConsole(cls, command, nonce):
741 if cls._consoleKey is None:
b4f23783
CH
742 result = command
743 else:
744 result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey)
745 return result.decode('UTF-8')
1ea747c0
RG
746
747 @classmethod
748 def sendConsoleCommand(cls, command, timeout=1.0):
749 ourNonce = libnacl.utils.rand_nonce()
750 theirNonce = None
751 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
501af9ae 752 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1ea747c0
RG
753 if timeout:
754 sock.settimeout(timeout)
755
756 sock.connect(("127.0.0.1", cls._consolePort))
757 sock.send(ourNonce)
758 theirNonce = sock.recv(len(ourNonce))
7b925432 759 if len(theirNonce) != len(ourNonce):
05a5b575 760 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce)))
bdfa6902
RG
761 if len(theirNonce) == 0:
762 raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce)))
7b925432 763 return None
1ea747c0 764
b4f23783 765 halfNonceSize = int(len(ourNonce) / 2)
333ea16e
RG
766 readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:]
767 writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:]
333ea16e 768 msg = cls._encryptConsole(command, writingNonce)
1ea747c0
RG
769 sock.send(struct.pack("!I", len(msg)))
770 sock.send(msg)
771 data = sock.recv(4)
9c9b4998
RG
772 if not data:
773 raise socket.error("Got EOF while reading the response size")
774
1ea747c0
RG
775 (responseLen,) = struct.unpack("!I", data)
776 data = sock.recv(responseLen)
333ea16e 777 response = cls._decryptConsole(data, readingNonce)
75b536de 778 sock.close()
1ea747c0 779 return response
5df86a8a
RG
780
781 def compareOptions(self, a, b):
4bfebc93 782 self.assertEqual(len(a), len(b))
b4f23783 783 for idx in range(len(a)):
4bfebc93 784 self.assertEqual(a[idx], b[idx])
5df86a8a
RG
785
786 def checkMessageNoEDNS(self, expected, received):
4bfebc93
CH
787 self.assertEqual(expected, received)
788 self.assertEqual(received.edns, -1)
789 self.assertEqual(len(received.options), 0)
5df86a8a 790
e7c732b8 791 def checkMessageEDNSWithoutOptions(self, expected, received):
4bfebc93
CH
792 self.assertEqual(expected, received)
793 self.assertEqual(received.edns, 0)
794 self.assertEqual(expected.payload, received.payload)
e7c732b8 795
5df86a8a 796 def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
4bfebc93
CH
797 self.assertEqual(expected, received)
798 self.assertEqual(received.edns, 0)
799 self.assertEqual(expected.payload, received.payload)
800 self.assertEqual(len(received.options), withCookies)
5df86a8a
RG
801 if withCookies:
802 for option in received.options:
4bfebc93 803 self.assertEqual(option.otype, 10)
be90d6bd
RG
804 else:
805 for option in received.options:
4bfebc93 806 self.assertNotEqual(option.otype, 10)
5df86a8a 807
cbf4e13a 808 def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
4bfebc93
CH
809 self.assertEqual(expected, received)
810 self.assertEqual(received.edns, 0)
811 self.assertEqual(expected.payload, received.payload)
812 self.assertEqual(len(received.options), 1 + additionalOptions)
cbf4e13a
RG
813 hasECS = False
814 for option in received.options:
815 if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE:
816 hasECS = True
817 else:
4bfebc93 818 self.assertNotEqual(additionalOptions, 0)
cbf4e13a 819
5df86a8a 820 self.compareOptions(expected.options, received.options)
cbf4e13a 821 self.assertTrue(hasECS)
5df86a8a 822
346410cd
CHB
823 def checkMessageEDNS(self, expected, received):
824 self.assertEqual(expected, received)
825 self.assertEqual(received.edns, 0)
826 self.assertEqual(expected.payload, received.payload)
827 self.assertEqual(len(expected.options), len(received.options))
828 self.compareOptions(expected.options, received.options)
829
cbf4e13a
RG
830 def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0):
831 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
5df86a8a 832
346410cd
CHB
833 def checkQueryEDNS(self, expected, received):
834 self.checkMessageEDNS(expected, received)
835
cbf4e13a
RG
836 def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0):
837 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
5df86a8a
RG
838
839 def checkQueryEDNSWithoutECS(self, expected, received):
840 self.checkMessageEDNSWithoutECS(expected, received)
841
842 def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0):
843 self.checkMessageEDNSWithoutECS(expected, received, withCookies)
844
845 def checkQueryNoEDNS(self, expected, received):
846 self.checkMessageNoEDNS(expected, received)
847
848 def checkResponseNoEDNS(self, expected, received):
849 self.checkMessageNoEDNS(expected, received)
b0d08f82 850
1d896c34
RG
851 def generateNewCertificateAndKey(self):
852 # generate and sign a new cert
853 cmd = ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', 'server.key', '-out', 'server.csr', '-config', 'configServer.conf']
854 output = None
855 try:
856 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
857 output = process.communicate(input='')
858 except subprocess.CalledProcessError as exc:
859 raise AssertionError('openssl req failed (%d): %s' % (exc.returncode, exc.output))
860 cmd = ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', 'server.csr', '-out', 'server.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req']
861 output = None
862 try:
863 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
864 output = process.communicate(input='')
865 except subprocess.CalledProcessError as exc:
866 raise AssertionError('openssl x509 failed (%d): %s' % (exc.returncode, exc.output))
867
868 with open('server.chain', 'w') as outFile:
869 for inFileName in ['server.pem', 'ca.pem']:
870 with open(inFileName) as inFile:
871 outFile.write(inFile.read())
0e6892c6 872
5ac11505
CHB
873 cmd = ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', 'server.pem', '-CAfile', 'ca.pem', '-inkey', 'server.key', '-out', 'server.p12']
874 output = None
875 try:
876 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
877 output = process.communicate(input='')
878 except subprocess.CalledProcessError as exc:
879 raise AssertionError('openssl pkcs12 failed (%d): %s' % (exc.returncode, exc.output))
880
0e6892c6
RG
881 def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[], v6=False, sourcePort=None, destinationPort=None):
882 proxy = ProxyProtocol()
883 self.assertTrue(proxy.parseHeader(receivedProxyPayload))
884 self.assertEqual(proxy.version, 0x02)
885 self.assertEqual(proxy.command, 0x01)
886 if v6:
887 self.assertEqual(proxy.family, 0x02)
888 else:
889 self.assertEqual(proxy.family, 0x01)
890 if not isTCP:
891 self.assertEqual(proxy.protocol, 0x02)
892 else:
893 self.assertEqual(proxy.protocol, 0x01)
894 self.assertGreater(proxy.contentLen, 0)
895
896 self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload))
897 self.assertEqual(proxy.source, source)
898 self.assertEqual(proxy.destination, destination)
899 if sourcePort:
900 self.assertEqual(proxy.sourcePort, sourcePort)
901 if destinationPort:
902 self.assertEqual(proxy.destinationPort, destinationPort)
903 else:
904 self.assertEqual(proxy.destinationPort, self._dnsDistPort)
905
906 self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload))
907 proxy.values.sort()
908 values.sort()
909 self.assertEqual(proxy.values, values)