]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/dnsdisttests.py
Merge pull request #13923 from rgacogne/ddist-xfr-response-chain
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
1 #!/usr/bin/env python2
2
3 import base64
4 import copy
5 import errno
6 import os
7 import socket
8 import ssl
9 import struct
10 import subprocess
11 import sys
12 import threading
13 import time
14 import unittest
15
16 import clientsubnetoption
17
18 import dns
19 import dns.message
20
21 import libnacl
22 import libnacl.utils
23
24 import h2.connection
25 import h2.events
26 import h2.config
27
28 import pycurl
29 from io import BytesIO
30
31 from doqclient import quic_query
32 from doh3client import doh3_query
33
34 from eqdnsmessage import AssertEqualDNSMessageMixin
35 from proxyprotocol import ProxyProtocol
36
37 # Python2/3 compatibility hacks
38 try:
39 from queue import Queue
40 except ImportError:
41 from Queue import Queue
42
43 try:
44 range = xrange
45 except NameError:
46 pass
47
48 def getWorkerID():
49 if not 'PYTEST_XDIST_WORKER' in os.environ:
50 return 0
51 workerName = os.environ['PYTEST_XDIST_WORKER']
52 return int(workerName[2:])
53
54 workerPorts = {}
55
56 def pickAvailablePort():
57 global workerPorts
58 workerID = getWorkerID()
59 if workerID in workerPorts:
60 port = workerPorts[workerID] + 1
61 else:
62 port = 11000 + (workerID * 1000)
63 workerPorts[workerID] = port
64 return port
65
66 class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
67 """
68 Set up a dnsdist instance and responder threads.
69 Queries sent to dnsdist are relayed to the responder threads,
70 who reply with the response provided by the tests themselves
71 on a queue. Responder threads also queue the queries received
72 from dnsdist on a separate queue, allowing the tests to check
73 that the queries sent from dnsdist were as expected.
74 """
75 _dnsDistListeningAddr = "127.0.0.1"
76 _toResponderQueue = Queue()
77 _fromResponderQueue = Queue()
78 _queueTimeout = 1
79 _dnsdist = None
80 _responsesCounter = {}
81 _config_template = """
82 """
83 _config_params = ['_testServerPort']
84 _acl = ['127.0.0.1/32']
85 _consoleKey = None
86 _healthCheckName = 'a.root-servers.net.'
87 _healthCheckCounter = 0
88 _answerUnexpected = True
89 _checkConfigExpectedOutput = None
90 _verboseMode = False
91 _sudoMode = False
92 _skipListeningOnCL = False
93 _alternateListeningAddr = None
94 _alternateListeningPort = None
95 _backgroundThreads = {}
96 _UDPResponder = None
97 _TCPResponder = None
98 _extraStartupSleep = 0
99 _dnsDistPort = pickAvailablePort()
100 _consolePort = pickAvailablePort()
101 _testServerPort = pickAvailablePort()
102
103 @classmethod
104 def waitForTCPSocket(cls, ipaddress, port):
105 for try_number in range(0, 20):
106 try:
107 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
108 sock.settimeout(1.0)
109 sock.connect((ipaddress, port))
110 sock.close()
111 return
112 except Exception as err:
113 if err.errno != errno.ECONNREFUSED:
114 print(f'Error occurred: {try_number} {err}', file=sys.stderr)
115 time.sleep(0.1)
116 # We assume the dnsdist instance does not listen. That's fine.
117
118 @classmethod
119 def startResponders(cls):
120 print("Launching responders..")
121 cls._testServerPort = pickAvailablePort()
122
123 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
124 cls._UDPResponder.daemon = True
125 cls._UDPResponder.start()
126 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
127 cls._TCPResponder.daemon = True
128 cls._TCPResponder.start()
129 cls.waitForTCPSocket("127.0.0.1", cls._testServerPort);
130
131 @classmethod
132 def startDNSDist(cls):
133 cls._dnsDistPort = pickAvailablePort()
134 cls._consolePort = pickAvailablePort()
135
136 print("Launching dnsdist..")
137 confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__))
138 params = tuple([getattr(cls, param) for param in cls._config_params])
139 print(params)
140 with open(confFile, 'w') as conf:
141 conf.write("-- Autogenerated by dnsdisttests.py\n")
142 conf.write(f"-- dnsdist will listen on {cls._dnsDistPort}")
143 conf.write(cls._config_template % params)
144 conf.write("setSecurityPollSuffix('')")
145
146 if cls._skipListeningOnCL:
147 dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile ]
148 else:
149 dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile,
150 '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
151
152 if cls._verboseMode:
153 dnsdistcmd.append('-v')
154 if cls._sudoMode:
155 preserve_env_values = ['LD_LIBRARY_PATH', 'LLVM_PROFILE_FILE']
156 for value in preserve_env_values:
157 if value in os.environ:
158 dnsdistcmd.insert(0, value + '=' + os.environ[value])
159 dnsdistcmd.insert(0, 'sudo')
160
161 for acl in cls._acl:
162 dnsdistcmd.extend(['--acl', acl])
163 print(' '.join(dnsdistcmd))
164
165 # validate config with --check-config, which sets client=true, possibly exposing bugs.
166 testcmd = dnsdistcmd + ['--check-config']
167 try:
168 output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
169 except subprocess.CalledProcessError as exc:
170 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
171 if cls._checkConfigExpectedOutput is not None:
172 expectedOutput = cls._checkConfigExpectedOutput
173 else:
174 expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
175 if not cls._verboseMode and output != expectedOutput:
176 raise AssertionError('dnsdist --check-config failed: %s (expected %s)' % (output, expectedOutput))
177
178 logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__))
179 with open(logFile, 'w') as fdLog:
180 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog)
181
182 if cls._alternateListeningAddr and cls._alternateListeningPort:
183 cls.waitForTCPSocket(cls._alternateListeningAddr, cls._alternateListeningPort)
184 else:
185 cls.waitForTCPSocket(cls._dnsDistListeningAddr, cls._dnsDistPort)
186
187 if cls._dnsdist.poll() is not None:
188 print(f"\n*** startDNSDist log for {logFile} ***")
189 with open(logFile, 'r') as fdLog:
190 print(fdLog.read())
191 print(f"*** End startDNSDist log for {logFile} ***")
192 raise AssertionError('%s failed (%d)' % (dnsdistcmd, cls._dnsdist.returncode))
193 time.sleep(cls._extraStartupSleep)
194
195 @classmethod
196 def setUpSockets(cls):
197 print("Setting up UDP socket..")
198 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
199 cls._sock.settimeout(2.0)
200 cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
201
202 @classmethod
203 def killProcess(cls, p):
204 # Don't try to kill it if it's already dead
205 if p.poll() is not None:
206 return
207 try:
208 p.terminate()
209 for count in range(20):
210 x = p.poll()
211 if x is not None:
212 break
213 time.sleep(0.1)
214 if x is None:
215 print("kill...", p, file=sys.stderr)
216 p.kill()
217 p.wait()
218 except OSError as e:
219 # There is a race-condition with the poll() and
220 # kill() statements, when the process is dead on the
221 # kill(), this is fine
222 if e.errno != errno.ESRCH:
223 raise
224
225 @classmethod
226 def setUpClass(cls):
227
228 cls.startResponders()
229 cls.startDNSDist()
230 cls.setUpSockets()
231
232 print("Launching tests..")
233
234 @classmethod
235 def tearDownClass(cls):
236 cls._sock.close()
237 # tell the background threads to stop, if any
238 for backgroundThread in cls._backgroundThreads:
239 cls._backgroundThreads[backgroundThread] = False
240 cls.killProcess(cls._dnsdist)
241
242 @classmethod
243 def _ResponderIncrementCounter(cls):
244 if threading.current_thread().name in cls._responsesCounter:
245 cls._responsesCounter[threading.current_thread().name] += 1
246 else:
247 cls._responsesCounter[threading.current_thread().name] = 1
248
249 @classmethod
250 def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
251 response = None
252 if len(request.question) != 1:
253 print("Skipping query with question count %d" % (len(request.question)))
254 return None
255 healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
256 if healthCheck:
257 cls._healthCheckCounter += 1
258 response = dns.message.make_response(request)
259 else:
260 cls._ResponderIncrementCounter()
261 if not fromQueue.empty():
262 toQueue.put(request, True, cls._queueTimeout)
263 response = fromQueue.get(True, cls._queueTimeout)
264 if response:
265 response = copy.copy(response)
266 response.id = request.id
267
268 if synthesize is not None:
269 response = dns.message.make_response(request)
270 response.set_rcode(synthesize)
271
272 if not response:
273 if cls._answerUnexpected:
274 response = dns.message.make_response(request)
275 response.set_rcode(dns.rcode.SERVFAIL)
276
277 return response
278
279 @classmethod
280 def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None):
281 cls._backgroundThreads[threading.get_native_id()] = True
282 # trailingDataResponse=True means "ignore trailing data".
283 # Other values are either False (meaning "raise an exception")
284 # or are interpreted as a response RCODE for queries with trailing data.
285 # callback is invoked for every -even healthcheck ones- query and should return a raw response
286 ignoreTrailing = trailingDataResponse is True
287
288 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
289 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
290 sock.bind(("127.0.0.1", port))
291 sock.settimeout(0.5)
292 while True:
293 try:
294 data, addr = sock.recvfrom(4096)
295 except socket.timeout:
296 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
297 del cls._backgroundThreads[threading.get_native_id()]
298 break
299 else:
300 continue
301
302 forceRcode = None
303 try:
304 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
305 except dns.message.TrailingJunk as e:
306 print('trailing data exception in UDPResponder')
307 if trailingDataResponse is False or forceRcode is True:
308 raise
309 print("UDP query with trailing data, synthesizing response")
310 request = dns.message.from_wire(data, ignore_trailing=True)
311 forceRcode = trailingDataResponse
312
313 wire = None
314 if callback:
315 wire = callback(request)
316 else:
317 if request.edns > 1:
318 forceRcode = dns.rcode.BADVERS
319 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
320 if response:
321 wire = response.to_wire()
322
323 if not wire:
324 continue
325
326 sock.sendto(wire, addr)
327
328 sock.close()
329
330 @classmethod
331 def handleTCPConnection(cls, conn, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, partialWrite=False):
332 ignoreTrailing = trailingDataResponse is True
333 try:
334 data = conn.recv(2)
335 except Exception as err:
336 data = None
337 print(f'Error while reading query size in TCP responder thread {err=}, {type(err)=}')
338 if not data:
339 conn.close()
340 return
341
342 (datalen,) = struct.unpack("!H", data)
343 data = conn.recv(datalen)
344 forceRcode = None
345 try:
346 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
347 except dns.message.TrailingJunk as e:
348 if trailingDataResponse is False or forceRcode is True:
349 raise
350 print("TCP query with trailing data, synthesizing response")
351 request = dns.message.from_wire(data, ignore_trailing=True)
352 forceRcode = trailingDataResponse
353
354 if callback:
355 wire = callback(request)
356 else:
357 if request.edns > 1:
358 forceRcode = dns.rcode.BADVERS
359 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
360 if response:
361 wire = response.to_wire(max_size=65535)
362
363 if not wire:
364 conn.close()
365 return
366
367 wireLen = struct.pack("!H", len(wire))
368 if partialWrite:
369 for b in wireLen:
370 conn.send(bytes([b]))
371 time.sleep(0.5)
372 else:
373 conn.send(wireLen)
374 conn.send(wire)
375
376 while multipleResponses:
377 # do not block, and stop as soon as the queue is empty, either the next response is already here or we are done
378 # otherwise we might read responses intended for the next connection
379 if fromQueue.empty():
380 break
381
382 response = fromQueue.get(False)
383 if not response:
384 break
385
386 response = copy.copy(response)
387 response.id = request.id
388 wire = response.to_wire(max_size=65535)
389 try:
390 conn.send(struct.pack("!H", len(wire)))
391 conn.send(wire)
392 except socket.error as e:
393 # some of the tests are going to close
394 # the connection on us, just deal with it
395 break
396
397 conn.close()
398
399 @classmethod
400 def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, multipleConnections=False, listeningAddr='127.0.0.1', partialWrite=False):
401 cls._backgroundThreads[threading.get_native_id()] = True
402 # trailingDataResponse=True means "ignore trailing data".
403 # Other values are either False (meaning "raise an exception")
404 # or are interpreted as a response RCODE for queries with trailing data.
405 # callback is invoked for every -even healthcheck ones- query and should return a raw response
406
407 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
408 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
409 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
410 try:
411 sock.bind((listeningAddr, port))
412 except socket.error as e:
413 print("Error binding in the TCP responder: %s" % str(e))
414 sys.exit(1)
415
416 sock.listen(100)
417 sock.settimeout(0.5)
418 if tlsContext:
419 sock = tlsContext.wrap_socket(sock, server_side=True)
420
421 while True:
422 try:
423 (conn, _) = sock.accept()
424 except ssl.SSLError:
425 continue
426 except ConnectionResetError:
427 continue
428 except socket.timeout:
429 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
430 del cls._backgroundThreads[threading.get_native_id()]
431 break
432 else:
433 continue
434
435 conn.settimeout(5.0)
436 if multipleConnections:
437 thread = threading.Thread(name='TCP Connection Handler',
438 target=cls.handleTCPConnection,
439 args=[conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, partialWrite])
440 thread.daemon = True
441 thread.start()
442 else:
443 cls.handleTCPConnection(conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, partialWrite)
444
445 sock.close()
446
447 @classmethod
448 def handleDoHConnection(cls, config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol):
449 ignoreTrailing = trailingDataResponse is True
450 try:
451 h2conn = h2.connection.H2Connection(config=config)
452 h2conn.initiate_connection()
453 conn.sendall(h2conn.data_to_send())
454 except ssl.SSLEOFError as e:
455 print("Unexpected EOF: %s" % (e))
456 return
457 except Exception as err:
458 print(f'Unexpected exception in DoH responder thread (connection init) {err=}, {type(err)=}')
459 return
460
461 dnsData = {}
462
463 if useProxyProtocol:
464 # try to read the entire Proxy Protocol header
465 proxy = ProxyProtocol()
466 header = conn.recv(proxy.HEADER_SIZE)
467 if not header:
468 print('unable to get header')
469 conn.close()
470 return
471
472 if not proxy.parseHeader(header):
473 print('unable to parse header')
474 print(header)
475 conn.close()
476 return
477
478 proxyContent = conn.recv(proxy.contentLen)
479 if not proxyContent:
480 print('unable to get content')
481 conn.close()
482 return
483
484 payload = header + proxyContent
485 toQueue.put(payload, True, cls._queueTimeout)
486
487 # be careful, HTTP/2 headers and data might be in different recv() results
488 requestHeaders = None
489 while True:
490 try:
491 data = conn.recv(65535)
492 except Exception as err:
493 data = None
494 print(f'Unexpected exception in DoH responder thread {err=}, {type(err)=}')
495 if not data:
496 break
497
498 events = h2conn.receive_data(data)
499 for event in events:
500 if isinstance(event, h2.events.RequestReceived):
501 requestHeaders = event.headers
502 if isinstance(event, h2.events.DataReceived):
503 h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id)
504 if not event.stream_id in dnsData:
505 dnsData[event.stream_id] = b''
506 dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data)
507 if event.stream_ended:
508 forceRcode = None
509 status = 200
510 try:
511 request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing)
512 except dns.message.TrailingJunk as e:
513 if trailingDataResponse is False or forceRcode is True:
514 raise
515 print("DOH query with trailing data, synthesizing response")
516 request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True)
517 forceRcode = trailingDataResponse
518
519 if callback:
520 status, wire = callback(request, requestHeaders, fromQueue, toQueue)
521 else:
522 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
523 if response:
524 wire = response.to_wire(max_size=65535)
525
526 if not wire:
527 conn.close()
528 conn = None
529 break
530
531 headers = [
532 (':status', str(status)),
533 ('content-length', str(len(wire))),
534 ('content-type', 'application/dns-message'),
535 ]
536 h2conn.send_headers(stream_id=event.stream_id, headers=headers)
537 h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True)
538
539 data_to_send = h2conn.data_to_send()
540 if data_to_send:
541 conn.sendall(data_to_send)
542
543 if conn is None:
544 break
545
546 if conn is not None:
547 conn.close()
548
549 @classmethod
550 def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False):
551 cls._backgroundThreads[threading.get_native_id()] = True
552 # trailingDataResponse=True means "ignore trailing data".
553 # Other values are either False (meaning "raise an exception")
554 # or are interpreted as a response RCODE for queries with trailing data.
555 # callback is invoked for every -even healthcheck ones- query and should return a raw response
556
557 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
558 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
559 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
560 try:
561 sock.bind(("127.0.0.1", port))
562 except socket.error as e:
563 print("Error binding in the TCP responder: %s" % str(e))
564 sys.exit(1)
565
566 sock.listen(100)
567 sock.settimeout(0.5)
568 if tlsContext:
569 sock = tlsContext.wrap_socket(sock, server_side=True)
570
571 config = h2.config.H2Configuration(client_side=False)
572
573 while True:
574 try:
575 (conn, _) = sock.accept()
576 except ssl.SSLError:
577 continue
578 except ConnectionResetError:
579 continue
580 except socket.timeout:
581 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
582 del cls._backgroundThreads[threading.get_native_id()]
583 break
584 else:
585 continue
586
587 conn.settimeout(5.0)
588 thread = threading.Thread(name='DoH Connection Handler',
589 target=cls.handleDoHConnection,
590 args=[config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol])
591 thread.daemon = True
592 thread.start()
593
594 sock.close()
595
596 @classmethod
597 def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
598 if useQueue and response is not None:
599 cls._toResponderQueue.put(response, True, timeout)
600
601 if timeout:
602 cls._sock.settimeout(timeout)
603
604 try:
605 if not rawQuery:
606 query = query.to_wire()
607 cls._sock.send(query)
608 data = cls._sock.recv(4096)
609 except socket.timeout:
610 data = None
611 finally:
612 if timeout:
613 cls._sock.settimeout(None)
614
615 receivedQuery = None
616 message = None
617 if useQueue and not cls._fromResponderQueue.empty():
618 receivedQuery = cls._fromResponderQueue.get(True, timeout)
619 if data:
620 message = dns.message.from_wire(data)
621 return (receivedQuery, message)
622
623 @classmethod
624 def openTCPConnection(cls, timeout=None, port=None):
625 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
626 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
627 if timeout:
628 sock.settimeout(timeout)
629
630 if not port:
631 port = cls._dnsDistPort
632
633 sock.connect(("127.0.0.1", port))
634 return sock
635
636 @classmethod
637 def openTLSConnection(cls, port, serverName, caCert=None, timeout=None, alpn=[]):
638 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
639 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
640 if timeout:
641 sock.settimeout(timeout)
642
643 # 2.7.9+
644 if hasattr(ssl, 'create_default_context'):
645 sslctx = ssl.create_default_context(cafile=caCert)
646 if len(alpn)> 0 and hasattr(sslctx, 'set_alpn_protocols'):
647 sslctx.set_alpn_protocols(alpn)
648 sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
649 else:
650 sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
651
652 sslsock.connect(("127.0.0.1", port))
653 return sslsock
654
655 @classmethod
656 def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0):
657 if not rawQuery:
658 wire = query.to_wire()
659 else:
660 wire = query
661
662 if response:
663 cls._toResponderQueue.put(response, True, timeout)
664
665 sock.send(struct.pack("!H", len(wire)))
666 sock.send(wire)
667
668 @classmethod
669 def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
670 message = None
671 data = sock.recv(2)
672 if data:
673 (datalen,) = struct.unpack("!H", data)
674 print(datalen)
675 data = sock.recv(datalen)
676 if data:
677 print(data)
678 message = dns.message.from_wire(data)
679
680 print(useQueue)
681 if useQueue and not cls._fromResponderQueue.empty():
682 receivedQuery = cls._fromResponderQueue.get(True, timeout)
683 print(receivedQuery)
684 return (receivedQuery, message)
685 else:
686 print("queue empty")
687 return message
688
689 @classmethod
690 def sendDOTQuery(cls, port, serverName, query, response, caFile, useQueue=True):
691 conn = cls.openTLSConnection(port, serverName, caFile)
692 cls.sendTCPQueryOverConnection(conn, query, response=response)
693 if useQueue:
694 return cls.recvTCPResponseOverConnection(conn, useQueue=useQueue)
695 return None, cls.recvTCPResponseOverConnection(conn, useQueue=useQueue)
696
697 @classmethod
698 def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
699 message = None
700 if useQueue:
701 cls._toResponderQueue.put(response, True, timeout)
702
703 try:
704 sock = cls.openTCPConnection(timeout)
705 except socket.timeout as e:
706 print("Timeout while opening TCP connection: %s" % (str(e)))
707 return (None, None)
708
709 try:
710 cls.sendTCPQueryOverConnection(sock, query, rawQuery, timeout=timeout)
711 message = cls.recvTCPResponseOverConnection(sock, timeout=timeout)
712 except socket.timeout as e:
713 print("Timeout while sending or receiving TCP data: %s" % (str(e)))
714 except socket.error as e:
715 print("Network error: %s" % (str(e)))
716 finally:
717 sock.close()
718
719 receivedQuery = None
720 print(useQueue)
721 if useQueue and not cls._fromResponderQueue.empty():
722 print(receivedQuery)
723 receivedQuery = cls._fromResponderQueue.get(True, timeout)
724 else:
725 print("queue is empty")
726
727 return (receivedQuery, message)
728
729 @classmethod
730 def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False):
731 if useQueue:
732 for response in responses:
733 cls._toResponderQueue.put(response, True, timeout)
734 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
735 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
736 if timeout:
737 sock.settimeout(timeout)
738
739 sock.connect(("127.0.0.1", cls._dnsDistPort))
740 messages = []
741
742 try:
743 if not rawQuery:
744 wire = query.to_wire()
745 else:
746 wire = query
747
748 sock.send(struct.pack("!H", len(wire)))
749 sock.send(wire)
750 while True:
751 data = sock.recv(2)
752 if not data:
753 break
754 (datalen,) = struct.unpack("!H", data)
755 data = sock.recv(datalen)
756 messages.append(dns.message.from_wire(data))
757
758 except socket.timeout as e:
759 print("Timeout while receiving multiple TCP responses: %s" % (str(e)))
760 except socket.error as e:
761 print("Network error: %s" % (str(e)))
762 finally:
763 sock.close()
764
765 receivedQuery = None
766 if useQueue and not cls._fromResponderQueue.empty():
767 receivedQuery = cls._fromResponderQueue.get(True, timeout)
768 return (receivedQuery, messages)
769
770 def setUp(self):
771 # This function is called before every test
772
773 # Clear the responses counters
774 self._responsesCounter.clear()
775
776 self._healthCheckCounter = 0
777
778 # Make sure the queues are empty, in case
779 # a previous test failed
780 self.clearResponderQueues()
781
782 super(DNSDistTest, self).setUp()
783
784 @classmethod
785 def clearToResponderQueue(cls):
786 while not cls._toResponderQueue.empty():
787 cls._toResponderQueue.get(False)
788
789 @classmethod
790 def clearFromResponderQueue(cls):
791 while not cls._fromResponderQueue.empty():
792 cls._fromResponderQueue.get(False)
793
794 @classmethod
795 def clearResponderQueues(cls):
796 cls.clearToResponderQueue()
797 cls.clearFromResponderQueue()
798
799 @staticmethod
800 def generateConsoleKey():
801 return libnacl.utils.salsa_key()
802
803 @classmethod
804 def _encryptConsole(cls, command, nonce):
805 command = command.encode('UTF-8')
806 if cls._consoleKey is None:
807 return command
808 return libnacl.crypto_secretbox(command, nonce, cls._consoleKey)
809
810 @classmethod
811 def _decryptConsole(cls, command, nonce):
812 if cls._consoleKey is None:
813 result = command
814 else:
815 result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey)
816 return result.decode('UTF-8')
817
818 @classmethod
819 def sendConsoleCommand(cls, command, timeout=5.0, IPv6=False):
820 ourNonce = libnacl.utils.rand_nonce()
821 theirNonce = None
822 sock = socket.socket(socket.AF_INET if not IPv6 else socket.AF_INET6, socket.SOCK_STREAM)
823 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
824 if timeout:
825 sock.settimeout(timeout)
826
827 sock.connect(("::1", cls._consolePort, 0, 0) if IPv6 else ("127.0.0.1", cls._consolePort))
828 sock.send(ourNonce)
829 theirNonce = sock.recv(len(ourNonce))
830 if len(theirNonce) != len(ourNonce):
831 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce)))
832 if len(theirNonce) == 0:
833 raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce)))
834 return None
835
836 halfNonceSize = int(len(ourNonce) / 2)
837 readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:]
838 writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:]
839 msg = cls._encryptConsole(command, writingNonce)
840 sock.send(struct.pack("!I", len(msg)))
841 sock.send(msg)
842 data = sock.recv(4)
843 if not data:
844 raise socket.error("Got EOF while reading the response size")
845
846 (responseLen,) = struct.unpack("!I", data)
847 data = sock.recv(responseLen)
848 response = cls._decryptConsole(data, readingNonce)
849 sock.close()
850 return response
851
852 def compareOptions(self, a, b):
853 self.assertEqual(len(a), len(b))
854 for idx in range(len(a)):
855 self.assertEqual(a[idx], b[idx])
856
857 def checkMessageNoEDNS(self, expected, received):
858 self.assertEqual(expected, received)
859 self.assertEqual(received.edns, -1)
860 self.assertEqual(len(received.options), 0)
861
862 def checkMessageEDNSWithoutOptions(self, expected, received):
863 self.assertEqual(expected, received)
864 self.assertEqual(received.edns, 0)
865 self.assertEqual(expected.payload, received.payload)
866
867 def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
868 self.assertEqual(expected, received)
869 self.assertEqual(received.edns, 0)
870 self.assertEqual(expected.payload, received.payload)
871 self.assertEqual(len(received.options), withCookies)
872 if withCookies:
873 for option in received.options:
874 self.assertEqual(option.otype, 10)
875 else:
876 for option in received.options:
877 self.assertNotEqual(option.otype, 10)
878
879 def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
880 self.assertEqual(expected, received)
881 self.assertEqual(received.edns, 0)
882 self.assertEqual(expected.payload, received.payload)
883 self.assertEqual(len(received.options), 1 + additionalOptions)
884 hasECS = False
885 for option in received.options:
886 if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE:
887 hasECS = True
888 else:
889 self.assertNotEqual(additionalOptions, 0)
890
891 self.compareOptions(expected.options, received.options)
892 self.assertTrue(hasECS)
893
894 def checkMessageEDNS(self, expected, received):
895 self.assertEqual(expected, received)
896 self.assertEqual(received.edns, 0)
897 self.assertEqual(expected.payload, received.payload)
898 self.assertEqual(len(expected.options), len(received.options))
899 self.compareOptions(expected.options, received.options)
900
901 def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0):
902 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
903
904 def checkQueryEDNS(self, expected, received):
905 self.checkMessageEDNS(expected, received)
906
907 def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0):
908 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
909
910 def checkQueryEDNSWithoutECS(self, expected, received):
911 self.checkMessageEDNSWithoutECS(expected, received)
912
913 def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0):
914 self.checkMessageEDNSWithoutECS(expected, received, withCookies)
915
916 def checkQueryNoEDNS(self, expected, received):
917 self.checkMessageNoEDNS(expected, received)
918
919 def checkResponseNoEDNS(self, expected, received):
920 self.checkMessageNoEDNS(expected, received)
921
922 @staticmethod
923 def generateNewCertificateAndKey(filePrefix):
924 # generate and sign a new cert
925 cmd = ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', filePrefix + '.key', '-out', filePrefix + '.csr', '-config', 'configServer.conf']
926 output = None
927 try:
928 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
929 output = process.communicate(input='')
930 except subprocess.CalledProcessError as exc:
931 raise AssertionError('openssl req failed (%d): %s' % (exc.returncode, exc.output))
932 cmd = ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', filePrefix + '.csr', '-out', filePrefix + '.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req']
933 output = None
934 try:
935 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
936 output = process.communicate(input='')
937 except subprocess.CalledProcessError as exc:
938 raise AssertionError('openssl x509 failed (%d): %s' % (exc.returncode, exc.output))
939
940 with open(filePrefix + '.chain', 'w') as outFile:
941 for inFileName in [filePrefix + '.pem', 'ca.pem']:
942 with open(inFileName) as inFile:
943 outFile.write(inFile.read())
944
945 cmd = ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', filePrefix + '.pem', '-CAfile', 'ca.pem', '-inkey', filePrefix + '.key', '-out', filePrefix + '.p12']
946 output = None
947 try:
948 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
949 output = process.communicate(input='')
950 except subprocess.CalledProcessError as exc:
951 raise AssertionError('openssl pkcs12 failed (%d): %s' % (exc.returncode, exc.output))
952
953 def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[], v6=False, sourcePort=None, destinationPort=None):
954 proxy = ProxyProtocol()
955 self.assertTrue(proxy.parseHeader(receivedProxyPayload))
956 self.assertEqual(proxy.version, 0x02)
957 self.assertEqual(proxy.command, 0x01)
958 if v6:
959 self.assertEqual(proxy.family, 0x02)
960 else:
961 self.assertEqual(proxy.family, 0x01)
962 if not isTCP:
963 self.assertEqual(proxy.protocol, 0x02)
964 else:
965 self.assertEqual(proxy.protocol, 0x01)
966 self.assertGreater(proxy.contentLen, 0)
967
968 self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload))
969 self.assertEqual(proxy.source, source)
970 self.assertEqual(proxy.destination, destination)
971 if sourcePort:
972 self.assertEqual(proxy.sourcePort, sourcePort)
973 if destinationPort:
974 self.assertEqual(proxy.destinationPort, destinationPort)
975 else:
976 self.assertEqual(proxy.destinationPort, self._dnsDistPort)
977
978 self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload))
979 proxy.values.sort()
980 values.sort()
981 self.assertEqual(proxy.values, values)
982
983 @classmethod
984 def getDOHGetURL(cls, baseurl, query, rawQuery=False):
985 if rawQuery:
986 wire = query
987 else:
988 wire = query.to_wire()
989 param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
990 return baseurl + "?dns=" + param
991
992 @classmethod
993 def openDOHConnection(cls, port, caFile, timeout=2.0):
994 conn = pycurl.Curl()
995 conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
996
997 conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message",
998 "Accept: application/dns-message"])
999 return conn
1000
1001 @classmethod
1002 def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None, conn=None):
1003 url = cls.getDOHGetURL(baseurl, query, rawQuery)
1004
1005 if not conn:
1006 conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
1007 # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
1008 conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
1009
1010 if useHTTPS:
1011 conn.setopt(pycurl.SSL_VERIFYPEER, 1)
1012 conn.setopt(pycurl.SSL_VERIFYHOST, 2)
1013 if caFile:
1014 conn.setopt(pycurl.CAINFO, caFile)
1015
1016 response_headers = BytesIO()
1017 #conn.setopt(pycurl.VERBOSE, True)
1018 conn.setopt(pycurl.URL, url)
1019 conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
1020
1021 conn.setopt(pycurl.HTTPHEADER, customHeaders)
1022 conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
1023
1024 if response:
1025 if toQueue:
1026 toQueue.put(response, True, timeout)
1027 else:
1028 cls._toResponderQueue.put(response, True, timeout)
1029
1030 receivedQuery = None
1031 message = None
1032 cls._response_headers = ''
1033 data = conn.perform_rb()
1034 cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
1035 if cls._rcode == 200 and not rawResponse:
1036 message = dns.message.from_wire(data)
1037 elif rawResponse:
1038 message = data
1039
1040 if useQueue:
1041 if fromQueue:
1042 if not fromQueue.empty():
1043 receivedQuery = fromQueue.get(True, timeout)
1044 else:
1045 if not cls._fromResponderQueue.empty():
1046 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1047
1048 cls._response_headers = response_headers.getvalue()
1049 return (receivedQuery, message)
1050
1051 @classmethod
1052 def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
1053 url = baseurl
1054 conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
1055 response_headers = BytesIO()
1056 #conn.setopt(pycurl.VERBOSE, True)
1057 conn.setopt(pycurl.URL, url)
1058 conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
1059 # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
1060 conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
1061 if useHTTPS:
1062 conn.setopt(pycurl.SSL_VERIFYPEER, 1)
1063 conn.setopt(pycurl.SSL_VERIFYHOST, 2)
1064 if caFile:
1065 conn.setopt(pycurl.CAINFO, caFile)
1066
1067 conn.setopt(pycurl.HTTPHEADER, customHeaders)
1068 conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
1069 conn.setopt(pycurl.POST, True)
1070 data = query
1071 if not rawQuery:
1072 data = data.to_wire()
1073
1074 conn.setopt(pycurl.POSTFIELDS, data)
1075
1076 if response:
1077 cls._toResponderQueue.put(response, True, timeout)
1078
1079 receivedQuery = None
1080 message = None
1081 cls._response_headers = ''
1082 data = conn.perform_rb()
1083 cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
1084 if cls._rcode == 200 and not rawResponse:
1085 message = dns.message.from_wire(data)
1086 elif rawResponse:
1087 message = data
1088
1089 if useQueue and not cls._fromResponderQueue.empty():
1090 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1091
1092 cls._response_headers = response_headers.getvalue()
1093 return (receivedQuery, message)
1094
1095 def sendDOHQueryWrapper(self, query, response, useQueue=True):
1096 return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
1097
1098 def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True):
1099 return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
1100
1101 def sendDOHWithH2OQueryWrapper(self, query, response, useQueue=True):
1102 return self.sendDOHQuery(self._dohWithH2OServerPort, self._serverName, self._dohWithH2OBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
1103
1104 def sendDOTQueryWrapper(self, query, response, useQueue=True):
1105 return self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response, self._caCert, useQueue=useQueue)
1106
1107 def sendDOQQueryWrapper(self, query, response, useQueue=True):
1108 return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName)
1109
1110 def sendDOH3QueryWrapper(self, query, response, useQueue=True):
1111 return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName)
1112 @classmethod
1113 def getDOQConnection(cls, port, caFile=None, source=None, source_port=0):
1114
1115 manager = dns.quic.SyncQuicManager(
1116 verify_mode=caFile
1117 )
1118
1119 return manager.connect('127.0.0.1', port, source, source_port)
1120
1121 @classmethod
1122 def sendDOQQuery(cls, port, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None):
1123
1124 if response:
1125 if toQueue:
1126 toQueue.put(response, True, timeout)
1127 else:
1128 cls._toResponderQueue.put(response, True, timeout)
1129
1130 (message, _) = quic_query(query, '127.0.0.1', timeout, port, verify=caFile, server_hostname=serverName)
1131
1132 receivedQuery = None
1133
1134 if useQueue:
1135 if fromQueue:
1136 if not fromQueue.empty():
1137 receivedQuery = fromQueue.get(True, timeout)
1138 else:
1139 if not cls._fromResponderQueue.empty():
1140 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1141
1142 return (receivedQuery, message)
1143
1144 @classmethod
1145 def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False):
1146
1147 if response:
1148 if toQueue:
1149 toQueue.put(response, True, timeout)
1150 else:
1151 cls._toResponderQueue.put(response, True, timeout)
1152
1153 message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post)
1154
1155 receivedQuery = None
1156
1157 if useQueue:
1158 if fromQueue:
1159 if not fromQueue.empty():
1160 receivedQuery = fromQueue.get(True, timeout)
1161 else:
1162 if not cls._fromResponderQueue.empty():
1163 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1164
1165 return (receivedQuery, message)