]> git.ipfire.org Git - thirdparty/pdns.git/blame - regression-tests.dnsdist/dnsdisttests.py
dnsdist: Add a regression test for DoH connection counters
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
CommitLineData
ca404e94
RG
1#!/usr/bin/env python2
2
fda32c1c 3import base64
95f0b802 4import copy
ffabdc3e 5import errno
ca404e94
RG
6import os
7import socket
a227f47d 8import ssl
ca404e94
RG
9import struct
10import subprocess
11import sys
12import threading
13import time
14import unittest
9d71a0cf 15
5df86a8a 16import clientsubnetoption
9d71a0cf 17
b1bec9f0
RG
18import dns
19import dns.message
9d71a0cf 20
1ea747c0
RG
21import libnacl
22import libnacl.utils
ca404e94 23
9d71a0cf
RG
24import h2.connection
25import h2.events
26import h2.config
27
fda32c1c
RG
28import pycurl
29from io import BytesIO
30
e7000cce 31from doqclient import quic_query
4f0b10a9 32from doh3client import doh3_query
e7000cce 33
6bd430bf 34from eqdnsmessage import AssertEqualDNSMessageMixin
0e6892c6 35from proxyprotocol import ProxyProtocol
6bd430bf 36
b4f23783 37# Python2/3 compatibility hacks
7a0ea291 38try:
39 from queue import Queue
40except ImportError:
b4f23783 41 from Queue import Queue
7a0ea291 42
43try:
b4f23783 44 range = xrange
7a0ea291 45except NameError:
46 pass
b4f23783 47
630eb526
RG
48def getWorkerID():
49 if not 'PYTEST_XDIST_WORKER' in os.environ:
50 return 0
51 workerName = os.environ['PYTEST_XDIST_WORKER']
52 return int(workerName[2:])
53
54workerPorts = {}
55
56def pickAvailablePort():
57 global workerPorts
58 workerID = getWorkerID()
59 if workerID in workerPorts:
60 port = workerPorts[workerID] + 1
61 else:
62 port = 11000 + (workerID * 1000)
63 workerPorts[workerID] = port
64 return port
b4f23783 65
6bd430bf 66class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
ca404e94
RG
67 """
68 Set up a dnsdist instance and responder threads.
69 Queries sent to dnsdist are relayed to the responder threads,
70 who reply with the response provided by the tests themselves
71 on a queue. Responder threads also queue the queries received
72 from dnsdist on a separate queue, allowing the tests to check
73 that the queries sent from dnsdist were as expected.
74 """
b052847c 75 _dnsDistListeningAddr = "127.0.0.1"
b4f23783
CH
76 _toResponderQueue = Queue()
77 _fromResponderQueue = Queue()
617dfe22 78 _queueTimeout = 1
ca404e94 79 _dnsdist = None
ec5f5c6b 80 _responsesCounter = {}
18a0e7c6 81 _config_template = """
18a0e7c6
CH
82 """
83 _config_params = ['_testServerPort']
84 _acl = ['127.0.0.1/32']
1ea747c0 85 _consoleKey = None
98650fde
RG
86 _healthCheckName = 'a.root-servers.net.'
87 _healthCheckCounter = 0
e44df0f1 88 _answerUnexpected = True
f73ce0e3 89 _checkConfigExpectedOutput = None
2a3cafcd 90 _verboseMode = False
9bf515e1 91 _sudoMode = False
db7acdaf 92 _skipListeningOnCL = False
1953ab6c
RG
93 _alternateListeningAddr = None
94 _alternateListeningPort = None
7373e3a6
RG
95 _backgroundThreads = {}
96 _UDPResponder = None
97 _TCPResponder = None
b38aaeb5 98 _extraStartupSleep = 0
630eb526
RG
99 _dnsDistPort = pickAvailablePort()
100 _consolePort = pickAvailablePort()
101 _testServerPort = pickAvailablePort()
ca404e94 102
ffabdc3e
OM
103 @classmethod
104 def waitForTCPSocket(cls, ipaddress, port):
105 for try_number in range(0, 20):
106 try:
107 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
108 sock.settimeout(1.0)
109 sock.connect((ipaddress, port))
110 sock.close()
111 return
112 except Exception as err:
113 if err.errno != errno.ECONNREFUSED:
114 print(f'Error occurred: {try_number} {err}', file=sys.stderr)
115 time.sleep(0.1)
116 # We assume the dnsdist instance does not listen. That's fine.
117
ca404e94
RG
118 @classmethod
119 def startResponders(cls):
120 print("Launching responders..")
630eb526 121 cls._testServerPort = pickAvailablePort()
ec5f5c6b 122
5df86a8a 123 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
630eb526 124 cls._UDPResponder.daemon = True
ca404e94 125 cls._UDPResponder.start()
5df86a8a 126 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
630eb526 127 cls._TCPResponder.daemon = True
ca404e94 128 cls._TCPResponder.start()
d53e77f2 129 cls.waitForTCPSocket("127.0.0.1", cls._testServerPort);
ca404e94
RG
130
131 @classmethod
aac2124b 132 def startDNSDist(cls):
630eb526
RG
133 cls._dnsDistPort = pickAvailablePort()
134 cls._consolePort = pickAvailablePort()
135
ca404e94 136 print("Launching dnsdist..")
aac2124b 137 confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__))
18a0e7c6
CH
138 params = tuple([getattr(cls, param) for param in cls._config_params])
139 print(params)
aac2124b 140 with open(confFile, 'w') as conf:
18a0e7c6 141 conf.write("-- Autogenerated by dnsdisttests.py\n")
630eb526 142 conf.write(f"-- dnsdist will listen on {cls._dnsDistPort}")
18a0e7c6 143 conf.write(cls._config_template % params)
f3853a40 144 conf.write("setSecurityPollSuffix('')")
18a0e7c6 145
db7acdaf
RG
146 if cls._skipListeningOnCL:
147 dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile ]
148 else:
149 dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile,
150 '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
151
2a3cafcd
RG
152 if cls._verboseMode:
153 dnsdistcmd.append('-v')
9bf515e1 154 if cls._sudoMode:
e4b04b6e
RG
155 preserve_env_values = ['LD_LIBRARY_PATH', 'LLVM_PROFILE_FILE']
156 for value in preserve_env_values:
157 if value in os.environ:
158 dnsdistcmd.insert(0, value + '=' + os.environ[value])
9bf515e1 159 dnsdistcmd.insert(0, 'sudo')
2a3cafcd 160
18a0e7c6
CH
161 for acl in cls._acl:
162 dnsdistcmd.extend(['--acl', acl])
163 print(' '.join(dnsdistcmd))
164
6b44773a
CH
165 # validate config with --check-config, which sets client=true, possibly exposing bugs.
166 testcmd = dnsdistcmd + ['--check-config']
ff0bc6a6
JS
167 try:
168 output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True)
169 except subprocess.CalledProcessError as exc:
170 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output))
f73ce0e3
RG
171 if cls._checkConfigExpectedOutput is not None:
172 expectedOutput = cls._checkConfigExpectedOutput
173 else:
174 expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode()
2a3cafcd 175 if not cls._verboseMode and output != expectedOutput:
630eb526 176 raise AssertionError('dnsdist --check-config failed: %s (expected %s)' % (output, expectedOutput))
6b44773a 177
aac2124b
RG
178 logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__))
179 with open(logFile, 'w') as fdLog:
180 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog)
ca404e94 181
1953ab6c
RG
182 if cls._alternateListeningAddr and cls._alternateListeningPort:
183 cls.waitForTCPSocket(cls._alternateListeningAddr, cls._alternateListeningPort)
184 else:
185 cls.waitForTCPSocket(cls._dnsDistListeningAddr, cls._dnsDistPort)
ca404e94
RG
186
187 if cls._dnsdist.poll() is not None:
ffabdc3e
OM
188 print(f"\n*** startDNSDist log for {logFile} ***")
189 with open(logFile, 'r') as fdLog:
190 print(fdLog.read())
191 print(f"*** End startDNSDist log for {logFile} ***")
192 raise AssertionError('%s failed (%d)' % (dnsdistcmd, cls._dnsdist.returncode))
b38aaeb5 193 time.sleep(cls._extraStartupSleep)
ca404e94
RG
194
195 @classmethod
196 def setUpSockets(cls):
197 print("Setting up UDP socket..")
198 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1ade83b2 199 cls._sock.settimeout(2.0)
ca404e94
RG
200 cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
201
e4284d05
OM
202 @classmethod
203 def killProcess(cls, p):
204 # Don't try to kill it if it's already dead
205 if p.poll() is not None:
206 return
207 try:
208 p.terminate()
6a51a279 209 for count in range(20):
e4284d05
OM
210 x = p.poll()
211 if x is not None:
212 break
213 time.sleep(0.1)
214 if x is None:
215 print("kill...", p, file=sys.stderr)
216 p.kill()
217 p.wait()
218 except OSError as e:
219 # There is a race-condition with the poll() and
220 # kill() statements, when the process is dead on the
221 # kill(), this is fine
222 if e.errno != errno.ESRCH:
223 raise
224
ca404e94
RG
225 @classmethod
226 def setUpClass(cls):
227
228 cls.startResponders()
aac2124b 229 cls.startDNSDist()
ca404e94
RG
230 cls.setUpSockets()
231
232 print("Launching tests..")
233
234 @classmethod
235 def tearDownClass(cls):
ffabdc3e
OM
236 cls._sock.close()
237 # tell the background threads to stop, if any
238 for backgroundThread in cls._backgroundThreads:
239 cls._backgroundThreads[backgroundThread] = False
e4284d05 240 cls.killProcess(cls._dnsdist)
7373e3a6 241
ca404e94 242 @classmethod
fe1c60f2 243 def _ResponderIncrementCounter(cls):
630eb526
RG
244 if threading.current_thread().name in cls._responsesCounter:
245 cls._responsesCounter[threading.current_thread().name] += 1
ec5f5c6b 246 else:
630eb526 247 cls._responsesCounter[threading.current_thread().name] = 1
ec5f5c6b 248
fe1c60f2 249 @classmethod
4aa08b62 250 def _getResponse(cls, request, fromQueue, toQueue, synthesize=None):
fe1c60f2
RG
251 response = None
252 if len(request.question) != 1:
253 print("Skipping query with question count %d" % (len(request.question)))
254 return None
98650fde
RG
255 healthCheck = str(request.question[0].name).endswith(cls._healthCheckName)
256 if healthCheck:
257 cls._healthCheckCounter += 1
4aa08b62 258 response = dns.message.make_response(request)
98650fde 259 else:
fe1c60f2 260 cls._ResponderIncrementCounter()
5df86a8a 261 if not fromQueue.empty():
4aa08b62 262 toQueue.put(request, True, cls._queueTimeout)
90186270
RG
263 response = fromQueue.get(True, cls._queueTimeout)
264 if response:
265 response = copy.copy(response)
266 response.id = request.id
267
268 if synthesize is not None:
269 response = dns.message.make_response(request)
270 response.set_rcode(synthesize)
fe1c60f2 271
e44df0f1 272 if not response:
90186270 273 if cls._answerUnexpected:
e44df0f1
RG
274 response = dns.message.make_response(request)
275 response.set_rcode(dns.rcode.SERVFAIL)
fe1c60f2
RG
276
277 return response
278
ec5f5c6b 279 @classmethod
a620f197 280 def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None):
7373e3a6 281 cls._backgroundThreads[threading.get_native_id()] = True
3ef7ab0d
RG
282 # trailingDataResponse=True means "ignore trailing data".
283 # Other values are either False (meaning "raise an exception")
284 # or are interpreted as a response RCODE for queries with trailing data.
a620f197 285 # callback is invoked for every -even healthcheck ones- query and should return a raw response
4aa08b62 286 ignoreTrailing = trailingDataResponse is True
3ef7ab0d 287
ca404e94
RG
288 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
289 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
ec5f5c6b 290 sock.bind(("127.0.0.1", port))
28d4c42d 291 sock.settimeout(0.5)
ca404e94 292 while True:
7373e3a6
RG
293 try:
294 data, addr = sock.recvfrom(4096)
295 except socket.timeout:
296 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
297 del cls._backgroundThreads[threading.get_native_id()]
298 break
299 else:
300 continue
301
4aa08b62
RG
302 forceRcode = None
303 try:
304 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
305 except dns.message.TrailingJunk as e:
51f07ad4 306 print('trailing data exception in UDPResponder')
3ef7ab0d 307 if trailingDataResponse is False or forceRcode is True:
4aa08b62
RG
308 raise
309 print("UDP query with trailing data, synthesizing response")
310 request = dns.message.from_wire(data, ignore_trailing=True)
311 forceRcode = trailingDataResponse
312
f3913dd2 313 wire = None
a620f197
RG
314 if callback:
315 wire = callback(request)
316 else:
f8662974
RG
317 if request.edns > 1:
318 forceRcode = dns.rcode.BADVERS
a620f197
RG
319 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
320 if response:
321 wire = response.to_wire()
87c605c4 322
f3913dd2
RG
323 if not wire:
324 continue
325
a620f197 326 sock.sendto(wire, addr)
7373e3a6 327
ca404e94
RG
328 sock.close()
329
330 @classmethod
e572dbf5 331 def handleTCPConnection(cls, conn, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, partialWrite=False):
645a1ca4 332 ignoreTrailing = trailingDataResponse is True
fabe8e3a
RG
333 try:
334 data = conn.recv(2)
335 except Exception as err:
336 data = None
337 print(f'Error while reading query size in TCP responder thread {err=}, {type(err)=}')
645a1ca4
RG
338 if not data:
339 conn.close()
340 return
341
342 (datalen,) = struct.unpack("!H", data)
343 data = conn.recv(datalen)
344 forceRcode = None
345 try:
346 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
347 except dns.message.TrailingJunk as e:
348 if trailingDataResponse is False or forceRcode is True:
349 raise
350 print("TCP query with trailing data, synthesizing response")
351 request = dns.message.from_wire(data, ignore_trailing=True)
352 forceRcode = trailingDataResponse
353
354 if callback:
355 wire = callback(request)
356 else:
f8662974
RG
357 if request.edns > 1:
358 forceRcode = dns.rcode.BADVERS
645a1ca4
RG
359 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
360 if response:
361 wire = response.to_wire(max_size=65535)
362
363 if not wire:
364 conn.close()
365 return
366
e572dbf5
RG
367 wireLen = struct.pack("!H", len(wire))
368 if partialWrite:
369 for b in wireLen:
370 conn.send(bytes([b]))
371 time.sleep(0.5)
372 else:
373 conn.send(wireLen)
645a1ca4
RG
374 conn.send(wire)
375
376 while multipleResponses:
936dd73c
RG
377 # do not block, and stop as soon as the queue is empty, either the next response is already here or we are done
378 # otherwise we might read responses intended for the next connection
645a1ca4
RG
379 if fromQueue.empty():
380 break
381
936dd73c 382 response = fromQueue.get(False)
645a1ca4
RG
383 if not response:
384 break
385
386 response = copy.copy(response)
387 response.id = request.id
388 wire = response.to_wire(max_size=65535)
389 try:
390 conn.send(struct.pack("!H", len(wire)))
391 conn.send(wire)
392 except socket.error as e:
393 # some of the tests are going to close
394 # the connection on us, just deal with it
395 break
396
397 conn.close()
398
399 @classmethod
e572dbf5 400 def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, multipleConnections=False, listeningAddr='127.0.0.1', partialWrite=False):
7373e3a6 401 cls._backgroundThreads[threading.get_native_id()] = True
3ef7ab0d
RG
402 # trailingDataResponse=True means "ignore trailing data".
403 # Other values are either False (meaning "raise an exception")
404 # or are interpreted as a response RCODE for queries with trailing data.
a620f197 405 # callback is invoked for every -even healthcheck ones- query and should return a raw response
3ef7ab0d 406
ca404e94 407 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
501af9ae 408 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
ca404e94
RG
409 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
410 try:
144eebeb 411 sock.bind((listeningAddr, port))
ca404e94
RG
412 except socket.error as e:
413 print("Error binding in the TCP responder: %s" % str(e))
414 sys.exit(1)
415
416 sock.listen(100)
28d4c42d 417 sock.settimeout(0.5)
7d2856a6
RG
418 if tlsContext:
419 sock = tlsContext.wrap_socket(sock, server_side=True)
420
ca404e94 421 while True:
7d2856a6
RG
422 try:
423 (conn, _) = sock.accept()
424 except ssl.SSLError:
425 continue
ae3b96d9
RG
426 except ConnectionResetError:
427 continue
7373e3a6
RG
428 except socket.timeout:
429 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
430 del cls._backgroundThreads[threading.get_native_id()]
431 break
432 else:
433 continue
ae3b96d9 434
6ac8517d 435 conn.settimeout(5.0)
645a1ca4
RG
436 if multipleConnections:
437 thread = threading.Thread(name='TCP Connection Handler',
438 target=cls.handleTCPConnection,
e572dbf5 439 args=[conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, partialWrite])
630eb526 440 thread.daemon = True
645a1ca4 441 thread.start()
a620f197 442 else:
e572dbf5 443 cls.handleTCPConnection(conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, partialWrite)
548c8b66 444
ca404e94
RG
445 sock.close()
446
c4c72a2c
RG
447 @classmethod
448 def handleDoHConnection(cls, config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol):
449 ignoreTrailing = trailingDataResponse is True
144eebeb
RG
450 try:
451 h2conn = h2.connection.H2Connection(config=config)
452 h2conn.initiate_connection()
453 conn.sendall(h2conn.data_to_send())
454 except ssl.SSLEOFError as e:
455 print("Unexpected EOF: %s" % (e))
456 return
fabe8e3a
RG
457 except Exception as err:
458 print(f'Unexpected exception in DoH responder thread (connection init) {err=}, {type(err)=}')
459 return
144eebeb 460
c4c72a2c
RG
461 dnsData = {}
462
463 if useProxyProtocol:
464 # try to read the entire Proxy Protocol header
465 proxy = ProxyProtocol()
466 header = conn.recv(proxy.HEADER_SIZE)
467 if not header:
468 print('unable to get header')
469 conn.close()
470 return
471
472 if not proxy.parseHeader(header):
473 print('unable to parse header')
474 print(header)
475 conn.close()
476 return
477
478 proxyContent = conn.recv(proxy.contentLen)
479 if not proxyContent:
480 print('unable to get content')
481 conn.close()
482 return
483
484 payload = header + proxyContent
485 toQueue.put(payload, True, cls._queueTimeout)
486
487 # be careful, HTTP/2 headers and data might be in different recv() results
488 requestHeaders = None
489 while True:
fabe8e3a
RG
490 try:
491 data = conn.recv(65535)
492 except Exception as err:
493 data = None
494 print(f'Unexpected exception in DoH responder thread {err=}, {type(err)=}')
c4c72a2c
RG
495 if not data:
496 break
497
498 events = h2conn.receive_data(data)
499 for event in events:
500 if isinstance(event, h2.events.RequestReceived):
501 requestHeaders = event.headers
502 if isinstance(event, h2.events.DataReceived):
503 h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id)
504 if not event.stream_id in dnsData:
505 dnsData[event.stream_id] = b''
506 dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data)
507 if event.stream_ended:
508 forceRcode = None
509 status = 200
510 try:
511 request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing)
512 except dns.message.TrailingJunk as e:
513 if trailingDataResponse is False or forceRcode is True:
514 raise
515 print("DOH query with trailing data, synthesizing response")
516 request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True)
517 forceRcode = trailingDataResponse
518
519 if callback:
520 status, wire = callback(request, requestHeaders, fromQueue, toQueue)
521 else:
522 response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
523 if response:
524 wire = response.to_wire(max_size=65535)
525
526 if not wire:
527 conn.close()
528 conn = None
529 break
530
531 headers = [
532 (':status', str(status)),
533 ('content-length', str(len(wire))),
534 ('content-type', 'application/dns-message'),
535 ]
536 h2conn.send_headers(stream_id=event.stream_id, headers=headers)
537 h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True)
538
539 data_to_send = h2conn.data_to_send()
540 if data_to_send:
541 conn.sendall(data_to_send)
542
543 if conn is None:
544 break
545
546 if conn is not None:
547 conn.close()
548
9d71a0cf 549 @classmethod
0e6892c6 550 def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False):
7373e3a6 551 cls._backgroundThreads[threading.get_native_id()] = True
9d71a0cf
RG
552 # trailingDataResponse=True means "ignore trailing data".
553 # Other values are either False (meaning "raise an exception")
554 # or are interpreted as a response RCODE for queries with trailing data.
555 # callback is invoked for every -even healthcheck ones- query and should return a raw response
9d71a0cf
RG
556
557 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
558 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
559 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
560 try:
561 sock.bind(("127.0.0.1", port))
562 except socket.error as e:
563 print("Error binding in the TCP responder: %s" % str(e))
564 sys.exit(1)
565
566 sock.listen(100)
28d4c42d 567 sock.settimeout(0.5)
9d71a0cf
RG
568 if tlsContext:
569 sock = tlsContext.wrap_socket(sock, server_side=True)
570
571 config = h2.config.H2Configuration(client_side=False)
572
573 while True:
574 try:
575 (conn, _) = sock.accept()
576 except ssl.SSLError:
577 continue
578 except ConnectionResetError:
579 continue
7373e3a6
RG
580 except socket.timeout:
581 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
582 del cls._backgroundThreads[threading.get_native_id()]
583 break
584 else:
585 continue
ae3b96d9 586
9d71a0cf 587 conn.settimeout(5.0)
c4c72a2c
RG
588 thread = threading.Thread(name='DoH Connection Handler',
589 target=cls.handleDoHConnection,
590 args=[config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol])
630eb526 591 thread.daemon = True
c4c72a2c 592 thread.start()
9d71a0cf
RG
593
594 sock.close()
595
ca404e94 596 @classmethod
55baa1f2 597 def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
90186270 598 if useQueue and response is not None:
617dfe22 599 cls._toResponderQueue.put(response, True, timeout)
ca404e94
RG
600
601 if timeout:
602 cls._sock.settimeout(timeout)
603
604 try:
55baa1f2
RG
605 if not rawQuery:
606 query = query.to_wire()
607 cls._sock.send(query)
ca404e94 608 data = cls._sock.recv(4096)
b1bec9f0 609 except socket.timeout:
ca404e94
RG
610 data = None
611 finally:
612 if timeout:
613 cls._sock.settimeout(None)
614
615 receivedQuery = None
616 message = None
617 if useQueue and not cls._fromResponderQueue.empty():
617dfe22 618 receivedQuery = cls._fromResponderQueue.get(True, timeout)
ca404e94
RG
619 if data:
620 message = dns.message.from_wire(data)
621 return (receivedQuery, message)
622
623 @classmethod
db7acdaf 624 def openTCPConnection(cls, timeout=None, port=None):
ca404e94 625 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
501af9ae 626 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
ca404e94
RG
627 if timeout:
628 sock.settimeout(timeout)
629
db7acdaf
RG
630 if not port:
631 port = cls._dnsDistPort
632
633 sock.connect(("127.0.0.1", port))
9396d955 634 return sock
0a2087eb 635
9396d955 636 @classmethod
7e8a05fa 637 def openTLSConnection(cls, port, serverName, caCert=None, timeout=None, alpn=[]):
a227f47d 638 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
501af9ae 639 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
a227f47d
RG
640 if timeout:
641 sock.settimeout(timeout)
642
643 # 2.7.9+
644 if hasattr(ssl, 'create_default_context'):
645 sslctx = ssl.create_default_context(cafile=caCert)
7e8a05fa
RG
646 if len(alpn)> 0 and hasattr(sslctx, 'set_alpn_protocols'):
647 sslctx.set_alpn_protocols(alpn)
a227f47d
RG
648 sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
649 else:
650 sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
651
652 sslsock.connect(("127.0.0.1", port))
653 return sslsock
654
655 @classmethod
656 def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0):
9396d955
RG
657 if not rawQuery:
658 wire = query.to_wire()
659 else:
660 wire = query
55baa1f2 661
a227f47d
RG
662 if response:
663 cls._toResponderQueue.put(response, True, timeout)
664
9396d955
RG
665 sock.send(struct.pack("!H", len(wire)))
666 sock.send(wire)
667
668 @classmethod
a227f47d 669 def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
9396d955
RG
670 message = None
671 data = sock.recv(2)
672 if data:
673 (datalen,) = struct.unpack("!H", data)
f05cd66c 674 print(datalen)
9396d955 675 data = sock.recv(datalen)
ca404e94 676 if data:
f05cd66c 677 print(data)
9396d955 678 message = dns.message.from_wire(data)
a227f47d 679
f05cd66c 680 print(useQueue)
a227f47d
RG
681 if useQueue and not cls._fromResponderQueue.empty():
682 receivedQuery = cls._fromResponderQueue.get(True, timeout)
f05cd66c 683 print(receivedQuery)
a227f47d
RG
684 return (receivedQuery, message)
685 else:
f05cd66c 686 print("queue empty")
a227f47d 687 return message
9396d955 688
fda32c1c
RG
689 @classmethod
690 def sendDOTQuery(cls, port, serverName, query, response, caFile, useQueue=True):
691 conn = cls.openTLSConnection(port, serverName, caFile)
692 cls.sendTCPQueryOverConnection(conn, query, response=response)
693 if useQueue:
694 return cls.recvTCPResponseOverConnection(conn, useQueue=useQueue)
695 return None, cls.recvTCPResponseOverConnection(conn, useQueue=useQueue)
696
9396d955
RG
697 @classmethod
698 def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
699 message = None
700 if useQueue:
701 cls._toResponderQueue.put(response, True, timeout)
702
9bf515e1
RG
703 try:
704 sock = cls.openTCPConnection(timeout)
705 except socket.timeout as e:
706 print("Timeout while opening TCP connection: %s" % (str(e)))
707 return (None, None)
9396d955
RG
708
709 try:
110624de
RG
710 cls.sendTCPQueryOverConnection(sock, query, rawQuery, timeout=timeout)
711 message = cls.recvTCPResponseOverConnection(sock, timeout=timeout)
ca404e94 712 except socket.timeout as e:
0e6892c6 713 print("Timeout while sending or receiving TCP data: %s" % (str(e)))
ca404e94
RG
714 except socket.error as e:
715 print("Network error: %s" % (str(e)))
ca404e94
RG
716 finally:
717 sock.close()
718
719 receivedQuery = None
f05cd66c 720 print(useQueue)
ca404e94 721 if useQueue and not cls._fromResponderQueue.empty():
f05cd66c 722 print(receivedQuery)
617dfe22 723 receivedQuery = cls._fromResponderQueue.get(True, timeout)
f05cd66c
RG
724 else:
725 print("queue is empty")
9396d955 726
ca404e94 727 return (receivedQuery, message)
617dfe22 728
548c8b66
RG
729 @classmethod
730 def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False):
731 if useQueue:
732 for response in responses:
733 cls._toResponderQueue.put(response, True, timeout)
734 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
501af9ae 735 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
548c8b66
RG
736 if timeout:
737 sock.settimeout(timeout)
738
739 sock.connect(("127.0.0.1", cls._dnsDistPort))
740 messages = []
741
742 try:
743 if not rawQuery:
744 wire = query.to_wire()
745 else:
746 wire = query
747
748 sock.send(struct.pack("!H", len(wire)))
749 sock.send(wire)
750 while True:
751 data = sock.recv(2)
752 if not data:
753 break
754 (datalen,) = struct.unpack("!H", data)
755 data = sock.recv(datalen)
756 messages.append(dns.message.from_wire(data))
757
758 except socket.timeout as e:
0e6892c6 759 print("Timeout while receiving multiple TCP responses: %s" % (str(e)))
548c8b66
RG
760 except socket.error as e:
761 print("Network error: %s" % (str(e)))
762 finally:
763 sock.close()
764
765 receivedQuery = None
766 if useQueue and not cls._fromResponderQueue.empty():
767 receivedQuery = cls._fromResponderQueue.get(True, timeout)
768 return (receivedQuery, messages)
769
617dfe22 770 def setUp(self):
936dd73c 771 # This function is called before every test
617dfe22
RG
772
773 # Clear the responses counters
fd71df4e 774 self._responsesCounter.clear()
617dfe22 775
98650fde
RG
776 self._healthCheckCounter = 0
777
617dfe22
RG
778 # Make sure the queues are empty, in case
779 # a previous test failed
936dd73c 780 self.clearResponderQueues()
1ea747c0 781
6bd430bf
PD
782 super(DNSDistTest, self).setUp()
783
3bef39c3
RG
784 @classmethod
785 def clearToResponderQueue(cls):
786 while not cls._toResponderQueue.empty():
787 cls._toResponderQueue.get(False)
788
789 @classmethod
790 def clearFromResponderQueue(cls):
791 while not cls._fromResponderQueue.empty():
792 cls._fromResponderQueue.get(False)
793
794 @classmethod
795 def clearResponderQueues(cls):
796 cls.clearToResponderQueue()
797 cls.clearFromResponderQueue()
798
1ea747c0
RG
799 @staticmethod
800 def generateConsoleKey():
801 return libnacl.utils.salsa_key()
802
803 @classmethod
804 def _encryptConsole(cls, command, nonce):
b4f23783 805 command = command.encode('UTF-8')
1ea747c0
RG
806 if cls._consoleKey is None:
807 return command
808 return libnacl.crypto_secretbox(command, nonce, cls._consoleKey)
809
810 @classmethod
811 def _decryptConsole(cls, command, nonce):
812 if cls._consoleKey is None:
b4f23783
CH
813 result = command
814 else:
815 result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey)
816 return result.decode('UTF-8')
1ea747c0
RG
817
818 @classmethod
8be2b867 819 def sendConsoleCommand(cls, command, timeout=5.0, IPv6=False):
1ea747c0
RG
820 ourNonce = libnacl.utils.rand_nonce()
821 theirNonce = None
8be2b867 822 sock = socket.socket(socket.AF_INET if not IPv6 else socket.AF_INET6, socket.SOCK_STREAM)
501af9ae 823 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1ea747c0
RG
824 if timeout:
825 sock.settimeout(timeout)
826
8be2b867 827 sock.connect(("::1", cls._consolePort, 0, 0) if IPv6 else ("127.0.0.1", cls._consolePort))
1ea747c0
RG
828 sock.send(ourNonce)
829 theirNonce = sock.recv(len(ourNonce))
7b925432 830 if len(theirNonce) != len(ourNonce):
05a5b575 831 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce)))
bdfa6902
RG
832 if len(theirNonce) == 0:
833 raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce)))
7b925432 834 return None
1ea747c0 835
b4f23783 836 halfNonceSize = int(len(ourNonce) / 2)
333ea16e
RG
837 readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:]
838 writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:]
333ea16e 839 msg = cls._encryptConsole(command, writingNonce)
1ea747c0
RG
840 sock.send(struct.pack("!I", len(msg)))
841 sock.send(msg)
842 data = sock.recv(4)
9c9b4998
RG
843 if not data:
844 raise socket.error("Got EOF while reading the response size")
845
1ea747c0
RG
846 (responseLen,) = struct.unpack("!I", data)
847 data = sock.recv(responseLen)
333ea16e 848 response = cls._decryptConsole(data, readingNonce)
75b536de 849 sock.close()
1ea747c0 850 return response
5df86a8a
RG
851
852 def compareOptions(self, a, b):
4bfebc93 853 self.assertEqual(len(a), len(b))
b4f23783 854 for idx in range(len(a)):
4bfebc93 855 self.assertEqual(a[idx], b[idx])
5df86a8a
RG
856
857 def checkMessageNoEDNS(self, expected, received):
4bfebc93
CH
858 self.assertEqual(expected, received)
859 self.assertEqual(received.edns, -1)
860 self.assertEqual(len(received.options), 0)
5df86a8a 861
e7c732b8 862 def checkMessageEDNSWithoutOptions(self, expected, received):
4bfebc93
CH
863 self.assertEqual(expected, received)
864 self.assertEqual(received.edns, 0)
865 self.assertEqual(expected.payload, received.payload)
e7c732b8 866
5df86a8a 867 def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
4bfebc93
CH
868 self.assertEqual(expected, received)
869 self.assertEqual(received.edns, 0)
870 self.assertEqual(expected.payload, received.payload)
871 self.assertEqual(len(received.options), withCookies)
5df86a8a
RG
872 if withCookies:
873 for option in received.options:
4bfebc93 874 self.assertEqual(option.otype, 10)
be90d6bd
RG
875 else:
876 for option in received.options:
4bfebc93 877 self.assertNotEqual(option.otype, 10)
5df86a8a 878
cbf4e13a 879 def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0):
4bfebc93
CH
880 self.assertEqual(expected, received)
881 self.assertEqual(received.edns, 0)
882 self.assertEqual(expected.payload, received.payload)
883 self.assertEqual(len(received.options), 1 + additionalOptions)
cbf4e13a
RG
884 hasECS = False
885 for option in received.options:
886 if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE:
887 hasECS = True
888 else:
4bfebc93 889 self.assertNotEqual(additionalOptions, 0)
cbf4e13a 890
5df86a8a 891 self.compareOptions(expected.options, received.options)
cbf4e13a 892 self.assertTrue(hasECS)
5df86a8a 893
346410cd
CHB
894 def checkMessageEDNS(self, expected, received):
895 self.assertEqual(expected, received)
896 self.assertEqual(received.edns, 0)
897 self.assertEqual(expected.payload, received.payload)
898 self.assertEqual(len(expected.options), len(received.options))
899 self.compareOptions(expected.options, received.options)
900
cbf4e13a
RG
901 def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0):
902 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
5df86a8a 903
346410cd
CHB
904 def checkQueryEDNS(self, expected, received):
905 self.checkMessageEDNS(expected, received)
906
cbf4e13a
RG
907 def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0):
908 self.checkMessageEDNSWithECS(expected, received, additionalOptions)
5df86a8a
RG
909
910 def checkQueryEDNSWithoutECS(self, expected, received):
911 self.checkMessageEDNSWithoutECS(expected, received)
912
913 def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0):
914 self.checkMessageEDNSWithoutECS(expected, received, withCookies)
915
916 def checkQueryNoEDNS(self, expected, received):
917 self.checkMessageNoEDNS(expected, received)
918
919 def checkResponseNoEDNS(self, expected, received):
920 self.checkMessageNoEDNS(expected, received)
b0d08f82 921
58ae5410
RG
922 @staticmethod
923 def generateNewCertificateAndKey(filePrefix):
1d896c34 924 # generate and sign a new cert
58ae5410 925 cmd = ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', filePrefix + '.key', '-out', filePrefix + '.csr', '-config', 'configServer.conf']
1d896c34
RG
926 output = None
927 try:
928 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
929 output = process.communicate(input='')
930 except subprocess.CalledProcessError as exc:
931 raise AssertionError('openssl req failed (%d): %s' % (exc.returncode, exc.output))
58ae5410 932 cmd = ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', filePrefix + '.csr', '-out', filePrefix + '.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req']
1d896c34
RG
933 output = None
934 try:
935 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
936 output = process.communicate(input='')
937 except subprocess.CalledProcessError as exc:
938 raise AssertionError('openssl x509 failed (%d): %s' % (exc.returncode, exc.output))
939
58ae5410
RG
940 with open(filePrefix + '.chain', 'w') as outFile:
941 for inFileName in [filePrefix + '.pem', 'ca.pem']:
1d896c34
RG
942 with open(inFileName) as inFile:
943 outFile.write(inFile.read())
0e6892c6 944
58ae5410 945 cmd = ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', filePrefix + '.pem', '-CAfile', 'ca.pem', '-inkey', filePrefix + '.key', '-out', filePrefix + '.p12']
5ac11505
CHB
946 output = None
947 try:
948 process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True)
949 output = process.communicate(input='')
950 except subprocess.CalledProcessError as exc:
951 raise AssertionError('openssl pkcs12 failed (%d): %s' % (exc.returncode, exc.output))
952
0e6892c6
RG
953 def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[], v6=False, sourcePort=None, destinationPort=None):
954 proxy = ProxyProtocol()
955 self.assertTrue(proxy.parseHeader(receivedProxyPayload))
956 self.assertEqual(proxy.version, 0x02)
957 self.assertEqual(proxy.command, 0x01)
958 if v6:
959 self.assertEqual(proxy.family, 0x02)
960 else:
961 self.assertEqual(proxy.family, 0x01)
962 if not isTCP:
963 self.assertEqual(proxy.protocol, 0x02)
964 else:
965 self.assertEqual(proxy.protocol, 0x01)
966 self.assertGreater(proxy.contentLen, 0)
967
968 self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload))
969 self.assertEqual(proxy.source, source)
970 self.assertEqual(proxy.destination, destination)
971 if sourcePort:
972 self.assertEqual(proxy.sourcePort, sourcePort)
973 if destinationPort:
974 self.assertEqual(proxy.destinationPort, destinationPort)
975 else:
976 self.assertEqual(proxy.destinationPort, self._dnsDistPort)
977
978 self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload))
979 proxy.values.sort()
980 values.sort()
981 self.assertEqual(proxy.values, values)
fda32c1c
RG
982
983 @classmethod
984 def getDOHGetURL(cls, baseurl, query, rawQuery=False):
985 if rawQuery:
986 wire = query
987 else:
988 wire = query.to_wire()
989 param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
990 return baseurl + "?dns=" + param
991
992 @classmethod
993 def openDOHConnection(cls, port, caFile, timeout=2.0):
994 conn = pycurl.Curl()
995 conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
996
997 conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message",
998 "Accept: application/dns-message"])
999 return conn
1000
1001 @classmethod
41f36765 1002 def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None, conn=None):
fda32c1c 1003 url = cls.getDOHGetURL(baseurl, query, rawQuery)
c02b7e13
RG
1004
1005 if not conn:
c02b7e13
RG
1006 conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
1007 # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
1008 conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
1009
fda32c1c
RG
1010 if useHTTPS:
1011 conn.setopt(pycurl.SSL_VERIFYPEER, 1)
1012 conn.setopt(pycurl.SSL_VERIFYHOST, 2)
1013 if caFile:
1014 conn.setopt(pycurl.CAINFO, caFile)
1015
c02b7e13
RG
1016 response_headers = BytesIO()
1017 #conn.setopt(pycurl.VERBOSE, True)
1018 conn.setopt(pycurl.URL, url)
1019 conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
1020
fda32c1c
RG
1021 conn.setopt(pycurl.HTTPHEADER, customHeaders)
1022 conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
1023
1024 if response:
1025 if toQueue:
1026 toQueue.put(response, True, timeout)
1027 else:
1028 cls._toResponderQueue.put(response, True, timeout)
1029
1030 receivedQuery = None
1031 message = None
1032 cls._response_headers = ''
1033 data = conn.perform_rb()
1034 cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
1035 if cls._rcode == 200 and not rawResponse:
1036 message = dns.message.from_wire(data)
1037 elif rawResponse:
1038 message = data
1039
1040 if useQueue:
1041 if fromQueue:
1042 if not fromQueue.empty():
1043 receivedQuery = fromQueue.get(True, timeout)
1044 else:
1045 if not cls._fromResponderQueue.empty():
1046 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1047
1048 cls._response_headers = response_headers.getvalue()
1049 return (receivedQuery, message)
1050
1051 @classmethod
1052 def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
1053 url = baseurl
1054 conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
1055 response_headers = BytesIO()
1056 #conn.setopt(pycurl.VERBOSE, True)
1057 conn.setopt(pycurl.URL, url)
1058 conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
7e8a05fa
RG
1059 # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
1060 conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
fda32c1c
RG
1061 if useHTTPS:
1062 conn.setopt(pycurl.SSL_VERIFYPEER, 1)
1063 conn.setopt(pycurl.SSL_VERIFYHOST, 2)
1064 if caFile:
1065 conn.setopt(pycurl.CAINFO, caFile)
1066
1067 conn.setopt(pycurl.HTTPHEADER, customHeaders)
1068 conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
1069 conn.setopt(pycurl.POST, True)
1070 data = query
1071 if not rawQuery:
1072 data = data.to_wire()
1073
1074 conn.setopt(pycurl.POSTFIELDS, data)
1075
1076 if response:
1077 cls._toResponderQueue.put(response, True, timeout)
1078
1079 receivedQuery = None
1080 message = None
1081 cls._response_headers = ''
1082 data = conn.perform_rb()
1083 cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
1084 if cls._rcode == 200 and not rawResponse:
1085 message = dns.message.from_wire(data)
1086 elif rawResponse:
1087 message = data
1088
1089 if useQueue and not cls._fromResponderQueue.empty():
1090 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1091
1092 cls._response_headers = response_headers.getvalue()
1093 return (receivedQuery, message)
334e3549 1094
41f36765
RG
1095 def sendDOHQueryWrapper(self, query, response, useQueue=True):
1096 return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
334e3549 1097
2f4ac048
RG
1098 def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True):
1099 return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
1100
1101 def sendDOHWithH2OQueryWrapper(self, query, response, useQueue=True):
1102 return self.sendDOHQuery(self._dohWithH2OServerPort, self._serverName, self._dohWithH2OBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
1103
334e3549
RG
1104 def sendDOTQueryWrapper(self, query, response, useQueue=True):
1105 return self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response, self._caCert, useQueue=useQueue)
3dc49a89 1106
0a6676a4
RG
1107 def sendDOQQueryWrapper(self, query, response, useQueue=True):
1108 return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName)
1109
655fe34d
CHB
1110 def sendDOH3QueryWrapper(self, query, response, useQueue=True):
1111 return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName)
3dc49a89 1112 @classmethod
0a6676a4 1113 def getDOQConnection(cls, port, caFile=None, source=None, source_port=0):
3dc49a89
CHB
1114
1115 manager = dns.quic.SyncQuicManager(
1116 verify_mode=caFile
1117 )
1118
0a6676a4 1119 return manager.connect('127.0.0.1', port, source, source_port)
3dc49a89
CHB
1120
1121 @classmethod
0a6676a4 1122 def sendDOQQuery(cls, port, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None):
3dc49a89
CHB
1123
1124 if response:
1125 if toQueue:
1126 toQueue.put(response, True, timeout)
1127 else:
1128 cls._toResponderQueue.put(response, True, timeout)
1129
9ec97c74 1130 (message, _) = quic_query(query, '127.0.0.1', timeout, port, verify=caFile, server_hostname=serverName)
3dc49a89
CHB
1131
1132 receivedQuery = None
1133
1134 if useQueue:
1135 if fromQueue:
1136 if not fromQueue.empty():
1137 receivedQuery = fromQueue.get(True, timeout)
1138 else:
1139 if not cls._fromResponderQueue.empty():
1140 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1141
1142 return (receivedQuery, message)
4f0b10a9
CHB
1143
1144 @classmethod
d0439b42 1145 def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False):
4f0b10a9
CHB
1146
1147 if response:
1148 if toQueue:
1149 toQueue.put(response, True, timeout)
1150 else:
1151 cls._toResponderQueue.put(response, True, timeout)
1152
d0439b42 1153 message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post)
4f0b10a9
CHB
1154
1155 receivedQuery = None
1156
1157 if useQueue:
1158 if fromQueue:
1159 if not fromQueue.empty():
1160 receivedQuery = fromQueue.get(True, timeout)
1161 else:
1162 if not cls._fromResponderQueue.empty():
1163 receivedQuery = cls._fromResponderQueue.get(True, timeout)
1164
1165 return (receivedQuery, message)