]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/dnsdisttests.py
dnsdist: Reduce the timeout on Dynamic Block tests expected to fail
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
1 #!/usr/bin/env python2
2
3 import base64
4 import copy
5 import errno
6 import os
7 import socket
8 import ssl
9 import struct
10 import subprocess
11 import sys
12 import threading
13 import time
14 import unittest
15
16 import clientsubnetoption
17
18 import dns
19 import dns.message
20
21 import libnacl
22 import libnacl.utils
23
24 import h2.connection
25 import h2.events
26 import h2.config
27
28 import pycurl
29 from io import BytesIO
30
31 from doqclient import quic_query
32 from doh3client import doh3_query
33
34 from eqdnsmessage import AssertEqualDNSMessageMixin
35 from proxyprotocol import ProxyProtocol
36
37 # Python2/3 compatibility hacks
38 try:
39 from queue import Queue
40 except ImportError:
41 from Queue import Queue
42
43 try:
44 range = xrange
45 except NameError:
46 pass
47
48 def getWorkerID():
49 if not 'PYTEST_XDIST_WORKER' in os.environ:
50 return 0
51 workerName = os.environ['PYTEST_XDIST_WORKER']
52 return int(workerName[2:])
53
54 workerPorts = {}
55
56 def pickAvailablePort():
57 global workerPorts
58 workerID = getWorkerID()
59 if workerID in workerPorts:
60 port = workerPorts[workerID] + 1
61 else:
62 port = 11000 + (workerID * 1000)
63 workerPorts[workerID] = port
64 return port
65
66 class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
67 """
68 Set up a dnsdist instance and responder threads.
69 Queries sent to dnsdist are relayed to the responder threads,
70 who reply with the response provided by the tests themselves
71 on a queue. Responder threads also queue the queries received
72 from dnsdist on a separate queue, allowing the tests to check
73 that the queries sent from dnsdist were as expected.
74 """
75 _dnsDistListeningAddr = "127.0.0.1"
76 _toResponderQueue = Queue()
77 _fromResponderQueue = Queue()
78 _queueTimeout = 1
79 _dnsdist = None
80 _responsesCounter = {}
81 _config_template = """
82 """
83 _config_params = ['_testServerPort']
84 _acl = ['127.0.0.1/32']
85 _consoleKey = None
86 _healthCheckName = 'a.root-servers.net.'
87 _healthCheckCounter = 0
88 _answerUnexpected = True
89 _checkConfigExpectedOutput = None
90 _verboseMode = False
91 _sudoMode = False
92 _skipListeningOnCL = False
93 _alternateListeningAddr = None
94 _alternateListeningPort = None
95 _backgroundThreads = {}
96 _UDPResponder = None
97 _TCPResponder = None
98 _extraStartupSleep = 0
99 _dnsDistPort = pickAvailablePort()
100 _consolePort = pickAvailablePort()
101 _testServerPort = pickAvailablePort()
102
103 @classmethod
104 def waitForTCPSocket(cls, ipaddress, port):
105 for try_number in range(0, 20):
106 try:
107 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
108 sock.settimeout(1.0)
109 sock.connect((ipaddress, port))
110 sock.close()
111 return
112 except Exception as err:
113 if err.errno != errno.ECONNREFUSED:
114 print(f'Error occurred: {try_number} {err}', file=sys.stderr)
115 time.sleep(0.1)
116 # We assume the dnsdist instance does not listen. That's fine.
117
118 @classmethod
119 def startResponders(cls):
120 print("Launching responders..")
121 cls._testServerPort = pickAvailablePort()
122
123 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
124 cls._UDPResponder.daemon = True
125 cls._UDPResponder.start()
126 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
127 cls._TCPResponder.daemon = True
128 cls._TCPResponder.start()
129 cls.waitForTCPSocket("127.0.0.1", cls._testServerPort);
130
131 @classmethod
132 def startDNSDist(cls):
133 cls._dnsDistPort = pickAvailablePort()
134 cls._consolePort = pickAvailablePort()
135
136 print("Launching dnsdist..")
137 confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__))
138 params = tuple([getattr(cls, param) for param in cls._config_params])
139 print(params)
140 with open(confFile, 'w') as conf:
141 conf.write("-- Autogenerated by dnsdisttests.py\n")
142 conf.write(f"-- dnsdist will listen on {cls._dnsDistPort}")
143 conf.write(cls._config_template % params)
144 conf.write("setSecurityPollSuffix('')")
145
146 if cls._skipListeningOnCL:
147 dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile ]
148 else:
149 dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile,
150 '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
151
152 if cls._verboseMode:
153 dnsdistcmd.append('-v')
154 if cls._sudoMode:
155 if 'LD_LIBRARY_PATH' in os.environ:
156 dnsdistcmd.insert(0, 'LD_LIBRARY_PATH=' + os.environ['LD_LIBRARY_PATH'])
157 dnsdistcmd.insert(0, 'sudo')
158
159 for acl in cls._acl:
160 dnsdistcmd.extend(['--acl', acl])
161 print(' '.join(dnsdistcmd))
162
163 # validate config with --check-config, which sets client=true, possibly exposing bugs.
164 testcmd = dnsdistcmd + ['--check-config']
165 try:
166 output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
167 except subprocess.CalledProcessError as exc:
168 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
169 if cls._checkConfigExpectedOutput is not None:
170 expectedOutput = cls._checkConfigExpectedOutput
171 else:
172 expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
173 if not cls._verboseMode and output != expectedOutput:
174 raise AssertionError('dnsdist --check-config failed: %s (expected %s)' % (output, expectedOutput))
175
176 logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__))
177 with open(logFile, 'w') as fdLog:
178 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog)
179
180 if cls._alternateListeningAddr and cls._alternateListeningPort:
181 cls.waitForTCPSocket(cls._alternateListeningAddr, cls._alternateListeningPort)
182 else:
183 cls.waitForTCPSocket(cls._dnsDistListeningAddr, cls._dnsDistPort)
184
185 if cls._dnsdist.poll() is not None:
186 print(f"\n*** startDNSDist log for {logFile} ***")
187 with open(logFile, 'r') as fdLog:
188 print(fdLog.read())
189 print(f"*** End startDNSDist log for {logFile} ***")
190 raise AssertionError('%s failed (%d)' % (dnsdistcmd, cls._dnsdist.returncode))
191 time.sleep(cls._extraStartupSleep)
192
193 @classmethod
194 def setUpSockets(cls):
195 print("Setting up UDP socket..")
196 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
197 cls._sock.settimeout(2.0)
198 cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
199
200 @classmethod
201 def killProcess(cls, p):
202 # Don't try to kill it if it's already dead
203 if p.poll() is not None:
204 return
205 try:
206 p.terminate()
207 for count in range(20):
208 x = p.poll()
209 if x is not None:
210 break
211 time.sleep(0.1)
212 if x is None:
213 print("kill...", p, file=sys.stderr)
214 p.kill()
215 p.wait()
216 except OSError as e:
217 # There is a race-condition with the poll() and
218 # kill() statements, when the process is dead on the
219 # kill(), this is fine
220 if e.errno != errno.ESRCH:
221 raise
222
223 @classmethod
224 def setUpClass(cls):
225
226 cls.startResponders()
227 cls.startDNSDist()
228 cls.setUpSockets()
229
230 print("Launching tests..")
231
232 @classmethod
233 def tearDownClass(cls):
234 cls._sock.close()
235 # tell the background threads to stop, if any
236 for backgroundThread in cls._backgroundThreads:
237 cls._backgroundThreads[backgroundThread] = False
238 cls.killProcess(cls._dnsdist)
239
240 @classmethod
241 def _ResponderIncrementCounter(cls):
242 if threading.current_thread().name in cls._responsesCounter:
243 cls._responsesCounter[threading.current_thread().name] += 1
244 else:
245 cls._responsesCounter[threading.current_thread().name] = 1
246
247 @classmethod
248 def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
249 response = None
250 if len(request.question) != 1:
251 print("Skipping query with question count %d" % (len(request.question)))
252 return None
253 healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
254 if healthCheck:
255 cls._healthCheckCounter += 1
256 response = dns.message.make_response(request)
257 else:
258 cls._ResponderIncrementCounter()
259 if not fromQueue.empty():
260 toQueue.put(request, True, cls._queueTimeout)
261 response = fromQueue.get(True, cls._queueTimeout)
262 if response:
263 response = copy.copy(response)
264 response.id = request.id
265
266 if synthesize is not None:
267 response = dns.message.make_response(request)
268 response.set_rcode(synthesize)
269
270 if not response:
271 if cls._answerUnexpected:
272 response = dns.message.make_response(request)
273 response.set_rcode(dns.rcode.SERVFAIL)
274
275 return response
276
277 @classmethod
278 def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None):
279 cls._backgroundThreads[threading.get_native_id()] = True
280 # trailingDataResponse=True means "ignore trailing data".
281 # Other values are either False (meaning "raise an exception")
282 # or are interpreted as a response RCODE for queries with trailing data.
283 # callback is invoked for every -even healthcheck ones- query and should return a raw response
284 ignoreTrailing = trailingDataResponse is True
285
286 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
287 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
288 sock.bind(("127.0.0.1", port))
289 sock.settimeout(0.5)
290 while True:
291 try:
292 data, addr = sock.recvfrom(4096)
293 except socket.timeout:
294 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
295 del cls._backgroundThreads[threading.get_native_id()]
296 break
297 else:
298 continue
299
300 forceRcode = None
301 try:
302 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
303 except dns.message.TrailingJunk as e:
304 print('trailing data exception in UDPResponder')
305 if trailingDataResponse is False or forceRcode is True:
306 raise
307 print("UDP query with trailing data, synthesizing response")
308 request = dns.message.from_wire(data, ignore_trailing=True)
309 forceRcode = trailingDataResponse
310
311 wire = None
312 if callback:
313 wire = callback(request)
314 else:
315 if request.edns > 1:
316 forceRcode = dns.rcode.BADVERS
317 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
318 if response:
319 wire = response.to_wire()
320
321 if not wire:
322 continue
323
324 sock.sendto(wire, addr)
325
326 sock.close()
327
328 @classmethod
329 def handleTCPConnection(cls, conn, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, partialWrite=False):
330 ignoreTrailing = trailingDataResponse is True
331 try:
332 data = conn.recv(2)
333 except Exception as err:
334 data = None
335 print(f'Error while reading query size in TCP responder thread {err=}, {type(err)=}')
336 if not data:
337 conn.close()
338 return
339
340 (datalen,) = struct.unpack("!H", data)
341 data = conn.recv(datalen)
342 forceRcode = None
343 try:
344 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
345 except dns.message.TrailingJunk as e:
346 if trailingDataResponse is False or forceRcode is True:
347 raise
348 print("TCP query with trailing data, synthesizing response")
349 request = dns.message.from_wire(data, ignore_trailing=True)
350 forceRcode = trailingDataResponse
351
352 if callback:
353 wire = callback(request)
354 else:
355 if request.edns > 1:
356 forceRcode = dns.rcode.BADVERS
357 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
358 if response:
359 wire = response.to_wire(max_size=65535)
360
361 if not wire:
362 conn.close()
363 return
364
365 wireLen = struct.pack("!H", len(wire))
366 if partialWrite:
367 for b in wireLen:
368 conn.send(bytes([b]))
369 time.sleep(0.5)
370 else:
371 conn.send(wireLen)
372 conn.send(wire)
373
374 while multipleResponses:
375 # do not block, and stop as soon as the queue is empty, either the next response is already here or we are done
376 # otherwise we might read responses intended for the next connection
377 if fromQueue.empty():
378 break
379
380 response = fromQueue.get(False)
381 if not response:
382 break
383
384 response = copy.copy(response)
385 response.id = request.id
386 wire = response.to_wire(max_size=65535)
387 try:
388 conn.send(struct.pack("!H", len(wire)))
389 conn.send(wire)
390 except socket.error as e:
391 # some of the tests are going to close
392 # the connection on us, just deal with it
393 break
394
395 conn.close()
396
397 @classmethod
398 def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, multipleConnections=False, listeningAddr='127.0.0.1', partialWrite=False):
399 cls._backgroundThreads[threading.get_native_id()] = True
400 # trailingDataResponse=True means "ignore trailing data".
401 # Other values are either False (meaning "raise an exception")
402 # or are interpreted as a response RCODE for queries with trailing data.
403 # callback is invoked for every -even healthcheck ones- query and should return a raw response
404
405 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
406 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
407 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
408 try:
409 sock.bind((listeningAddr, port))
410 except socket.error as e:
411 print("Error binding in the TCP responder: %s" % str(e))
412 sys.exit(1)
413
414 sock.listen(100)
415 sock.settimeout(0.5)
416 if tlsContext:
417 sock = tlsContext.wrap_socket(sock, server_side=True)
418
419 while True:
420 try:
421 (conn, _) = sock.accept()
422 except ssl.SSLError:
423 continue
424 except ConnectionResetError:
425 continue
426 except socket.timeout:
427 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
428 del cls._backgroundThreads[threading.get_native_id()]
429 break
430 else:
431 continue
432
433 conn.settimeout(5.0)
434 if multipleConnections:
435 thread = threading.Thread(name='TCP Connection Handler',
436 target=cls.handleTCPConnection,
437 args=[conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, partialWrite])
438 thread.daemon = True
439 thread.start()
440 else:
441 cls.handleTCPConnection(conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, partialWrite)
442
443 sock.close()
444
445 @classmethod
446 def handleDoHConnection(cls, config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol):
447 ignoreTrailing = trailingDataResponse is True
448 try:
449 h2conn = h2.connection.H2Connection(config=config)
450 h2conn.initiate_connection()
451 conn.sendall(h2conn.data_to_send())
452 except ssl.SSLEOFError as e:
453 print("Unexpected EOF: %s" % (e))
454 return
455 except Exception as err:
456 print(f'Unexpected exception in DoH responder thread (connection init) {err=}, {type(err)=}')
457 return
458
459 dnsData = {}
460
461 if useProxyProtocol:
462 # try to read the entire Proxy Protocol header
463 proxy = ProxyProtocol()
464 header = conn.recv(proxy.HEADER_SIZE)
465 if not header:
466 print('unable to get header')
467 conn.close()
468 return
469
470 if not proxy.parseHeader(header):
471 print('unable to parse header')
472 print(header)
473 conn.close()
474 return
475
476 proxyContent = conn.recv(proxy.contentLen)
477 if not proxyContent:
478 print('unable to get content')
479 conn.close()
480 return
481
482 payload = header + proxyContent
483 toQueue.put(payload, True, cls._queueTimeout)
484
485 # be careful, HTTP/2 headers and data might be in different recv() results
486 requestHeaders = None
487 while True:
488 try:
489 data = conn.recv(65535)
490 except Exception as err:
491 data = None
492 print(f'Unexpected exception in DoH responder thread {err=}, {type(err)=}')
493 if not data:
494 break
495
496 events = h2conn.receive_data(data)
497 for event in events:
498 if isinstance(event, h2.events.RequestReceived):
499 requestHeaders = event.headers
500 if isinstance(event, h2.events.DataReceived):
501 h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id)
502 if not event.stream_id in dnsData:
503 dnsData[event.stream_id] = b''
504 dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data)
505 if event.stream_ended:
506 forceRcode = None
507 status = 200
508 try:
509 request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing)
510 except dns.message.TrailingJunk as e:
511 if trailingDataResponse is False or forceRcode is True:
512 raise
513 print("DOH query with trailing data, synthesizing response")
514 request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True)
515 forceRcode = trailingDataResponse
516
517 if callback:
518 status, wire = callback(request, requestHeaders, fromQueue, toQueue)
519 else:
520 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
521 if response:
522 wire = response.to_wire(max_size=65535)
523
524 if not wire:
525 conn.close()
526 conn = None
527 break
528
529 headers = [
530 (':status', str(status)),
531 ('content-length', str(len(wire))),
532 ('content-type', 'application/dns-message'),
533 ]
534 h2conn.send_headers(stream_id=event.stream_id, headers=headers)
535 h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True)
536
537 data_to_send = h2conn.data_to_send()
538 if data_to_send:
539 conn.sendall(data_to_send)
540
541 if conn is None:
542 break
543
544 if conn is not None:
545 conn.close()
546
547 @classmethod
548 def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False):
549 cls._backgroundThreads[threading.get_native_id()] = True
550 # trailingDataResponse=True means "ignore trailing data".
551 # Other values are either False (meaning "raise an exception")
552 # or are interpreted as a response RCODE for queries with trailing data.
553 # callback is invoked for every -even healthcheck ones- query and should return a raw response
554
555 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
556 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
557 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
558 try:
559 sock.bind(("127.0.0.1", port))
560 except socket.error as e:
561 print("Error binding in the TCP responder: %s" % str(e))
562 sys.exit(1)
563
564 sock.listen(100)
565 sock.settimeout(0.5)
566 if tlsContext:
567 sock = tlsContext.wrap_socket(sock, server_side=True)
568
569 config = h2.config.H2Configuration(client_side=False)
570
571 while True:
572 try:
573 (conn, _) = sock.accept()
574 except ssl.SSLError:
575 continue
576 except ConnectionResetError:
577 continue
578 except socket.timeout:
579 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
580 del cls._backgroundThreads[threading.get_native_id()]
581 break
582 else:
583 continue
584
585 conn.settimeout(5.0)
586 thread = threading.Thread(name='DoH Connection Handler',
587 target=cls.handleDoHConnection,
588 args=[config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol])
589 thread.daemon = True
590 thread.start()
591
592 sock.close()
593
594 @classmethod
595 def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
596 if useQueue and response is not None:
597 cls._toResponderQueue.put(response, True, timeout)
598
599 if timeout:
600 cls._sock.settimeout(timeout)
601
602 try:
603 if not rawQuery:
604 query = query.to_wire()
605 cls._sock.send(query)
606 data = cls._sock.recv(4096)
607 except socket.timeout:
608 data = None
609 finally:
610 if timeout:
611 cls._sock.settimeout(None)
612
613 receivedQuery = None
614 message = None
615 if useQueue and not cls._fromResponderQueue.empty():
616 receivedQuery = cls._fromResponderQueue.get(True, timeout)
617 if data:
618 message = dns.message.from_wire(data)
619 return (receivedQuery, message)
620
621 @classmethod
622 def openTCPConnection(cls, timeout=None, port=None):
623 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
624 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
625 if timeout:
626 sock.settimeout(timeout)
627
628 if not port:
629 port = cls._dnsDistPort
630
631 sock.connect(("127.0.0.1", port))
632 return sock
633
634 @classmethod
635 def openTLSConnection(cls, port, serverName, caCert=None, timeout=None, alpn=[]):
636 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
637 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
638 if timeout:
639 sock.settimeout(timeout)
640
641 # 2.7.9+
642 if hasattr(ssl, 'create_default_context'):
643 sslctx = ssl.create_default_context(cafile=caCert)
644 if len(alpn)> 0 and hasattr(sslctx, 'set_alpn_protocols'):
645 sslctx.set_alpn_protocols(alpn)
646 sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
647 else:
648 sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
649
650 sslsock.connect(("127.0.0.1", port))
651 return sslsock
652
653 @classmethod
654 def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0):
655 if not rawQuery:
656 wire = query.to_wire()
657 else:
658 wire = query
659
660 if response:
661 cls._toResponderQueue.put(response, True, timeout)
662
663 sock.send(struct.pack("!H", len(wire)))
664 sock.send(wire)
665
666 @classmethod
667 def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
668 message = None
669 data = sock.recv(2)
670 if data:
671 (datalen,) = struct.unpack("!H", data)
672 print(datalen)
673 data = sock.recv(datalen)
674 if data:
675 print(data)
676 message = dns.message.from_wire(data)
677
678 print(useQueue)
679 if useQueue and not cls._fromResponderQueue.empty():
680 receivedQuery = cls._fromResponderQueue.get(True, timeout)
681 print(receivedQuery)
682 return (receivedQuery, message)
683 else:
684 print("queue empty")
685 return message
686
687 @classmethod
688 def sendDOTQuery(cls, port, serverName, query, response, caFile, useQueue=True):
689 conn = cls.openTLSConnection(port, serverName, caFile)
690 cls.sendTCPQueryOverConnection(conn, query, response=response)
691 if useQueue:
692 return cls.recvTCPResponseOverConnection(conn, useQueue=useQueue)
693 return None, cls.recvTCPResponseOverConnection(conn, useQueue=useQueue)
694
695 @classmethod
696 def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
697 message = None
698 if useQueue:
699 cls._toResponderQueue.put(response, True, timeout)
700
701 try:
702 sock = cls.openTCPConnection(timeout)
703 except socket.timeout as e:
704 print("Timeout while opening TCP connection: %s" % (str(e)))
705 return (None, None)
706
707 try:
708 cls.sendTCPQueryOverConnection(sock, query, rawQuery, timeout=timeout)
709 message = cls.recvTCPResponseOverConnection(sock, timeout=timeout)
710 except socket.timeout as e:
711 print("Timeout while sending or receiving TCP data: %s" % (str(e)))
712 except socket.error as e:
713 print("Network error: %s" % (str(e)))
714 finally:
715 sock.close()
716
717 receivedQuery = None
718 print(useQueue)
719 if useQueue and not cls._fromResponderQueue.empty():
720 print(receivedQuery)
721 receivedQuery = cls._fromResponderQueue.get(True, timeout)
722 else:
723 print("queue is empty")
724
725 return (receivedQuery, message)
726
727 @classmethod
728 def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False):
729 if useQueue:
730 for response in responses:
731 cls._toResponderQueue.put(response, True, timeout)
732 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
733 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
734 if timeout:
735 sock.settimeout(timeout)
736
737 sock.connect(("127.0.0.1", cls._dnsDistPort))
738 messages = []
739
740 try:
741 if not rawQuery:
742 wire = query.to_wire()
743 else:
744 wire = query
745
746 sock.send(struct.pack("!H", len(wire)))
747 sock.send(wire)
748 while True:
749 data = sock.recv(2)
750 if not data:
751 break
752 (datalen,) = struct.unpack("!H", data)
753 data = sock.recv(datalen)
754 messages.append(dns.message.from_wire(data))
755
756 except socket.timeout as e:
757 print("Timeout while receiving multiple TCP responses: %s" % (str(e)))
758 except socket.error as e:
759 print("Network error: %s" % (str(e)))
760 finally:
761 sock.close()
762
763 receivedQuery = None
764 if useQueue and not cls._fromResponderQueue.empty():
765 receivedQuery = cls._fromResponderQueue.get(True, timeout)
766 return (receivedQuery, messages)
767
768 def setUp(self):
769 # This function is called before every test
770
771 # Clear the responses counters
772 self._responsesCounter.clear()
773
774 self._healthCheckCounter = 0
775
776 # Make sure the queues are empty, in case
777 # a previous test failed
778 self.clearResponderQueues()
779
780 super(DNSDistTest, self).setUp()
781
782 @classmethod
783 def clearToResponderQueue(cls):
784 while not cls._toResponderQueue.empty():
785 cls._toResponderQueue.get(False)
786
787 @classmethod
788 def clearFromResponderQueue(cls):
789 while not cls._fromResponderQueue.empty():
790 cls._fromResponderQueue.get(False)
791
792 @classmethod
793 def clearResponderQueues(cls):
794 cls.clearToResponderQueue()
795 cls.clearFromResponderQueue()
796
797 @staticmethod
798 def generateConsoleKey():
799 return libnacl.utils.salsa_key()
800
801 @classmethod
802 def _encryptConsole(cls, command, nonce):
803 command = command.encode('UTF-8')
804 if cls._consoleKey is None:
805 return command
806 return libnacl.crypto_secretbox(command, nonce, cls._consoleKey)
807
808 @classmethod
809 def _decryptConsole(cls, command, nonce):
810 if cls._consoleKey is None:
811 result = command
812 else:
813 result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey)
814 return result.decode('UTF-8')
815
816 @classmethod
817 def sendConsoleCommand(cls, command, timeout=5.0):
818 ourNonce = libnacl.utils.rand_nonce()
819 theirNonce = None
820 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
821 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
822 if timeout:
823 sock.settimeout(timeout)
824
825 sock.connect(("127.0.0.1", cls._consolePort))
826 sock.send(ourNonce)
827 theirNonce = sock.recv(len(ourNonce))
828 if len(theirNonce) != len(ourNonce):
829 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce)))
830 if len(theirNonce) == 0:
831 raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce)))
832 return None
833
834 halfNonceSize = int(len(ourNonce) / 2)
835 readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:]
836 writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:]
837 msg = cls._encryptConsole(command, writingNonce)
838 sock.send(struct.pack("!I", len(msg)))
839 sock.send(msg)
840 data = sock.recv(4)
841 if not data:
842 raise socket.error("Got EOF while reading the response size")
843
844 (responseLen,) = struct.unpack("!I", data)
845 data = sock.recv(responseLen)
846 response = cls._decryptConsole(data, readingNonce)
847 sock.close()
848 return response
849
850 def compareOptions(self, a, b):
851 self.assertEqual(len(a), len(b))
852 for idx in range(len(a)):
853 self.assertEqual(a[idx], b[idx])
854
855 def checkMessageNoEDNS(self, expected, received):
856 self.assertEqual(expected, received)
857 self.assertEqual(received.edns, -1)
858 self.assertEqual(len(received.options), 0)
859
860 def checkMessageEDNSWithoutOptions(self, expected, received):
861 self.assertEqual(expected, received)
862 self.assertEqual(received.edns, 0)
863 self.assertEqual(expected.payload, received.payload)
864
865 def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
866 self.assertEqual(expected, received)
867 self.assertEqual(received.edns, 0)
868 self.assertEqual(expected.payload, received.payload)
869 self.assertEqual(len(received.options), withCookies)
870 if withCookies:
871 for option in received.options:
872 self.assertEqual(option.otype, 10)
873 else:
874 for option in received.options:
875 self.assertNotEqual(option.otype, 10)
876
877 def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
878 self.assertEqual(expected, received)
879 self.assertEqual(received.edns, 0)
880 self.assertEqual(expected.payload, received.payload)
881 self.assertEqual(len(received.options), 1 + additionalOptions)
882 hasECS = False
883 for option in received.options:
884 if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE:
885 hasECS = True
886 else:
887 self.assertNotEqual(additionalOptions, 0)
888
889 self.compareOptions(expected.options, received.options)
890 self.assertTrue(hasECS)
891
892 def checkMessageEDNS(self, expected, received):
893 self.assertEqual(expected, received)
894 self.assertEqual(received.edns, 0)
895 self.assertEqual(expected.payload, received.payload)
896 self.assertEqual(len(expected.options), len(received.options))
897 self.compareOptions(expected.options, received.options)
898
899 def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0):
900 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
901
902 def checkQueryEDNS(self, expected, received):
903 self.checkMessageEDNS(expected, received)
904
905 def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0):
906 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
907
908 def checkQueryEDNSWithoutECS(self, expected, received):
909 self.checkMessageEDNSWithoutECS(expected, received)
910
911 def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0):
912 self.checkMessageEDNSWithoutECS(expected, received, withCookies)
913
914 def checkQueryNoEDNS(self, expected, received):
915 self.checkMessageNoEDNS(expected, received)
916
917 def checkResponseNoEDNS(self, expected, received):
918 self.checkMessageNoEDNS(expected, received)
919
920 @staticmethod
921 def generateNewCertificateAndKey(filePrefix):
922 # generate and sign a new cert
923 cmd = ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', filePrefix + '.key', '-out', filePrefix + '.csr', '-config', 'configServer.conf']
924 output = None
925 try:
926 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
927 output = process.communicate(input='')
928 except subprocess.CalledProcessError as exc:
929 raise AssertionError('openssl req failed (%d): %s' % (exc.returncode, exc.output))
930 cmd = ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', filePrefix + '.csr', '-out', filePrefix + '.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req']
931 output = None
932 try:
933 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
934 output = process.communicate(input='')
935 except subprocess.CalledProcessError as exc:
936 raise AssertionError('openssl x509 failed (%d): %s' % (exc.returncode, exc.output))
937
938 with open(filePrefix + '.chain', 'w') as outFile:
939 for inFileName in [filePrefix + '.pem', 'ca.pem']:
940 with open(inFileName) as inFile:
941 outFile.write(inFile.read())
942
943 cmd = ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', filePrefix + '.pem', '-CAfile', 'ca.pem', '-inkey', filePrefix + '.key', '-out', filePrefix + '.p12']
944 output = None
945 try:
946 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
947 output = process.communicate(input='')
948 except subprocess.CalledProcessError as exc:
949 raise AssertionError('openssl pkcs12 failed (%d): %s' % (exc.returncode, exc.output))
950
951 def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[], v6=False, sourcePort=None, destinationPort=None):
952 proxy = ProxyProtocol()
953 self.assertTrue(proxy.parseHeader(receivedProxyPayload))
954 self.assertEqual(proxy.version, 0x02)
955 self.assertEqual(proxy.command, 0x01)
956 if v6:
957 self.assertEqual(proxy.family, 0x02)
958 else:
959 self.assertEqual(proxy.family, 0x01)
960 if not isTCP:
961 self.assertEqual(proxy.protocol, 0x02)
962 else:
963 self.assertEqual(proxy.protocol, 0x01)
964 self.assertGreater(proxy.contentLen, 0)
965
966 self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload))
967 self.assertEqual(proxy.source, source)
968 self.assertEqual(proxy.destination, destination)
969 if sourcePort:
970 self.assertEqual(proxy.sourcePort, sourcePort)
971 if destinationPort:
972 self.assertEqual(proxy.destinationPort, destinationPort)
973 else:
974 self.assertEqual(proxy.destinationPort, self._dnsDistPort)
975
976 self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload))
977 proxy.values.sort()
978 values.sort()
979 self.assertEqual(proxy.values, values)
980
981 @classmethod
982 def getDOHGetURL(cls, baseurl, query, rawQuery=False):
983 if rawQuery:
984 wire = query
985 else:
986 wire = query.to_wire()
987 param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
988 return baseurl + "?dns=" + param
989
990 @classmethod
991 def openDOHConnection(cls, port, caFile, timeout=2.0):
992 conn = pycurl.Curl()
993 conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
994
995 conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message",
996 "Accept: application/dns-message"])
997 return conn
998
999 @classmethod
1000 def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None, conn=None):
1001 url = cls.getDOHGetURL(baseurl, query, rawQuery)
1002
1003 if not conn:
1004 conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
1005 # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
1006 conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
1007
1008 if useHTTPS:
1009 conn.setopt(pycurl.SSL_VERIFYPEER, 1)
1010 conn.setopt(pycurl.SSL_VERIFYHOST, 2)
1011 if caFile:
1012 conn.setopt(pycurl.CAINFO, caFile)
1013
1014 response_headers = BytesIO()
1015 #conn.setopt(pycurl.VERBOSE, True)
1016 conn.setopt(pycurl.URL, url)
1017 conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
1018
1019 conn.setopt(pycurl.HTTPHEADER, customHeaders)
1020 conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
1021
1022 if response:
1023 if toQueue:
1024 toQueue.put(response, True, timeout)
1025 else:
1026 cls._toResponderQueue.put(response, True, timeout)
1027
1028 receivedQuery = None
1029 message = None
1030 cls._response_headers = ''
1031 data = conn.perform_rb()
1032 cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
1033 if cls._rcode == 200 and not rawResponse:
1034 message = dns.message.from_wire(data)
1035 elif rawResponse:
1036 message = data
1037
1038 if useQueue:
1039 if fromQueue:
1040 if not fromQueue.empty():
1041 receivedQuery = fromQueue.get(True, timeout)
1042 else:
1043 if not cls._fromResponderQueue.empty():
1044 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1045
1046 cls._response_headers = response_headers.getvalue()
1047 return (receivedQuery, message)
1048
1049 @classmethod
1050 def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
1051 url = baseurl
1052 conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
1053 response_headers = BytesIO()
1054 #conn.setopt(pycurl.VERBOSE, True)
1055 conn.setopt(pycurl.URL, url)
1056 conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
1057 # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
1058 conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
1059 if useHTTPS:
1060 conn.setopt(pycurl.SSL_VERIFYPEER, 1)
1061 conn.setopt(pycurl.SSL_VERIFYHOST, 2)
1062 if caFile:
1063 conn.setopt(pycurl.CAINFO, caFile)
1064
1065 conn.setopt(pycurl.HTTPHEADER, customHeaders)
1066 conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
1067 conn.setopt(pycurl.POST, True)
1068 data = query
1069 if not rawQuery:
1070 data = data.to_wire()
1071
1072 conn.setopt(pycurl.POSTFIELDS, data)
1073
1074 if response:
1075 cls._toResponderQueue.put(response, True, timeout)
1076
1077 receivedQuery = None
1078 message = None
1079 cls._response_headers = ''
1080 data = conn.perform_rb()
1081 cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
1082 if cls._rcode == 200 and not rawResponse:
1083 message = dns.message.from_wire(data)
1084 elif rawResponse:
1085 message = data
1086
1087 if useQueue and not cls._fromResponderQueue.empty():
1088 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1089
1090 cls._response_headers = response_headers.getvalue()
1091 return (receivedQuery, message)
1092
1093 def sendDOHQueryWrapper(self, query, response, useQueue=True):
1094 return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
1095
1096 def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True):
1097 return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
1098
1099 def sendDOHWithH2OQueryWrapper(self, query, response, useQueue=True):
1100 return self.sendDOHQuery(self._dohWithH2OServerPort, self._serverName, self._dohWithH2OBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
1101
1102 def sendDOTQueryWrapper(self, query, response, useQueue=True):
1103 return self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response, self._caCert, useQueue=useQueue)
1104
1105 def sendDOQQueryWrapper(self, query, response, useQueue=True):
1106 return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName)
1107
1108 def sendDOH3QueryWrapper(self, query, response, useQueue=True):
1109 return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName)
1110 @classmethod
1111 def getDOQConnection(cls, port, caFile=None, source=None, source_port=0):
1112
1113 manager = dns.quic.SyncQuicManager(
1114 verify_mode=caFile
1115 )
1116
1117 return manager.connect('127.0.0.1', port, source, source_port)
1118
1119 @classmethod
1120 def sendDOQQuery(cls, port, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None):
1121
1122 if response:
1123 if toQueue:
1124 toQueue.put(response, True, timeout)
1125 else:
1126 cls._toResponderQueue.put(response, True, timeout)
1127
1128 message = quic_query(query, '127.0.0.1', timeout, port, verify=caFile, server_hostname=serverName)
1129
1130 receivedQuery = None
1131
1132 if useQueue:
1133 if fromQueue:
1134 if not fromQueue.empty():
1135 receivedQuery = fromQueue.get(True, timeout)
1136 else:
1137 if not cls._fromResponderQueue.empty():
1138 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1139
1140 return (receivedQuery, message)
1141
1142 @classmethod
1143 def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False):
1144
1145 if response:
1146 if toQueue:
1147 toQueue.put(response, True, timeout)
1148 else:
1149 cls._toResponderQueue.put(response, True, timeout)
1150
1151 message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post)
1152
1153 receivedQuery = None
1154
1155 if useQueue:
1156 if fromQueue:
1157 if not fromQueue.empty():
1158 receivedQuery = fromQueue.get(True, timeout)
1159 else:
1160 if not cls._fromResponderQueue.empty():
1161 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1162
1163 return (receivedQuery, message)