]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/dnsdisttests.py
dnsdist: Add a regression test for DoQ certs/keys reloading
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
1 #!/usr/bin/env python2
2
3 import base64
4 import copy
5 import errno
6 import os
7 import socket
8 import ssl
9 import struct
10 import subprocess
11 import sys
12 import threading
13 import time
14 import unittest
15
16 import clientsubnetoption
17
18 import dns
19 import dns.message
20
21 import libnacl
22 import libnacl.utils
23
24 import h2.connection
25 import h2.events
26 import h2.config
27
28 import pycurl
29 from io import BytesIO
30
31 from doqclient import quic_query
32 from doh3client import doh3_query
33
34 from eqdnsmessage import AssertEqualDNSMessageMixin
35 from proxyprotocol import ProxyProtocol
36
37 # Python2/3 compatibility hacks
38 try:
39 from queue import Queue
40 except ImportError:
41 from Queue import Queue
42
43 try:
44 range = xrange
45 except NameError:
46 pass
47
48 def getWorkerID():
49 if not 'PYTEST_XDIST_WORKER' in os.environ:
50 return 0
51 workerName = os.environ['PYTEST_XDIST_WORKER']
52 return int(workerName[2:])
53
54 workerPorts = {}
55
56 def pickAvailablePort():
57 global workerPorts
58 workerID = getWorkerID()
59 if workerID in workerPorts:
60 port = workerPorts[workerID] + 1
61 else:
62 port = 11000 + (workerID * 1000)
63 workerPorts[workerID] = port
64 return port
65
66 class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
67 """
68 Set up a dnsdist instance and responder threads.
69 Queries sent to dnsdist are relayed to the responder threads,
70 who reply with the response provided by the tests themselves
71 on a queue. Responder threads also queue the queries received
72 from dnsdist on a separate queue, allowing the tests to check
73 that the queries sent from dnsdist were as expected.
74 """
75 _dnsDistListeningAddr = "127.0.0.1"
76 _toResponderQueue = Queue()
77 _fromResponderQueue = Queue()
78 _queueTimeout = 1
79 _dnsdist = None
80 _responsesCounter = {}
81 _config_template = """
82 """
83 _config_params = ['_testServerPort']
84 _acl = ['127.0.0.1/32']
85 _consoleKey = None
86 _healthCheckName = 'a.root-servers.net.'
87 _healthCheckCounter = 0
88 _answerUnexpected = True
89 _checkConfigExpectedOutput = None
90 _verboseMode = False
91 _skipListeningOnCL = False
92 _alternateListeningAddr = None
93 _alternateListeningPort = None
94 _backgroundThreads = {}
95 _UDPResponder = None
96 _TCPResponder = None
97 _extraStartupSleep = 0
98 _dnsDistPort = pickAvailablePort()
99 _consolePort = pickAvailablePort()
100 _testServerPort = pickAvailablePort()
101
102 @classmethod
103 def waitForTCPSocket(cls, ipaddress, port):
104 for try_number in range(0, 20):
105 try:
106 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
107 sock.settimeout(1.0)
108 sock.connect((ipaddress, port))
109 sock.close()
110 return
111 except Exception as err:
112 if err.errno != errno.ECONNREFUSED:
113 print(f'Error occurred: {try_number} {err}', file=sys.stderr)
114 time.sleep(0.1)
115 # We assume the dnsdist instance does not listen. That's fine.
116
117 @classmethod
118 def startResponders(cls):
119 print("Launching responders..")
120 cls._testServerPort = pickAvailablePort()
121
122 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
123 cls._UDPResponder.daemon = True
124 cls._UDPResponder.start()
125 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
126 cls._TCPResponder.daemon = True
127 cls._TCPResponder.start()
128 cls.waitForTCPSocket("127.0.0.1", cls._testServerPort);
129
130 @classmethod
131 def startDNSDist(cls):
132 cls._dnsDistPort = pickAvailablePort()
133 cls._consolePort = pickAvailablePort()
134
135 print("Launching dnsdist..")
136 confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__))
137 params = tuple([getattr(cls, param) for param in cls._config_params])
138 print(params)
139 with open(confFile, 'w') as conf:
140 conf.write("-- Autogenerated by dnsdisttests.py\n")
141 conf.write(f"-- dnsdist will listen on {cls._dnsDistPort}")
142 conf.write(cls._config_template % params)
143 conf.write("setSecurityPollSuffix('')")
144
145 if cls._skipListeningOnCL:
146 dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile ]
147 else:
148 dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile,
149 '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
150
151 if cls._verboseMode:
152 dnsdistcmd.append('-v')
153
154 for acl in cls._acl:
155 dnsdistcmd.extend(['--acl', acl])
156 print(' '.join(dnsdistcmd))
157
158 # validate config with --check-config, which sets client=true, possibly exposing bugs.
159 testcmd = dnsdistcmd + ['--check-config']
160 try:
161 output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
162 except subprocess.CalledProcessError as exc:
163 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
164 if cls._checkConfigExpectedOutput is not None:
165 expectedOutput = cls._checkConfigExpectedOutput
166 else:
167 expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
168 if not cls._verboseMode and output != expectedOutput:
169 raise AssertionError('dnsdist --check-config failed: %s (expected %s)' % (output, expectedOutput))
170
171 logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__))
172 with open(logFile, 'w') as fdLog:
173 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog)
174
175 if cls._alternateListeningAddr and cls._alternateListeningPort:
176 cls.waitForTCPSocket(cls._alternateListeningAddr, cls._alternateListeningPort)
177 else:
178 cls.waitForTCPSocket(cls._dnsDistListeningAddr, cls._dnsDistPort)
179
180 if cls._dnsdist.poll() is not None:
181 print(f"\n*** startDNSDist log for {logFile} ***")
182 with open(logFile, 'r') as fdLog:
183 print(fdLog.read())
184 print(f"*** End startDNSDist log for {logFile} ***")
185 raise AssertionError('%s failed (%d)' % (dnsdistcmd, cls._dnsdist.returncode))
186 time.sleep(cls._extraStartupSleep)
187
188 @classmethod
189 def setUpSockets(cls):
190 print("Setting up UDP socket..")
191 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
192 cls._sock.settimeout(2.0)
193 cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
194
195 @classmethod
196 def killProcess(cls, p):
197 # Don't try to kill it if it's already dead
198 if p.poll() is not None:
199 return
200 try:
201 p.terminate()
202 for count in range(20):
203 x = p.poll()
204 if x is not None:
205 break
206 time.sleep(0.1)
207 if x is None:
208 print("kill...", p, file=sys.stderr)
209 p.kill()
210 p.wait()
211 except OSError as e:
212 # There is a race-condition with the poll() and
213 # kill() statements, when the process is dead on the
214 # kill(), this is fine
215 if e.errno != errno.ESRCH:
216 raise
217
218 @classmethod
219 def setUpClass(cls):
220
221 cls.startResponders()
222 cls.startDNSDist()
223 cls.setUpSockets()
224
225 print("Launching tests..")
226
227 @classmethod
228 def tearDownClass(cls):
229 cls._sock.close()
230 # tell the background threads to stop, if any
231 for backgroundThread in cls._backgroundThreads:
232 cls._backgroundThreads[backgroundThread] = False
233 cls.killProcess(cls._dnsdist)
234
235 @classmethod
236 def _ResponderIncrementCounter(cls):
237 if threading.current_thread().name in cls._responsesCounter:
238 cls._responsesCounter[threading.current_thread().name] += 1
239 else:
240 cls._responsesCounter[threading.current_thread().name] = 1
241
242 @classmethod
243 def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
244 response = None
245 if len(request.question) != 1:
246 print("Skipping query with question count %d" % (len(request.question)))
247 return None
248 healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
249 if healthCheck:
250 cls._healthCheckCounter += 1
251 response = dns.message.make_response(request)
252 else:
253 cls._ResponderIncrementCounter()
254 if not fromQueue.empty():
255 toQueue.put(request, True, cls._queueTimeout)
256 response = fromQueue.get(True, cls._queueTimeout)
257 if response:
258 response = copy.copy(response)
259 response.id = request.id
260
261 if synthesize is not None:
262 response = dns.message.make_response(request)
263 response.set_rcode(synthesize)
264
265 if not response:
266 if cls._answerUnexpected:
267 response = dns.message.make_response(request)
268 response.set_rcode(dns.rcode.SERVFAIL)
269
270 return response
271
272 @classmethod
273 def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None):
274 cls._backgroundThreads[threading.get_native_id()] = True
275 # trailingDataResponse=True means "ignore trailing data".
276 # Other values are either False (meaning "raise an exception")
277 # or are interpreted as a response RCODE for queries with trailing data.
278 # callback is invoked for every -even healthcheck ones- query and should return a raw response
279 ignoreTrailing = trailingDataResponse is True
280
281 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
282 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
283 sock.bind(("127.0.0.1", port))
284 sock.settimeout(0.5)
285 while True:
286 try:
287 data, addr = sock.recvfrom(4096)
288 except socket.timeout:
289 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
290 del cls._backgroundThreads[threading.get_native_id()]
291 break
292 else:
293 continue
294
295 forceRcode = None
296 try:
297 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
298 except dns.message.TrailingJunk as e:
299 print('trailing data exception in UDPResponder')
300 if trailingDataResponse is False or forceRcode is True:
301 raise
302 print("UDP query with trailing data, synthesizing response")
303 request = dns.message.from_wire(data, ignore_trailing=True)
304 forceRcode = trailingDataResponse
305
306 wire = None
307 if callback:
308 wire = callback(request)
309 else:
310 if request.edns > 1:
311 forceRcode = dns.rcode.BADVERS
312 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
313 if response:
314 wire = response.to_wire()
315
316 if not wire:
317 continue
318
319 sock.sendto(wire, addr)
320
321 sock.close()
322
323 @classmethod
324 def handleTCPConnection(cls, conn, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, partialWrite=False):
325 ignoreTrailing = trailingDataResponse is True
326 try:
327 data = conn.recv(2)
328 except Exception as err:
329 data = None
330 print(f'Error while reading query size in TCP responder thread {err=}, {type(err)=}')
331 if not data:
332 conn.close()
333 return
334
335 (datalen,) = struct.unpack("!H", data)
336 data = conn.recv(datalen)
337 forceRcode = None
338 try:
339 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
340 except dns.message.TrailingJunk as e:
341 if trailingDataResponse is False or forceRcode is True:
342 raise
343 print("TCP query with trailing data, synthesizing response")
344 request = dns.message.from_wire(data, ignore_trailing=True)
345 forceRcode = trailingDataResponse
346
347 if callback:
348 wire = callback(request)
349 else:
350 if request.edns > 1:
351 forceRcode = dns.rcode.BADVERS
352 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
353 if response:
354 wire = response.to_wire(max_size=65535)
355
356 if not wire:
357 conn.close()
358 return
359
360 wireLen = struct.pack("!H", len(wire))
361 if partialWrite:
362 for b in wireLen:
363 conn.send(bytes([b]))
364 time.sleep(0.5)
365 else:
366 conn.send(wireLen)
367 conn.send(wire)
368
369 while multipleResponses:
370 # do not block, and stop as soon as the queue is empty, either the next response is already here or we are done
371 # otherwise we might read responses intended for the next connection
372 if fromQueue.empty():
373 break
374
375 response = fromQueue.get(False)
376 if not response:
377 break
378
379 response = copy.copy(response)
380 response.id = request.id
381 wire = response.to_wire(max_size=65535)
382 try:
383 conn.send(struct.pack("!H", len(wire)))
384 conn.send(wire)
385 except socket.error as e:
386 # some of the tests are going to close
387 # the connection on us, just deal with it
388 break
389
390 conn.close()
391
392 @classmethod
393 def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, multipleConnections=False, listeningAddr='127.0.0.1', partialWrite=False):
394 cls._backgroundThreads[threading.get_native_id()] = True
395 # trailingDataResponse=True means "ignore trailing data".
396 # Other values are either False (meaning "raise an exception")
397 # or are interpreted as a response RCODE for queries with trailing data.
398 # callback is invoked for every -even healthcheck ones- query and should return a raw response
399
400 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
401 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
402 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
403 try:
404 sock.bind((listeningAddr, port))
405 except socket.error as e:
406 print("Error binding in the TCP responder: %s" % str(e))
407 sys.exit(1)
408
409 sock.listen(100)
410 sock.settimeout(0.5)
411 if tlsContext:
412 sock = tlsContext.wrap_socket(sock, server_side=True)
413
414 while True:
415 try:
416 (conn, _) = sock.accept()
417 except ssl.SSLError:
418 continue
419 except ConnectionResetError:
420 continue
421 except socket.timeout:
422 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
423 del cls._backgroundThreads[threading.get_native_id()]
424 break
425 else:
426 continue
427
428 conn.settimeout(5.0)
429 if multipleConnections:
430 thread = threading.Thread(name='TCP Connection Handler',
431 target=cls.handleTCPConnection,
432 args=[conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, partialWrite])
433 thread.daemon = True
434 thread.start()
435 else:
436 cls.handleTCPConnection(conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, partialWrite)
437
438 sock.close()
439
440 @classmethod
441 def handleDoHConnection(cls, config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol):
442 ignoreTrailing = trailingDataResponse is True
443 try:
444 h2conn = h2.connection.H2Connection(config=config)
445 h2conn.initiate_connection()
446 conn.sendall(h2conn.data_to_send())
447 except ssl.SSLEOFError as e:
448 print("Unexpected EOF: %s" % (e))
449 return
450 except Exception as err:
451 print(f'Unexpected exception in DoH responder thread (connection init) {err=}, {type(err)=}')
452 return
453
454 dnsData = {}
455
456 if useProxyProtocol:
457 # try to read the entire Proxy Protocol header
458 proxy = ProxyProtocol()
459 header = conn.recv(proxy.HEADER_SIZE)
460 if not header:
461 print('unable to get header')
462 conn.close()
463 return
464
465 if not proxy.parseHeader(header):
466 print('unable to parse header')
467 print(header)
468 conn.close()
469 return
470
471 proxyContent = conn.recv(proxy.contentLen)
472 if not proxyContent:
473 print('unable to get content')
474 conn.close()
475 return
476
477 payload = header + proxyContent
478 toQueue.put(payload, True, cls._queueTimeout)
479
480 # be careful, HTTP/2 headers and data might be in different recv() results
481 requestHeaders = None
482 while True:
483 try:
484 data = conn.recv(65535)
485 except Exception as err:
486 data = None
487 print(f'Unexpected exception in DoH responder thread {err=}, {type(err)=}')
488 if not data:
489 break
490
491 events = h2conn.receive_data(data)
492 for event in events:
493 if isinstance(event, h2.events.RequestReceived):
494 requestHeaders = event.headers
495 if isinstance(event, h2.events.DataReceived):
496 h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id)
497 if not event.stream_id in dnsData:
498 dnsData[event.stream_id] = b''
499 dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data)
500 if event.stream_ended:
501 forceRcode = None
502 status = 200
503 try:
504 request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing)
505 except dns.message.TrailingJunk as e:
506 if trailingDataResponse is False or forceRcode is True:
507 raise
508 print("DOH query with trailing data, synthesizing response")
509 request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True)
510 forceRcode = trailingDataResponse
511
512 if callback:
513 status, wire = callback(request, requestHeaders, fromQueue, toQueue)
514 else:
515 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
516 if response:
517 wire = response.to_wire(max_size=65535)
518
519 if not wire:
520 conn.close()
521 conn = None
522 break
523
524 headers = [
525 (':status', str(status)),
526 ('content-length', str(len(wire))),
527 ('content-type', 'application/dns-message'),
528 ]
529 h2conn.send_headers(stream_id=event.stream_id, headers=headers)
530 h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True)
531
532 data_to_send = h2conn.data_to_send()
533 if data_to_send:
534 conn.sendall(data_to_send)
535
536 if conn is None:
537 break
538
539 if conn is not None:
540 conn.close()
541
542 @classmethod
543 def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False):
544 cls._backgroundThreads[threading.get_native_id()] = True
545 # trailingDataResponse=True means "ignore trailing data".
546 # Other values are either False (meaning "raise an exception")
547 # or are interpreted as a response RCODE for queries with trailing data.
548 # callback is invoked for every -even healthcheck ones- query and should return a raw response
549
550 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
551 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
552 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
553 try:
554 sock.bind(("127.0.0.1", port))
555 except socket.error as e:
556 print("Error binding in the TCP responder: %s" % str(e))
557 sys.exit(1)
558
559 sock.listen(100)
560 sock.settimeout(0.5)
561 if tlsContext:
562 sock = tlsContext.wrap_socket(sock, server_side=True)
563
564 config = h2.config.H2Configuration(client_side=False)
565
566 while True:
567 try:
568 (conn, _) = sock.accept()
569 except ssl.SSLError:
570 continue
571 except ConnectionResetError:
572 continue
573 except socket.timeout:
574 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
575 del cls._backgroundThreads[threading.get_native_id()]
576 break
577 else:
578 continue
579
580 conn.settimeout(5.0)
581 thread = threading.Thread(name='DoH Connection Handler',
582 target=cls.handleDoHConnection,
583 args=[config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol])
584 thread.daemon = True
585 thread.start()
586
587 sock.close()
588
589 @classmethod
590 def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
591 if useQueue and response is not None:
592 cls._toResponderQueue.put(response, True, timeout)
593
594 if timeout:
595 cls._sock.settimeout(timeout)
596
597 try:
598 if not rawQuery:
599 query = query.to_wire()
600 cls._sock.send(query)
601 data = cls._sock.recv(4096)
602 except socket.timeout:
603 data = None
604 finally:
605 if timeout:
606 cls._sock.settimeout(None)
607
608 receivedQuery = None
609 message = None
610 if useQueue and not cls._fromResponderQueue.empty():
611 receivedQuery = cls._fromResponderQueue.get(True, timeout)
612 if data:
613 message = dns.message.from_wire(data)
614 return (receivedQuery, message)
615
616 @classmethod
617 def openTCPConnection(cls, timeout=None, port=None):
618 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
619 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
620 if timeout:
621 sock.settimeout(timeout)
622
623 if not port:
624 port = cls._dnsDistPort
625
626 sock.connect(("127.0.0.1", port))
627 return sock
628
629 @classmethod
630 def openTLSConnection(cls, port, serverName, caCert=None, timeout=None, alpn=[]):
631 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
632 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
633 if timeout:
634 sock.settimeout(timeout)
635
636 # 2.7.9+
637 if hasattr(ssl, 'create_default_context'):
638 sslctx = ssl.create_default_context(cafile=caCert)
639 if len(alpn)> 0 and hasattr(sslctx, 'set_alpn_protocols'):
640 sslctx.set_alpn_protocols(alpn)
641 sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
642 else:
643 sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
644
645 sslsock.connect(("127.0.0.1", port))
646 return sslsock
647
648 @classmethod
649 def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0):
650 if not rawQuery:
651 wire = query.to_wire()
652 else:
653 wire = query
654
655 if response:
656 cls._toResponderQueue.put(response, True, timeout)
657
658 sock.send(struct.pack("!H", len(wire)))
659 sock.send(wire)
660
661 @classmethod
662 def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
663 message = None
664 data = sock.recv(2)
665 if data:
666 (datalen,) = struct.unpack("!H", data)
667 print(datalen)
668 data = sock.recv(datalen)
669 if data:
670 print(data)
671 message = dns.message.from_wire(data)
672
673 print(useQueue)
674 if useQueue and not cls._fromResponderQueue.empty():
675 receivedQuery = cls._fromResponderQueue.get(True, timeout)
676 print(receivedQuery)
677 return (receivedQuery, message)
678 else:
679 print("queue empty")
680 return message
681
682 @classmethod
683 def sendDOTQuery(cls, port, serverName, query, response, caFile, useQueue=True):
684 conn = cls.openTLSConnection(port, serverName, caFile)
685 cls.sendTCPQueryOverConnection(conn, query, response=response)
686 if useQueue:
687 return cls.recvTCPResponseOverConnection(conn, useQueue=useQueue)
688 return None, cls.recvTCPResponseOverConnection(conn, useQueue=useQueue)
689
690 @classmethod
691 def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
692 message = None
693 if useQueue:
694 cls._toResponderQueue.put(response, True, timeout)
695
696 sock = cls.openTCPConnection(timeout)
697
698 try:
699 cls.sendTCPQueryOverConnection(sock, query, rawQuery)
700 message = cls.recvTCPResponseOverConnection(sock)
701 except socket.timeout as e:
702 print("Timeout while sending or receiving TCP data: %s" % (str(e)))
703 except socket.error as e:
704 print("Network error: %s" % (str(e)))
705 finally:
706 sock.close()
707
708 receivedQuery = None
709 print(useQueue)
710 if useQueue and not cls._fromResponderQueue.empty():
711 print(receivedQuery)
712 receivedQuery = cls._fromResponderQueue.get(True, timeout)
713 else:
714 print("queue is empty")
715
716 return (receivedQuery, message)
717
718 @classmethod
719 def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False):
720 if useQueue:
721 for response in responses:
722 cls._toResponderQueue.put(response, True, timeout)
723 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
724 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
725 if timeout:
726 sock.settimeout(timeout)
727
728 sock.connect(("127.0.0.1", cls._dnsDistPort))
729 messages = []
730
731 try:
732 if not rawQuery:
733 wire = query.to_wire()
734 else:
735 wire = query
736
737 sock.send(struct.pack("!H", len(wire)))
738 sock.send(wire)
739 while True:
740 data = sock.recv(2)
741 if not data:
742 break
743 (datalen,) = struct.unpack("!H", data)
744 data = sock.recv(datalen)
745 messages.append(dns.message.from_wire(data))
746
747 except socket.timeout as e:
748 print("Timeout while receiving multiple TCP responses: %s" % (str(e)))
749 except socket.error as e:
750 print("Network error: %s" % (str(e)))
751 finally:
752 sock.close()
753
754 receivedQuery = None
755 if useQueue and not cls._fromResponderQueue.empty():
756 receivedQuery = cls._fromResponderQueue.get(True, timeout)
757 return (receivedQuery, messages)
758
759 def setUp(self):
760 # This function is called before every test
761
762 # Clear the responses counters
763 self._responsesCounter.clear()
764
765 self._healthCheckCounter = 0
766
767 # Make sure the queues are empty, in case
768 # a previous test failed
769 self.clearResponderQueues()
770
771 super(DNSDistTest, self).setUp()
772
773 @classmethod
774 def clearToResponderQueue(cls):
775 while not cls._toResponderQueue.empty():
776 cls._toResponderQueue.get(False)
777
778 @classmethod
779 def clearFromResponderQueue(cls):
780 while not cls._fromResponderQueue.empty():
781 cls._fromResponderQueue.get(False)
782
783 @classmethod
784 def clearResponderQueues(cls):
785 cls.clearToResponderQueue()
786 cls.clearFromResponderQueue()
787
788 @staticmethod
789 def generateConsoleKey():
790 return libnacl.utils.salsa_key()
791
792 @classmethod
793 def _encryptConsole(cls, command, nonce):
794 command = command.encode('UTF-8')
795 if cls._consoleKey is None:
796 return command
797 return libnacl.crypto_secretbox(command, nonce, cls._consoleKey)
798
799 @classmethod
800 def _decryptConsole(cls, command, nonce):
801 if cls._consoleKey is None:
802 result = command
803 else:
804 result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey)
805 return result.decode('UTF-8')
806
807 @classmethod
808 def sendConsoleCommand(cls, command, timeout=5.0):
809 ourNonce = libnacl.utils.rand_nonce()
810 theirNonce = None
811 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
812 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
813 if timeout:
814 sock.settimeout(timeout)
815
816 sock.connect(("127.0.0.1", cls._consolePort))
817 sock.send(ourNonce)
818 theirNonce = sock.recv(len(ourNonce))
819 if len(theirNonce) != len(ourNonce):
820 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce)))
821 if len(theirNonce) == 0:
822 raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce)))
823 return None
824
825 halfNonceSize = int(len(ourNonce) / 2)
826 readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:]
827 writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:]
828 msg = cls._encryptConsole(command, writingNonce)
829 sock.send(struct.pack("!I", len(msg)))
830 sock.send(msg)
831 data = sock.recv(4)
832 if not data:
833 raise socket.error("Got EOF while reading the response size")
834
835 (responseLen,) = struct.unpack("!I", data)
836 data = sock.recv(responseLen)
837 response = cls._decryptConsole(data, readingNonce)
838 sock.close()
839 return response
840
841 def compareOptions(self, a, b):
842 self.assertEqual(len(a), len(b))
843 for idx in range(len(a)):
844 self.assertEqual(a[idx], b[idx])
845
846 def checkMessageNoEDNS(self, expected, received):
847 self.assertEqual(expected, received)
848 self.assertEqual(received.edns, -1)
849 self.assertEqual(len(received.options), 0)
850
851 def checkMessageEDNSWithoutOptions(self, expected, received):
852 self.assertEqual(expected, received)
853 self.assertEqual(received.edns, 0)
854 self.assertEqual(expected.payload, received.payload)
855
856 def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
857 self.assertEqual(expected, received)
858 self.assertEqual(received.edns, 0)
859 self.assertEqual(expected.payload, received.payload)
860 self.assertEqual(len(received.options), withCookies)
861 if withCookies:
862 for option in received.options:
863 self.assertEqual(option.otype, 10)
864 else:
865 for option in received.options:
866 self.assertNotEqual(option.otype, 10)
867
868 def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
869 self.assertEqual(expected, received)
870 self.assertEqual(received.edns, 0)
871 self.assertEqual(expected.payload, received.payload)
872 self.assertEqual(len(received.options), 1 + additionalOptions)
873 hasECS = False
874 for option in received.options:
875 if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE:
876 hasECS = True
877 else:
878 self.assertNotEqual(additionalOptions, 0)
879
880 self.compareOptions(expected.options, received.options)
881 self.assertTrue(hasECS)
882
883 def checkMessageEDNS(self, expected, received):
884 self.assertEqual(expected, received)
885 self.assertEqual(received.edns, 0)
886 self.assertEqual(expected.payload, received.payload)
887 self.assertEqual(len(expected.options), len(received.options))
888 self.compareOptions(expected.options, received.options)
889
890 def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0):
891 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
892
893 def checkQueryEDNS(self, expected, received):
894 self.checkMessageEDNS(expected, received)
895
896 def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0):
897 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
898
899 def checkQueryEDNSWithoutECS(self, expected, received):
900 self.checkMessageEDNSWithoutECS(expected, received)
901
902 def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0):
903 self.checkMessageEDNSWithoutECS(expected, received, withCookies)
904
905 def checkQueryNoEDNS(self, expected, received):
906 self.checkMessageNoEDNS(expected, received)
907
908 def checkResponseNoEDNS(self, expected, received):
909 self.checkMessageNoEDNS(expected, received)
910
911 @staticmethod
912 def generateNewCertificateAndKey(filePrefix):
913 # generate and sign a new cert
914 cmd = ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', filePrefix + '.key', '-out', filePrefix + '.csr', '-config', 'configServer.conf']
915 output = None
916 try:
917 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
918 output = process.communicate(input='')
919 except subprocess.CalledProcessError as exc:
920 raise AssertionError('openssl req failed (%d): %s' % (exc.returncode, exc.output))
921 cmd = ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', filePrefix + '.csr', '-out', filePrefix + '.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req']
922 output = None
923 try:
924 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
925 output = process.communicate(input='')
926 except subprocess.CalledProcessError as exc:
927 raise AssertionError('openssl x509 failed (%d): %s' % (exc.returncode, exc.output))
928
929 with open(filePrefix + '.chain', 'w') as outFile:
930 for inFileName in [filePrefix + '.pem', 'ca.pem']:
931 with open(inFileName) as inFile:
932 outFile.write(inFile.read())
933
934 cmd = ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', filePrefix + '.pem', '-CAfile', 'ca.pem', '-inkey', filePrefix + '.key', '-out', filePrefix + '.p12']
935 output = None
936 try:
937 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
938 output = process.communicate(input='')
939 except subprocess.CalledProcessError as exc:
940 raise AssertionError('openssl pkcs12 failed (%d): %s' % (exc.returncode, exc.output))
941
942 def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[], v6=False, sourcePort=None, destinationPort=None):
943 proxy = ProxyProtocol()
944 self.assertTrue(proxy.parseHeader(receivedProxyPayload))
945 self.assertEqual(proxy.version, 0x02)
946 self.assertEqual(proxy.command, 0x01)
947 if v6:
948 self.assertEqual(proxy.family, 0x02)
949 else:
950 self.assertEqual(proxy.family, 0x01)
951 if not isTCP:
952 self.assertEqual(proxy.protocol, 0x02)
953 else:
954 self.assertEqual(proxy.protocol, 0x01)
955 self.assertGreater(proxy.contentLen, 0)
956
957 self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload))
958 self.assertEqual(proxy.source, source)
959 self.assertEqual(proxy.destination, destination)
960 if sourcePort:
961 self.assertEqual(proxy.sourcePort, sourcePort)
962 if destinationPort:
963 self.assertEqual(proxy.destinationPort, destinationPort)
964 else:
965 self.assertEqual(proxy.destinationPort, self._dnsDistPort)
966
967 self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload))
968 proxy.values.sort()
969 values.sort()
970 self.assertEqual(proxy.values, values)
971
972 @classmethod
973 def getDOHGetURL(cls, baseurl, query, rawQuery=False):
974 if rawQuery:
975 wire = query
976 else:
977 wire = query.to_wire()
978 param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
979 return baseurl + "?dns=" + param
980
981 @classmethod
982 def openDOHConnection(cls, port, caFile, timeout=2.0):
983 conn = pycurl.Curl()
984 conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
985
986 conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message",
987 "Accept: application/dns-message"])
988 return conn
989
990 @classmethod
991 def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None, conn=None):
992 url = cls.getDOHGetURL(baseurl, query, rawQuery)
993
994 if not conn:
995 conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
996 # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
997 conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
998
999 if useHTTPS:
1000 conn.setopt(pycurl.SSL_VERIFYPEER, 1)
1001 conn.setopt(pycurl.SSL_VERIFYHOST, 2)
1002 if caFile:
1003 conn.setopt(pycurl.CAINFO, caFile)
1004
1005 response_headers = BytesIO()
1006 #conn.setopt(pycurl.VERBOSE, True)
1007 conn.setopt(pycurl.URL, url)
1008 conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
1009
1010 conn.setopt(pycurl.HTTPHEADER, customHeaders)
1011 conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
1012
1013 if response:
1014 if toQueue:
1015 toQueue.put(response, True, timeout)
1016 else:
1017 cls._toResponderQueue.put(response, True, timeout)
1018
1019 receivedQuery = None
1020 message = None
1021 cls._response_headers = ''
1022 data = conn.perform_rb()
1023 cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
1024 if cls._rcode == 200 and not rawResponse:
1025 message = dns.message.from_wire(data)
1026 elif rawResponse:
1027 message = data
1028
1029 if useQueue:
1030 if fromQueue:
1031 if not fromQueue.empty():
1032 receivedQuery = fromQueue.get(True, timeout)
1033 else:
1034 if not cls._fromResponderQueue.empty():
1035 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1036
1037 cls._response_headers = response_headers.getvalue()
1038 return (receivedQuery, message)
1039
1040 @classmethod
1041 def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
1042 url = baseurl
1043 conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
1044 response_headers = BytesIO()
1045 #conn.setopt(pycurl.VERBOSE, True)
1046 conn.setopt(pycurl.URL, url)
1047 conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
1048 # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
1049 conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
1050 if useHTTPS:
1051 conn.setopt(pycurl.SSL_VERIFYPEER, 1)
1052 conn.setopt(pycurl.SSL_VERIFYHOST, 2)
1053 if caFile:
1054 conn.setopt(pycurl.CAINFO, caFile)
1055
1056 conn.setopt(pycurl.HTTPHEADER, customHeaders)
1057 conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
1058 conn.setopt(pycurl.POST, True)
1059 data = query
1060 if not rawQuery:
1061 data = data.to_wire()
1062
1063 conn.setopt(pycurl.POSTFIELDS, data)
1064
1065 if response:
1066 cls._toResponderQueue.put(response, True, timeout)
1067
1068 receivedQuery = None
1069 message = None
1070 cls._response_headers = ''
1071 data = conn.perform_rb()
1072 cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
1073 if cls._rcode == 200 and not rawResponse:
1074 message = dns.message.from_wire(data)
1075 elif rawResponse:
1076 message = data
1077
1078 if useQueue and not cls._fromResponderQueue.empty():
1079 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1080
1081 cls._response_headers = response_headers.getvalue()
1082 return (receivedQuery, message)
1083
1084 def sendDOHQueryWrapper(self, query, response, useQueue=True):
1085 return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
1086
1087 def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True):
1088 return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
1089
1090 def sendDOHWithH2OQueryWrapper(self, query, response, useQueue=True):
1091 return self.sendDOHQuery(self._dohWithH2OServerPort, self._serverName, self._dohWithH2OBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
1092
1093 def sendDOTQueryWrapper(self, query, response, useQueue=True):
1094 return self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response, self._caCert, useQueue=useQueue)
1095
1096 def sendDOQQueryWrapper(self, query, response, useQueue=True):
1097 return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName)
1098
1099 def sendDOH3QueryWrapper(self, query, response, useQueue=True):
1100 return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName)
1101 @classmethod
1102 def getDOQConnection(cls, port, caFile=None, source=None, source_port=0):
1103
1104 manager = dns.quic.SyncQuicManager(
1105 verify_mode=caFile
1106 )
1107
1108 return manager.connect('127.0.0.1', port, source, source_port)
1109
1110 @classmethod
1111 def sendDOQQuery(cls, port, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None):
1112
1113 if response:
1114 if toQueue:
1115 toQueue.put(response, True, timeout)
1116 else:
1117 cls._toResponderQueue.put(response, True, timeout)
1118
1119 (message, _) = quic_query(query, '127.0.0.1', timeout, port, verify=caFile, server_hostname=serverName)
1120
1121 receivedQuery = None
1122
1123 if useQueue:
1124 if fromQueue:
1125 if not fromQueue.empty():
1126 receivedQuery = fromQueue.get(True, timeout)
1127 else:
1128 if not cls._fromResponderQueue.empty():
1129 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1130
1131 return (receivedQuery, message)
1132
1133 @classmethod
1134 def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False):
1135
1136 if response:
1137 if toQueue:
1138 toQueue.put(response, True, timeout)
1139 else:
1140 cls._toResponderQueue.put(response, True, timeout)
1141
1142 message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post)
1143
1144 receivedQuery = None
1145
1146 if useQueue:
1147 if fromQueue:
1148 if not fromQueue.empty():
1149 receivedQuery = fromQueue.get(True, timeout)
1150 else:
1151 if not cls._fromResponderQueue.empty():
1152 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1153
1154 return (receivedQuery, message)