]>
Commit | Line | Data |
---|---|---|
ca404e94 RG |
1 | #!/usr/bin/env python2 |
2 | ||
fda32c1c | 3 | import base64 |
95f0b802 | 4 | import copy |
ffabdc3e | 5 | import errno |
ca404e94 RG |
6 | import os |
7 | import socket | |
a227f47d | 8 | import ssl |
ca404e94 RG |
9 | import struct |
10 | import subprocess | |
11 | import sys | |
12 | import threading | |
13 | import time | |
14 | import unittest | |
9d71a0cf | 15 | |
5df86a8a | 16 | import clientsubnetoption |
9d71a0cf | 17 | |
b1bec9f0 RG |
18 | import dns |
19 | import dns.message | |
9d71a0cf | 20 | |
1ea747c0 RG |
21 | import libnacl |
22 | import libnacl.utils | |
ca404e94 | 23 | |
9d71a0cf RG |
24 | import h2.connection |
25 | import h2.events | |
26 | import h2.config | |
27 | ||
fda32c1c RG |
28 | import pycurl |
29 | from io import BytesIO | |
30 | ||
e7000cce | 31 | from doqclient import quic_query |
4f0b10a9 | 32 | from doh3client import doh3_query |
e7000cce | 33 | |
6bd430bf | 34 | from eqdnsmessage import AssertEqualDNSMessageMixin |
0e6892c6 | 35 | from proxyprotocol import ProxyProtocol |
6bd430bf | 36 | |
b4f23783 | 37 | # Python2/3 compatibility hacks |
7a0ea291 | 38 | try: |
39 | from queue import Queue | |
40 | except ImportError: | |
b4f23783 | 41 | from Queue import Queue |
7a0ea291 | 42 | |
43 | try: | |
b4f23783 | 44 | range = xrange |
7a0ea291 | 45 | except NameError: |
46 | pass | |
b4f23783 | 47 | |
630eb526 RG |
48 | def getWorkerID(): |
49 | if not 'PYTEST_XDIST_WORKER' in os.environ: | |
50 | return 0 | |
51 | workerName = os.environ['PYTEST_XDIST_WORKER'] | |
52 | return int(workerName[2:]) | |
53 | ||
54 | workerPorts = {} | |
55 | ||
56 | def pickAvailablePort(): | |
57 | global workerPorts | |
58 | workerID = getWorkerID() | |
59 | if workerID in workerPorts: | |
60 | port = workerPorts[workerID] + 1 | |
61 | else: | |
62 | port = 11000 + (workerID * 1000) | |
63 | workerPorts[workerID] = port | |
64 | return port | |
b4f23783 | 65 | |
6bd430bf | 66 | class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): |
ca404e94 RG |
67 | """ |
68 | Set up a dnsdist instance and responder threads. | |
69 | Queries sent to dnsdist are relayed to the responder threads, | |
70 | who reply with the response provided by the tests themselves | |
71 | on a queue. Responder threads also queue the queries received | |
72 | from dnsdist on a separate queue, allowing the tests to check | |
73 | that the queries sent from dnsdist were as expected. | |
74 | """ | |
b052847c | 75 | _dnsDistListeningAddr = "127.0.0.1" |
b4f23783 CH |
76 | _toResponderQueue = Queue() |
77 | _fromResponderQueue = Queue() | |
617dfe22 | 78 | _queueTimeout = 1 |
ca404e94 | 79 | _dnsdist = None |
ec5f5c6b | 80 | _responsesCounter = {} |
18a0e7c6 | 81 | _config_template = """ |
18a0e7c6 CH |
82 | """ |
83 | _config_params = ['_testServerPort'] | |
84 | _acl = ['127.0.0.1/32'] | |
1ea747c0 | 85 | _consoleKey = None |
98650fde RG |
86 | _healthCheckName = 'a.root-servers.net.' |
87 | _healthCheckCounter = 0 | |
e44df0f1 | 88 | _answerUnexpected = True |
f73ce0e3 | 89 | _checkConfigExpectedOutput = None |
2a3cafcd | 90 | _verboseMode = False |
9bf515e1 | 91 | _sudoMode = False |
db7acdaf | 92 | _skipListeningOnCL = False |
1953ab6c RG |
93 | _alternateListeningAddr = None |
94 | _alternateListeningPort = None | |
7373e3a6 RG |
95 | _backgroundThreads = {} |
96 | _UDPResponder = None | |
97 | _TCPResponder = None | |
b38aaeb5 | 98 | _extraStartupSleep = 0 |
630eb526 RG |
99 | _dnsDistPort = pickAvailablePort() |
100 | _consolePort = pickAvailablePort() | |
101 | _testServerPort = pickAvailablePort() | |
ca404e94 | 102 | |
ffabdc3e OM |
103 | @classmethod |
104 | def waitForTCPSocket(cls, ipaddress, port): | |
105 | for try_number in range(0, 20): | |
106 | try: | |
107 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
108 | sock.settimeout(1.0) | |
109 | sock.connect((ipaddress, port)) | |
110 | sock.close() | |
111 | return | |
112 | except Exception as err: | |
113 | if err.errno != errno.ECONNREFUSED: | |
114 | print(f'Error occurred: {try_number} {err}', file=sys.stderr) | |
115 | time.sleep(0.1) | |
116 | # We assume the dnsdist instance does not listen. That's fine. | |
117 | ||
ca404e94 RG |
118 | @classmethod |
119 | def startResponders(cls): | |
120 | print("Launching responders..") | |
630eb526 | 121 | cls._testServerPort = pickAvailablePort() |
ec5f5c6b | 122 | |
5df86a8a | 123 | cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue]) |
630eb526 | 124 | cls._UDPResponder.daemon = True |
ca404e94 | 125 | cls._UDPResponder.start() |
5df86a8a | 126 | cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue]) |
630eb526 | 127 | cls._TCPResponder.daemon = True |
ca404e94 | 128 | cls._TCPResponder.start() |
d53e77f2 | 129 | cls.waitForTCPSocket("127.0.0.1", cls._testServerPort); |
ca404e94 RG |
130 | |
131 | @classmethod | |
aac2124b | 132 | def startDNSDist(cls): |
630eb526 RG |
133 | cls._dnsDistPort = pickAvailablePort() |
134 | cls._consolePort = pickAvailablePort() | |
135 | ||
ca404e94 | 136 | print("Launching dnsdist..") |
aac2124b | 137 | confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__)) |
18a0e7c6 CH |
138 | params = tuple([getattr(cls, param) for param in cls._config_params]) |
139 | print(params) | |
aac2124b | 140 | with open(confFile, 'w') as conf: |
18a0e7c6 | 141 | conf.write("-- Autogenerated by dnsdisttests.py\n") |
630eb526 | 142 | conf.write(f"-- dnsdist will listen on {cls._dnsDistPort}") |
18a0e7c6 | 143 | conf.write(cls._config_template % params) |
f3853a40 | 144 | conf.write("setSecurityPollSuffix('')") |
18a0e7c6 | 145 | |
db7acdaf RG |
146 | if cls._skipListeningOnCL: |
147 | dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile ] | |
148 | else: | |
149 | dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile, | |
150 | '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ] | |
151 | ||
2a3cafcd RG |
152 | if cls._verboseMode: |
153 | dnsdistcmd.append('-v') | |
9bf515e1 | 154 | if cls._sudoMode: |
e4b04b6e RG |
155 | preserve_env_values = ['LD_LIBRARY_PATH', 'LLVM_PROFILE_FILE'] |
156 | for value in preserve_env_values: | |
157 | if value in os.environ: | |
158 | dnsdistcmd.insert(0, value + '=' + os.environ[value]) | |
9bf515e1 | 159 | dnsdistcmd.insert(0, 'sudo') |
2a3cafcd | 160 | |
18a0e7c6 CH |
161 | for acl in cls._acl: |
162 | dnsdistcmd.extend(['--acl', acl]) | |
163 | print(' '.join(dnsdistcmd)) | |
164 | ||
6b44773a CH |
165 | # validate config with --check-config, which sets client=true, possibly exposing bugs. |
166 | testcmd = dnsdistcmd + ['--check-config'] | |
ff0bc6a6 JS |
167 | try: |
168 | output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True) | |
169 | except subprocess.CalledProcessError as exc: | |
170 | raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output)) | |
f73ce0e3 RG |
171 | if cls._checkConfigExpectedOutput is not None: |
172 | expectedOutput = cls._checkConfigExpectedOutput | |
173 | else: | |
174 | expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode() | |
2a3cafcd | 175 | if not cls._verboseMode and output != expectedOutput: |
630eb526 | 176 | raise AssertionError('dnsdist --check-config failed: %s (expected %s)' % (output, expectedOutput)) |
6b44773a | 177 | |
aac2124b RG |
178 | logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__)) |
179 | with open(logFile, 'w') as fdLog: | |
180 | cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog) | |
ca404e94 | 181 | |
1953ab6c RG |
182 | if cls._alternateListeningAddr and cls._alternateListeningPort: |
183 | cls.waitForTCPSocket(cls._alternateListeningAddr, cls._alternateListeningPort) | |
184 | else: | |
185 | cls.waitForTCPSocket(cls._dnsDistListeningAddr, cls._dnsDistPort) | |
ca404e94 RG |
186 | |
187 | if cls._dnsdist.poll() is not None: | |
ffabdc3e OM |
188 | print(f"\n*** startDNSDist log for {logFile} ***") |
189 | with open(logFile, 'r') as fdLog: | |
190 | print(fdLog.read()) | |
191 | print(f"*** End startDNSDist log for {logFile} ***") | |
192 | raise AssertionError('%s failed (%d)' % (dnsdistcmd, cls._dnsdist.returncode)) | |
b38aaeb5 | 193 | time.sleep(cls._extraStartupSleep) |
ca404e94 RG |
194 | |
195 | @classmethod | |
196 | def setUpSockets(cls): | |
197 | print("Setting up UDP socket..") | |
198 | cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
1ade83b2 | 199 | cls._sock.settimeout(2.0) |
ca404e94 RG |
200 | cls._sock.connect(("127.0.0.1", cls._dnsDistPort)) |
201 | ||
e4284d05 OM |
202 | @classmethod |
203 | def killProcess(cls, p): | |
204 | # Don't try to kill it if it's already dead | |
205 | if p.poll() is not None: | |
206 | return | |
207 | try: | |
208 | p.terminate() | |
6a51a279 | 209 | for count in range(20): |
e4284d05 OM |
210 | x = p.poll() |
211 | if x is not None: | |
212 | break | |
213 | time.sleep(0.1) | |
214 | if x is None: | |
215 | print("kill...", p, file=sys.stderr) | |
216 | p.kill() | |
217 | p.wait() | |
218 | except OSError as e: | |
219 | # There is a race-condition with the poll() and | |
220 | # kill() statements, when the process is dead on the | |
221 | # kill(), this is fine | |
222 | if e.errno != errno.ESRCH: | |
223 | raise | |
224 | ||
ca404e94 RG |
225 | @classmethod |
226 | def setUpClass(cls): | |
227 | ||
228 | cls.startResponders() | |
aac2124b | 229 | cls.startDNSDist() |
ca404e94 RG |
230 | cls.setUpSockets() |
231 | ||
232 | print("Launching tests..") | |
233 | ||
234 | @classmethod | |
235 | def tearDownClass(cls): | |
ffabdc3e OM |
236 | cls._sock.close() |
237 | # tell the background threads to stop, if any | |
238 | for backgroundThread in cls._backgroundThreads: | |
239 | cls._backgroundThreads[backgroundThread] = False | |
e4284d05 | 240 | cls.killProcess(cls._dnsdist) |
7373e3a6 | 241 | |
ca404e94 | 242 | @classmethod |
fe1c60f2 | 243 | def _ResponderIncrementCounter(cls): |
630eb526 RG |
244 | if threading.current_thread().name in cls._responsesCounter: |
245 | cls._responsesCounter[threading.current_thread().name] += 1 | |
ec5f5c6b | 246 | else: |
630eb526 | 247 | cls._responsesCounter[threading.current_thread().name] = 1 |
ec5f5c6b | 248 | |
fe1c60f2 | 249 | @classmethod |
4aa08b62 | 250 | def _getResponse(cls, request, fromQueue, toQueue, synthesize=None): |
fe1c60f2 RG |
251 | response = None |
252 | if len(request.question) != 1: | |
253 | print("Skipping query with question count %d" % (len(request.question))) | |
254 | return None | |
98650fde RG |
255 | healthCheck = str(request.question[0].name).endswith(cls._healthCheckName) |
256 | if healthCheck: | |
257 | cls._healthCheckCounter += 1 | |
4aa08b62 | 258 | response = dns.message.make_response(request) |
98650fde | 259 | else: |
fe1c60f2 | 260 | cls._ResponderIncrementCounter() |
5df86a8a | 261 | if not fromQueue.empty(): |
4aa08b62 | 262 | toQueue.put(request, True, cls._queueTimeout) |
90186270 RG |
263 | response = fromQueue.get(True, cls._queueTimeout) |
264 | if response: | |
265 | response = copy.copy(response) | |
266 | response.id = request.id | |
267 | ||
268 | if synthesize is not None: | |
269 | response = dns.message.make_response(request) | |
270 | response.set_rcode(synthesize) | |
fe1c60f2 | 271 | |
e44df0f1 | 272 | if not response: |
90186270 | 273 | if cls._answerUnexpected: |
e44df0f1 RG |
274 | response = dns.message.make_response(request) |
275 | response.set_rcode(dns.rcode.SERVFAIL) | |
fe1c60f2 RG |
276 | |
277 | return response | |
278 | ||
ec5f5c6b | 279 | @classmethod |
a620f197 | 280 | def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None): |
7373e3a6 | 281 | cls._backgroundThreads[threading.get_native_id()] = True |
3ef7ab0d RG |
282 | # trailingDataResponse=True means "ignore trailing data". |
283 | # Other values are either False (meaning "raise an exception") | |
284 | # or are interpreted as a response RCODE for queries with trailing data. | |
a620f197 | 285 | # callback is invoked for every -even healthcheck ones- query and should return a raw response |
4aa08b62 | 286 | ignoreTrailing = trailingDataResponse is True |
3ef7ab0d | 287 | |
ca404e94 RG |
288 | sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) |
289 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | |
ec5f5c6b | 290 | sock.bind(("127.0.0.1", port)) |
28d4c42d | 291 | sock.settimeout(0.5) |
ca404e94 | 292 | while True: |
7373e3a6 RG |
293 | try: |
294 | data, addr = sock.recvfrom(4096) | |
295 | except socket.timeout: | |
296 | if cls._backgroundThreads.get(threading.get_native_id(), False) == False: | |
297 | del cls._backgroundThreads[threading.get_native_id()] | |
298 | break | |
299 | else: | |
300 | continue | |
301 | ||
4aa08b62 RG |
302 | forceRcode = None |
303 | try: | |
304 | request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) | |
305 | except dns.message.TrailingJunk as e: | |
51f07ad4 | 306 | print('trailing data exception in UDPResponder') |
3ef7ab0d | 307 | if trailingDataResponse is False or forceRcode is True: |
4aa08b62 RG |
308 | raise |
309 | print("UDP query with trailing data, synthesizing response") | |
310 | request = dns.message.from_wire(data, ignore_trailing=True) | |
311 | forceRcode = trailingDataResponse | |
312 | ||
f3913dd2 | 313 | wire = None |
a620f197 RG |
314 | if callback: |
315 | wire = callback(request) | |
316 | else: | |
f8662974 RG |
317 | if request.edns > 1: |
318 | forceRcode = dns.rcode.BADVERS | |
a620f197 RG |
319 | response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) |
320 | if response: | |
321 | wire = response.to_wire() | |
87c605c4 | 322 | |
f3913dd2 RG |
323 | if not wire: |
324 | continue | |
325 | ||
a620f197 | 326 | sock.sendto(wire, addr) |
7373e3a6 | 327 | |
ca404e94 RG |
328 | sock.close() |
329 | ||
330 | @classmethod | |
e572dbf5 | 331 | def handleTCPConnection(cls, conn, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, partialWrite=False): |
645a1ca4 | 332 | ignoreTrailing = trailingDataResponse is True |
fabe8e3a RG |
333 | try: |
334 | data = conn.recv(2) | |
335 | except Exception as err: | |
336 | data = None | |
337 | print(f'Error while reading query size in TCP responder thread {err=}, {type(err)=}') | |
645a1ca4 RG |
338 | if not data: |
339 | conn.close() | |
340 | return | |
341 | ||
342 | (datalen,) = struct.unpack("!H", data) | |
343 | data = conn.recv(datalen) | |
344 | forceRcode = None | |
345 | try: | |
346 | request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) | |
347 | except dns.message.TrailingJunk as e: | |
348 | if trailingDataResponse is False or forceRcode is True: | |
349 | raise | |
350 | print("TCP query with trailing data, synthesizing response") | |
351 | request = dns.message.from_wire(data, ignore_trailing=True) | |
352 | forceRcode = trailingDataResponse | |
353 | ||
354 | if callback: | |
355 | wire = callback(request) | |
356 | else: | |
f8662974 RG |
357 | if request.edns > 1: |
358 | forceRcode = dns.rcode.BADVERS | |
645a1ca4 RG |
359 | response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) |
360 | if response: | |
361 | wire = response.to_wire(max_size=65535) | |
362 | ||
363 | if not wire: | |
364 | conn.close() | |
365 | return | |
366 | ||
e572dbf5 RG |
367 | wireLen = struct.pack("!H", len(wire)) |
368 | if partialWrite: | |
369 | for b in wireLen: | |
370 | conn.send(bytes([b])) | |
371 | time.sleep(0.5) | |
372 | else: | |
373 | conn.send(wireLen) | |
645a1ca4 RG |
374 | conn.send(wire) |
375 | ||
376 | while multipleResponses: | |
936dd73c RG |
377 | # do not block, and stop as soon as the queue is empty, either the next response is already here or we are done |
378 | # otherwise we might read responses intended for the next connection | |
645a1ca4 RG |
379 | if fromQueue.empty(): |
380 | break | |
381 | ||
936dd73c | 382 | response = fromQueue.get(False) |
645a1ca4 RG |
383 | if not response: |
384 | break | |
385 | ||
386 | response = copy.copy(response) | |
387 | response.id = request.id | |
388 | wire = response.to_wire(max_size=65535) | |
389 | try: | |
390 | conn.send(struct.pack("!H", len(wire))) | |
391 | conn.send(wire) | |
392 | except socket.error as e: | |
393 | # some of the tests are going to close | |
394 | # the connection on us, just deal with it | |
395 | break | |
396 | ||
397 | conn.close() | |
398 | ||
399 | @classmethod | |
e572dbf5 | 400 | def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, multipleConnections=False, listeningAddr='127.0.0.1', partialWrite=False): |
7373e3a6 | 401 | cls._backgroundThreads[threading.get_native_id()] = True |
3ef7ab0d RG |
402 | # trailingDataResponse=True means "ignore trailing data". |
403 | # Other values are either False (meaning "raise an exception") | |
404 | # or are interpreted as a response RCODE for queries with trailing data. | |
a620f197 | 405 | # callback is invoked for every -even healthcheck ones- query and should return a raw response |
3ef7ab0d | 406 | |
ca404e94 | 407 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
501af9ae | 408 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
ca404e94 RG |
409 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) |
410 | try: | |
144eebeb | 411 | sock.bind((listeningAddr, port)) |
ca404e94 RG |
412 | except socket.error as e: |
413 | print("Error binding in the TCP responder: %s" % str(e)) | |
414 | sys.exit(1) | |
415 | ||
416 | sock.listen(100) | |
28d4c42d | 417 | sock.settimeout(0.5) |
7d2856a6 RG |
418 | if tlsContext: |
419 | sock = tlsContext.wrap_socket(sock, server_side=True) | |
420 | ||
ca404e94 | 421 | while True: |
7d2856a6 RG |
422 | try: |
423 | (conn, _) = sock.accept() | |
424 | except ssl.SSLError: | |
425 | continue | |
ae3b96d9 RG |
426 | except ConnectionResetError: |
427 | continue | |
7373e3a6 RG |
428 | except socket.timeout: |
429 | if cls._backgroundThreads.get(threading.get_native_id(), False) == False: | |
430 | del cls._backgroundThreads[threading.get_native_id()] | |
431 | break | |
432 | else: | |
433 | continue | |
ae3b96d9 | 434 | |
6ac8517d | 435 | conn.settimeout(5.0) |
645a1ca4 RG |
436 | if multipleConnections: |
437 | thread = threading.Thread(name='TCP Connection Handler', | |
438 | target=cls.handleTCPConnection, | |
e572dbf5 | 439 | args=[conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, partialWrite]) |
630eb526 | 440 | thread.daemon = True |
645a1ca4 | 441 | thread.start() |
a620f197 | 442 | else: |
e572dbf5 | 443 | cls.handleTCPConnection(conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, partialWrite) |
548c8b66 | 444 | |
ca404e94 RG |
445 | sock.close() |
446 | ||
c4c72a2c RG |
447 | @classmethod |
448 | def handleDoHConnection(cls, config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol): | |
449 | ignoreTrailing = trailingDataResponse is True | |
144eebeb RG |
450 | try: |
451 | h2conn = h2.connection.H2Connection(config=config) | |
452 | h2conn.initiate_connection() | |
453 | conn.sendall(h2conn.data_to_send()) | |
454 | except ssl.SSLEOFError as e: | |
455 | print("Unexpected EOF: %s" % (e)) | |
456 | return | |
fabe8e3a RG |
457 | except Exception as err: |
458 | print(f'Unexpected exception in DoH responder thread (connection init) {err=}, {type(err)=}') | |
459 | return | |
144eebeb | 460 | |
c4c72a2c RG |
461 | dnsData = {} |
462 | ||
463 | if useProxyProtocol: | |
464 | # try to read the entire Proxy Protocol header | |
465 | proxy = ProxyProtocol() | |
466 | header = conn.recv(proxy.HEADER_SIZE) | |
467 | if not header: | |
468 | print('unable to get header') | |
469 | conn.close() | |
470 | return | |
471 | ||
472 | if not proxy.parseHeader(header): | |
473 | print('unable to parse header') | |
474 | print(header) | |
475 | conn.close() | |
476 | return | |
477 | ||
478 | proxyContent = conn.recv(proxy.contentLen) | |
479 | if not proxyContent: | |
480 | print('unable to get content') | |
481 | conn.close() | |
482 | return | |
483 | ||
484 | payload = header + proxyContent | |
485 | toQueue.put(payload, True, cls._queueTimeout) | |
486 | ||
487 | # be careful, HTTP/2 headers and data might be in different recv() results | |
488 | requestHeaders = None | |
489 | while True: | |
fabe8e3a RG |
490 | try: |
491 | data = conn.recv(65535) | |
492 | except Exception as err: | |
493 | data = None | |
494 | print(f'Unexpected exception in DoH responder thread {err=}, {type(err)=}') | |
c4c72a2c RG |
495 | if not data: |
496 | break | |
497 | ||
498 | events = h2conn.receive_data(data) | |
499 | for event in events: | |
500 | if isinstance(event, h2.events.RequestReceived): | |
501 | requestHeaders = event.headers | |
502 | if isinstance(event, h2.events.DataReceived): | |
503 | h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id) | |
504 | if not event.stream_id in dnsData: | |
505 | dnsData[event.stream_id] = b'' | |
506 | dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data) | |
507 | if event.stream_ended: | |
508 | forceRcode = None | |
509 | status = 200 | |
510 | try: | |
511 | request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing) | |
512 | except dns.message.TrailingJunk as e: | |
513 | if trailingDataResponse is False or forceRcode is True: | |
514 | raise | |
515 | print("DOH query with trailing data, synthesizing response") | |
516 | request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True) | |
517 | forceRcode = trailingDataResponse | |
518 | ||
519 | if callback: | |
520 | status, wire = callback(request, requestHeaders, fromQueue, toQueue) | |
521 | else: | |
522 | response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) | |
523 | if response: | |
524 | wire = response.to_wire(max_size=65535) | |
525 | ||
526 | if not wire: | |
527 | conn.close() | |
528 | conn = None | |
529 | break | |
530 | ||
531 | headers = [ | |
532 | (':status', str(status)), | |
533 | ('content-length', str(len(wire))), | |
534 | ('content-type', 'application/dns-message'), | |
535 | ] | |
536 | h2conn.send_headers(stream_id=event.stream_id, headers=headers) | |
537 | h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True) | |
538 | ||
539 | data_to_send = h2conn.data_to_send() | |
540 | if data_to_send: | |
541 | conn.sendall(data_to_send) | |
542 | ||
543 | if conn is None: | |
544 | break | |
545 | ||
546 | if conn is not None: | |
547 | conn.close() | |
548 | ||
9d71a0cf | 549 | @classmethod |
0e6892c6 | 550 | def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False): |
7373e3a6 | 551 | cls._backgroundThreads[threading.get_native_id()] = True |
9d71a0cf RG |
552 | # trailingDataResponse=True means "ignore trailing data". |
553 | # Other values are either False (meaning "raise an exception") | |
554 | # or are interpreted as a response RCODE for queries with trailing data. | |
555 | # callback is invoked for every -even healthcheck ones- query and should return a raw response | |
9d71a0cf RG |
556 | |
557 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
558 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) | |
559 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | |
560 | try: | |
561 | sock.bind(("127.0.0.1", port)) | |
562 | except socket.error as e: | |
563 | print("Error binding in the TCP responder: %s" % str(e)) | |
564 | sys.exit(1) | |
565 | ||
566 | sock.listen(100) | |
28d4c42d | 567 | sock.settimeout(0.5) |
9d71a0cf RG |
568 | if tlsContext: |
569 | sock = tlsContext.wrap_socket(sock, server_side=True) | |
570 | ||
571 | config = h2.config.H2Configuration(client_side=False) | |
572 | ||
573 | while True: | |
574 | try: | |
575 | (conn, _) = sock.accept() | |
576 | except ssl.SSLError: | |
577 | continue | |
578 | except ConnectionResetError: | |
579 | continue | |
7373e3a6 RG |
580 | except socket.timeout: |
581 | if cls._backgroundThreads.get(threading.get_native_id(), False) == False: | |
582 | del cls._backgroundThreads[threading.get_native_id()] | |
583 | break | |
584 | else: | |
585 | continue | |
ae3b96d9 | 586 | |
9d71a0cf | 587 | conn.settimeout(5.0) |
c4c72a2c RG |
588 | thread = threading.Thread(name='DoH Connection Handler', |
589 | target=cls.handleDoHConnection, | |
590 | args=[config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol]) | |
630eb526 | 591 | thread.daemon = True |
c4c72a2c | 592 | thread.start() |
9d71a0cf RG |
593 | |
594 | sock.close() | |
595 | ||
ca404e94 | 596 | @classmethod |
55baa1f2 | 597 | def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False): |
90186270 | 598 | if useQueue and response is not None: |
617dfe22 | 599 | cls._toResponderQueue.put(response, True, timeout) |
ca404e94 RG |
600 | |
601 | if timeout: | |
602 | cls._sock.settimeout(timeout) | |
603 | ||
604 | try: | |
55baa1f2 RG |
605 | if not rawQuery: |
606 | query = query.to_wire() | |
607 | cls._sock.send(query) | |
ca404e94 | 608 | data = cls._sock.recv(4096) |
b1bec9f0 | 609 | except socket.timeout: |
ca404e94 RG |
610 | data = None |
611 | finally: | |
612 | if timeout: | |
613 | cls._sock.settimeout(None) | |
614 | ||
615 | receivedQuery = None | |
616 | message = None | |
617 | if useQueue and not cls._fromResponderQueue.empty(): | |
617dfe22 | 618 | receivedQuery = cls._fromResponderQueue.get(True, timeout) |
ca404e94 RG |
619 | if data: |
620 | message = dns.message.from_wire(data) | |
621 | return (receivedQuery, message) | |
622 | ||
623 | @classmethod | |
db7acdaf | 624 | def openTCPConnection(cls, timeout=None, port=None): |
ca404e94 | 625 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
501af9ae | 626 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
ca404e94 RG |
627 | if timeout: |
628 | sock.settimeout(timeout) | |
629 | ||
db7acdaf RG |
630 | if not port: |
631 | port = cls._dnsDistPort | |
632 | ||
633 | sock.connect(("127.0.0.1", port)) | |
9396d955 | 634 | return sock |
0a2087eb | 635 | |
9396d955 | 636 | @classmethod |
7e8a05fa | 637 | def openTLSConnection(cls, port, serverName, caCert=None, timeout=None, alpn=[]): |
a227f47d | 638 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
501af9ae | 639 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
a227f47d RG |
640 | if timeout: |
641 | sock.settimeout(timeout) | |
642 | ||
643 | # 2.7.9+ | |
644 | if hasattr(ssl, 'create_default_context'): | |
645 | sslctx = ssl.create_default_context(cafile=caCert) | |
7e8a05fa RG |
646 | if len(alpn)> 0 and hasattr(sslctx, 'set_alpn_protocols'): |
647 | sslctx.set_alpn_protocols(alpn) | |
a227f47d RG |
648 | sslsock = sslctx.wrap_socket(sock, server_hostname=serverName) |
649 | else: | |
650 | sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED) | |
651 | ||
652 | sslsock.connect(("127.0.0.1", port)) | |
653 | return sslsock | |
654 | ||
655 | @classmethod | |
656 | def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0): | |
9396d955 RG |
657 | if not rawQuery: |
658 | wire = query.to_wire() | |
659 | else: | |
660 | wire = query | |
55baa1f2 | 661 | |
a227f47d RG |
662 | if response: |
663 | cls._toResponderQueue.put(response, True, timeout) | |
664 | ||
9396d955 RG |
665 | sock.send(struct.pack("!H", len(wire))) |
666 | sock.send(wire) | |
667 | ||
668 | @classmethod | |
a227f47d | 669 | def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0): |
9396d955 RG |
670 | message = None |
671 | data = sock.recv(2) | |
672 | if data: | |
673 | (datalen,) = struct.unpack("!H", data) | |
f05cd66c | 674 | print(datalen) |
9396d955 | 675 | data = sock.recv(datalen) |
ca404e94 | 676 | if data: |
f05cd66c | 677 | print(data) |
9396d955 | 678 | message = dns.message.from_wire(data) |
a227f47d | 679 | |
f05cd66c | 680 | print(useQueue) |
a227f47d RG |
681 | if useQueue and not cls._fromResponderQueue.empty(): |
682 | receivedQuery = cls._fromResponderQueue.get(True, timeout) | |
f05cd66c | 683 | print(receivedQuery) |
a227f47d RG |
684 | return (receivedQuery, message) |
685 | else: | |
f05cd66c | 686 | print("queue empty") |
a227f47d | 687 | return message |
9396d955 | 688 | |
fda32c1c RG |
689 | @classmethod |
690 | def sendDOTQuery(cls, port, serverName, query, response, caFile, useQueue=True): | |
691 | conn = cls.openTLSConnection(port, serverName, caFile) | |
692 | cls.sendTCPQueryOverConnection(conn, query, response=response) | |
693 | if useQueue: | |
694 | return cls.recvTCPResponseOverConnection(conn, useQueue=useQueue) | |
695 | return None, cls.recvTCPResponseOverConnection(conn, useQueue=useQueue) | |
696 | ||
9396d955 RG |
697 | @classmethod |
698 | def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False): | |
699 | message = None | |
700 | if useQueue: | |
701 | cls._toResponderQueue.put(response, True, timeout) | |
702 | ||
9bf515e1 RG |
703 | try: |
704 | sock = cls.openTCPConnection(timeout) | |
705 | except socket.timeout as e: | |
706 | print("Timeout while opening TCP connection: %s" % (str(e))) | |
707 | return (None, None) | |
9396d955 RG |
708 | |
709 | try: | |
110624de RG |
710 | cls.sendTCPQueryOverConnection(sock, query, rawQuery, timeout=timeout) |
711 | message = cls.recvTCPResponseOverConnection(sock, timeout=timeout) | |
ca404e94 | 712 | except socket.timeout as e: |
0e6892c6 | 713 | print("Timeout while sending or receiving TCP data: %s" % (str(e))) |
ca404e94 RG |
714 | except socket.error as e: |
715 | print("Network error: %s" % (str(e))) | |
ca404e94 RG |
716 | finally: |
717 | sock.close() | |
718 | ||
719 | receivedQuery = None | |
f05cd66c | 720 | print(useQueue) |
ca404e94 | 721 | if useQueue and not cls._fromResponderQueue.empty(): |
f05cd66c | 722 | print(receivedQuery) |
617dfe22 | 723 | receivedQuery = cls._fromResponderQueue.get(True, timeout) |
f05cd66c RG |
724 | else: |
725 | print("queue is empty") | |
9396d955 | 726 | |
ca404e94 | 727 | return (receivedQuery, message) |
617dfe22 | 728 | |
548c8b66 RG |
729 | @classmethod |
730 | def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False): | |
731 | if useQueue: | |
732 | for response in responses: | |
733 | cls._toResponderQueue.put(response, True, timeout) | |
734 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
501af9ae | 735 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
548c8b66 RG |
736 | if timeout: |
737 | sock.settimeout(timeout) | |
738 | ||
739 | sock.connect(("127.0.0.1", cls._dnsDistPort)) | |
740 | messages = [] | |
741 | ||
742 | try: | |
743 | if not rawQuery: | |
744 | wire = query.to_wire() | |
745 | else: | |
746 | wire = query | |
747 | ||
748 | sock.send(struct.pack("!H", len(wire))) | |
749 | sock.send(wire) | |
750 | while True: | |
751 | data = sock.recv(2) | |
752 | if not data: | |
753 | break | |
754 | (datalen,) = struct.unpack("!H", data) | |
755 | data = sock.recv(datalen) | |
756 | messages.append(dns.message.from_wire(data)) | |
757 | ||
758 | except socket.timeout as e: | |
0e6892c6 | 759 | print("Timeout while receiving multiple TCP responses: %s" % (str(e))) |
548c8b66 RG |
760 | except socket.error as e: |
761 | print("Network error: %s" % (str(e))) | |
762 | finally: | |
763 | sock.close() | |
764 | ||
765 | receivedQuery = None | |
766 | if useQueue and not cls._fromResponderQueue.empty(): | |
767 | receivedQuery = cls._fromResponderQueue.get(True, timeout) | |
768 | return (receivedQuery, messages) | |
769 | ||
617dfe22 | 770 | def setUp(self): |
936dd73c | 771 | # This function is called before every test |
617dfe22 RG |
772 | |
773 | # Clear the responses counters | |
fd71df4e | 774 | self._responsesCounter.clear() |
617dfe22 | 775 | |
98650fde RG |
776 | self._healthCheckCounter = 0 |
777 | ||
617dfe22 RG |
778 | # Make sure the queues are empty, in case |
779 | # a previous test failed | |
936dd73c | 780 | self.clearResponderQueues() |
1ea747c0 | 781 | |
6bd430bf PD |
782 | super(DNSDistTest, self).setUp() |
783 | ||
3bef39c3 RG |
784 | @classmethod |
785 | def clearToResponderQueue(cls): | |
786 | while not cls._toResponderQueue.empty(): | |
787 | cls._toResponderQueue.get(False) | |
788 | ||
789 | @classmethod | |
790 | def clearFromResponderQueue(cls): | |
791 | while not cls._fromResponderQueue.empty(): | |
792 | cls._fromResponderQueue.get(False) | |
793 | ||
794 | @classmethod | |
795 | def clearResponderQueues(cls): | |
796 | cls.clearToResponderQueue() | |
797 | cls.clearFromResponderQueue() | |
798 | ||
1ea747c0 RG |
799 | @staticmethod |
800 | def generateConsoleKey(): | |
801 | return libnacl.utils.salsa_key() | |
802 | ||
803 | @classmethod | |
804 | def _encryptConsole(cls, command, nonce): | |
b4f23783 | 805 | command = command.encode('UTF-8') |
1ea747c0 RG |
806 | if cls._consoleKey is None: |
807 | return command | |
808 | return libnacl.crypto_secretbox(command, nonce, cls._consoleKey) | |
809 | ||
810 | @classmethod | |
811 | def _decryptConsole(cls, command, nonce): | |
812 | if cls._consoleKey is None: | |
b4f23783 CH |
813 | result = command |
814 | else: | |
815 | result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey) | |
816 | return result.decode('UTF-8') | |
1ea747c0 RG |
817 | |
818 | @classmethod | |
8be2b867 | 819 | def sendConsoleCommand(cls, command, timeout=5.0, IPv6=False): |
1ea747c0 RG |
820 | ourNonce = libnacl.utils.rand_nonce() |
821 | theirNonce = None | |
8be2b867 | 822 | sock = socket.socket(socket.AF_INET if not IPv6 else socket.AF_INET6, socket.SOCK_STREAM) |
501af9ae | 823 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
1ea747c0 RG |
824 | if timeout: |
825 | sock.settimeout(timeout) | |
826 | ||
8be2b867 | 827 | sock.connect(("::1", cls._consolePort, 0, 0) if IPv6 else ("127.0.0.1", cls._consolePort)) |
1ea747c0 RG |
828 | sock.send(ourNonce) |
829 | theirNonce = sock.recv(len(ourNonce)) | |
7b925432 | 830 | if len(theirNonce) != len(ourNonce): |
05a5b575 | 831 | print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce))) |
bdfa6902 RG |
832 | if len(theirNonce) == 0: |
833 | raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce))) | |
7b925432 | 834 | return None |
1ea747c0 | 835 | |
b4f23783 | 836 | halfNonceSize = int(len(ourNonce) / 2) |
333ea16e RG |
837 | readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:] |
838 | writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:] | |
333ea16e | 839 | msg = cls._encryptConsole(command, writingNonce) |
1ea747c0 RG |
840 | sock.send(struct.pack("!I", len(msg))) |
841 | sock.send(msg) | |
842 | data = sock.recv(4) | |
9c9b4998 RG |
843 | if not data: |
844 | raise socket.error("Got EOF while reading the response size") | |
845 | ||
1ea747c0 RG |
846 | (responseLen,) = struct.unpack("!I", data) |
847 | data = sock.recv(responseLen) | |
333ea16e | 848 | response = cls._decryptConsole(data, readingNonce) |
75b536de | 849 | sock.close() |
1ea747c0 | 850 | return response |
5df86a8a RG |
851 | |
852 | def compareOptions(self, a, b): | |
4bfebc93 | 853 | self.assertEqual(len(a), len(b)) |
b4f23783 | 854 | for idx in range(len(a)): |
4bfebc93 | 855 | self.assertEqual(a[idx], b[idx]) |
5df86a8a RG |
856 | |
857 | def checkMessageNoEDNS(self, expected, received): | |
4bfebc93 CH |
858 | self.assertEqual(expected, received) |
859 | self.assertEqual(received.edns, -1) | |
860 | self.assertEqual(len(received.options), 0) | |
5df86a8a | 861 | |
e7c732b8 | 862 | def checkMessageEDNSWithoutOptions(self, expected, received): |
4bfebc93 CH |
863 | self.assertEqual(expected, received) |
864 | self.assertEqual(received.edns, 0) | |
865 | self.assertEqual(expected.payload, received.payload) | |
e7c732b8 | 866 | |
5df86a8a | 867 | def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0): |
4bfebc93 CH |
868 | self.assertEqual(expected, received) |
869 | self.assertEqual(received.edns, 0) | |
870 | self.assertEqual(expected.payload, received.payload) | |
871 | self.assertEqual(len(received.options), withCookies) | |
5df86a8a RG |
872 | if withCookies: |
873 | for option in received.options: | |
4bfebc93 | 874 | self.assertEqual(option.otype, 10) |
be90d6bd RG |
875 | else: |
876 | for option in received.options: | |
4bfebc93 | 877 | self.assertNotEqual(option.otype, 10) |
5df86a8a | 878 | |
cbf4e13a | 879 | def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0): |
4bfebc93 CH |
880 | self.assertEqual(expected, received) |
881 | self.assertEqual(received.edns, 0) | |
882 | self.assertEqual(expected.payload, received.payload) | |
883 | self.assertEqual(len(received.options), 1 + additionalOptions) | |
cbf4e13a RG |
884 | hasECS = False |
885 | for option in received.options: | |
886 | if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE: | |
887 | hasECS = True | |
888 | else: | |
4bfebc93 | 889 | self.assertNotEqual(additionalOptions, 0) |
cbf4e13a | 890 | |
5df86a8a | 891 | self.compareOptions(expected.options, received.options) |
cbf4e13a | 892 | self.assertTrue(hasECS) |
5df86a8a | 893 | |
346410cd CHB |
894 | def checkMessageEDNS(self, expected, received): |
895 | self.assertEqual(expected, received) | |
896 | self.assertEqual(received.edns, 0) | |
897 | self.assertEqual(expected.payload, received.payload) | |
898 | self.assertEqual(len(expected.options), len(received.options)) | |
899 | self.compareOptions(expected.options, received.options) | |
900 | ||
cbf4e13a RG |
901 | def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0): |
902 | self.checkMessageEDNSWithECS(expected, received, additionalOptions) | |
5df86a8a | 903 | |
346410cd CHB |
904 | def checkQueryEDNS(self, expected, received): |
905 | self.checkMessageEDNS(expected, received) | |
906 | ||
cbf4e13a RG |
907 | def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0): |
908 | self.checkMessageEDNSWithECS(expected, received, additionalOptions) | |
5df86a8a RG |
909 | |
910 | def checkQueryEDNSWithoutECS(self, expected, received): | |
911 | self.checkMessageEDNSWithoutECS(expected, received) | |
912 | ||
913 | def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0): | |
914 | self.checkMessageEDNSWithoutECS(expected, received, withCookies) | |
915 | ||
916 | def checkQueryNoEDNS(self, expected, received): | |
917 | self.checkMessageNoEDNS(expected, received) | |
918 | ||
919 | def checkResponseNoEDNS(self, expected, received): | |
920 | self.checkMessageNoEDNS(expected, received) | |
b0d08f82 | 921 | |
58ae5410 RG |
922 | @staticmethod |
923 | def generateNewCertificateAndKey(filePrefix): | |
1d896c34 | 924 | # generate and sign a new cert |
58ae5410 | 925 | cmd = ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', filePrefix + '.key', '-out', filePrefix + '.csr', '-config', 'configServer.conf'] |
1d896c34 RG |
926 | output = None |
927 | try: | |
928 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True) | |
929 | output = process.communicate(input='') | |
930 | except subprocess.CalledProcessError as exc: | |
931 | raise AssertionError('openssl req failed (%d): %s' % (exc.returncode, exc.output)) | |
58ae5410 | 932 | cmd = ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', filePrefix + '.csr', '-out', filePrefix + '.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req'] |
1d896c34 RG |
933 | output = None |
934 | try: | |
935 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True) | |
936 | output = process.communicate(input='') | |
937 | except subprocess.CalledProcessError as exc: | |
938 | raise AssertionError('openssl x509 failed (%d): %s' % (exc.returncode, exc.output)) | |
939 | ||
58ae5410 RG |
940 | with open(filePrefix + '.chain', 'w') as outFile: |
941 | for inFileName in [filePrefix + '.pem', 'ca.pem']: | |
1d896c34 RG |
942 | with open(inFileName) as inFile: |
943 | outFile.write(inFile.read()) | |
0e6892c6 | 944 | |
58ae5410 | 945 | cmd = ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', filePrefix + '.pem', '-CAfile', 'ca.pem', '-inkey', filePrefix + '.key', '-out', filePrefix + '.p12'] |
5ac11505 CHB |
946 | output = None |
947 | try: | |
948 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True) | |
949 | output = process.communicate(input='') | |
950 | except subprocess.CalledProcessError as exc: | |
951 | raise AssertionError('openssl pkcs12 failed (%d): %s' % (exc.returncode, exc.output)) | |
952 | ||
0e6892c6 RG |
953 | def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[], v6=False, sourcePort=None, destinationPort=None): |
954 | proxy = ProxyProtocol() | |
955 | self.assertTrue(proxy.parseHeader(receivedProxyPayload)) | |
956 | self.assertEqual(proxy.version, 0x02) | |
957 | self.assertEqual(proxy.command, 0x01) | |
958 | if v6: | |
959 | self.assertEqual(proxy.family, 0x02) | |
960 | else: | |
961 | self.assertEqual(proxy.family, 0x01) | |
962 | if not isTCP: | |
963 | self.assertEqual(proxy.protocol, 0x02) | |
964 | else: | |
965 | self.assertEqual(proxy.protocol, 0x01) | |
966 | self.assertGreater(proxy.contentLen, 0) | |
967 | ||
968 | self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload)) | |
969 | self.assertEqual(proxy.source, source) | |
970 | self.assertEqual(proxy.destination, destination) | |
971 | if sourcePort: | |
972 | self.assertEqual(proxy.sourcePort, sourcePort) | |
973 | if destinationPort: | |
974 | self.assertEqual(proxy.destinationPort, destinationPort) | |
975 | else: | |
976 | self.assertEqual(proxy.destinationPort, self._dnsDistPort) | |
977 | ||
978 | self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload)) | |
979 | proxy.values.sort() | |
980 | values.sort() | |
981 | self.assertEqual(proxy.values, values) | |
fda32c1c RG |
982 | |
983 | @classmethod | |
984 | def getDOHGetURL(cls, baseurl, query, rawQuery=False): | |
985 | if rawQuery: | |
986 | wire = query | |
987 | else: | |
988 | wire = query.to_wire() | |
989 | param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=') | |
990 | return baseurl + "?dns=" + param | |
991 | ||
992 | @classmethod | |
993 | def openDOHConnection(cls, port, caFile, timeout=2.0): | |
994 | conn = pycurl.Curl() | |
995 | conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2) | |
996 | ||
997 | conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message", | |
998 | "Accept: application/dns-message"]) | |
999 | return conn | |
1000 | ||
1001 | @classmethod | |
41f36765 | 1002 | def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None, conn=None): |
fda32c1c | 1003 | url = cls.getDOHGetURL(baseurl, query, rawQuery) |
c02b7e13 RG |
1004 | |
1005 | if not conn: | |
c02b7e13 RG |
1006 | conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout) |
1007 | # this means "really do HTTP/2, not HTTP/1 with Upgrade headers" | |
1008 | conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE) | |
1009 | ||
fda32c1c RG |
1010 | if useHTTPS: |
1011 | conn.setopt(pycurl.SSL_VERIFYPEER, 1) | |
1012 | conn.setopt(pycurl.SSL_VERIFYHOST, 2) | |
1013 | if caFile: | |
1014 | conn.setopt(pycurl.CAINFO, caFile) | |
1015 | ||
c02b7e13 RG |
1016 | response_headers = BytesIO() |
1017 | #conn.setopt(pycurl.VERBOSE, True) | |
1018 | conn.setopt(pycurl.URL, url) | |
1019 | conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)]) | |
1020 | ||
fda32c1c RG |
1021 | conn.setopt(pycurl.HTTPHEADER, customHeaders) |
1022 | conn.setopt(pycurl.HEADERFUNCTION, response_headers.write) | |
1023 | ||
1024 | if response: | |
1025 | if toQueue: | |
1026 | toQueue.put(response, True, timeout) | |
1027 | else: | |
1028 | cls._toResponderQueue.put(response, True, timeout) | |
1029 | ||
1030 | receivedQuery = None | |
1031 | message = None | |
1032 | cls._response_headers = '' | |
1033 | data = conn.perform_rb() | |
1034 | cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE) | |
1035 | if cls._rcode == 200 and not rawResponse: | |
1036 | message = dns.message.from_wire(data) | |
1037 | elif rawResponse: | |
1038 | message = data | |
1039 | ||
1040 | if useQueue: | |
1041 | if fromQueue: | |
1042 | if not fromQueue.empty(): | |
1043 | receivedQuery = fromQueue.get(True, timeout) | |
1044 | else: | |
1045 | if not cls._fromResponderQueue.empty(): | |
1046 | receivedQuery = cls._fromResponderQueue.get(True, timeout) | |
1047 | ||
1048 | cls._response_headers = response_headers.getvalue() | |
1049 | return (receivedQuery, message) | |
1050 | ||
1051 | @classmethod | |
1052 | def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True): | |
1053 | url = baseurl | |
1054 | conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout) | |
1055 | response_headers = BytesIO() | |
1056 | #conn.setopt(pycurl.VERBOSE, True) | |
1057 | conn.setopt(pycurl.URL, url) | |
1058 | conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)]) | |
7e8a05fa RG |
1059 | # this means "really do HTTP/2, not HTTP/1 with Upgrade headers" |
1060 | conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE) | |
fda32c1c RG |
1061 | if useHTTPS: |
1062 | conn.setopt(pycurl.SSL_VERIFYPEER, 1) | |
1063 | conn.setopt(pycurl.SSL_VERIFYHOST, 2) | |
1064 | if caFile: | |
1065 | conn.setopt(pycurl.CAINFO, caFile) | |
1066 | ||
1067 | conn.setopt(pycurl.HTTPHEADER, customHeaders) | |
1068 | conn.setopt(pycurl.HEADERFUNCTION, response_headers.write) | |
1069 | conn.setopt(pycurl.POST, True) | |
1070 | data = query | |
1071 | if not rawQuery: | |
1072 | data = data.to_wire() | |
1073 | ||
1074 | conn.setopt(pycurl.POSTFIELDS, data) | |
1075 | ||
1076 | if response: | |
1077 | cls._toResponderQueue.put(response, True, timeout) | |
1078 | ||
1079 | receivedQuery = None | |
1080 | message = None | |
1081 | cls._response_headers = '' | |
1082 | data = conn.perform_rb() | |
1083 | cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE) | |
1084 | if cls._rcode == 200 and not rawResponse: | |
1085 | message = dns.message.from_wire(data) | |
1086 | elif rawResponse: | |
1087 | message = data | |
1088 | ||
1089 | if useQueue and not cls._fromResponderQueue.empty(): | |
1090 | receivedQuery = cls._fromResponderQueue.get(True, timeout) | |
1091 | ||
1092 | cls._response_headers = response_headers.getvalue() | |
1093 | return (receivedQuery, message) | |
334e3549 | 1094 | |
41f36765 RG |
1095 | def sendDOHQueryWrapper(self, query, response, useQueue=True): |
1096 | return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue) | |
334e3549 | 1097 | |
2f4ac048 RG |
1098 | def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True): |
1099 | return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue) | |
1100 | ||
1101 | def sendDOHWithH2OQueryWrapper(self, query, response, useQueue=True): | |
1102 | return self.sendDOHQuery(self._dohWithH2OServerPort, self._serverName, self._dohWithH2OBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue) | |
1103 | ||
334e3549 RG |
1104 | def sendDOTQueryWrapper(self, query, response, useQueue=True): |
1105 | return self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response, self._caCert, useQueue=useQueue) | |
3dc49a89 | 1106 | |
0a6676a4 RG |
1107 | def sendDOQQueryWrapper(self, query, response, useQueue=True): |
1108 | return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName) | |
1109 | ||
655fe34d CHB |
1110 | def sendDOH3QueryWrapper(self, query, response, useQueue=True): |
1111 | return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName) | |
3dc49a89 | 1112 | @classmethod |
0a6676a4 | 1113 | def getDOQConnection(cls, port, caFile=None, source=None, source_port=0): |
3dc49a89 CHB |
1114 | |
1115 | manager = dns.quic.SyncQuicManager( | |
1116 | verify_mode=caFile | |
1117 | ) | |
1118 | ||
0a6676a4 | 1119 | return manager.connect('127.0.0.1', port, source, source_port) |
3dc49a89 CHB |
1120 | |
1121 | @classmethod | |
0a6676a4 | 1122 | def sendDOQQuery(cls, port, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None): |
3dc49a89 CHB |
1123 | |
1124 | if response: | |
1125 | if toQueue: | |
1126 | toQueue.put(response, True, timeout) | |
1127 | else: | |
1128 | cls._toResponderQueue.put(response, True, timeout) | |
1129 | ||
9ec97c74 | 1130 | (message, _) = quic_query(query, '127.0.0.1', timeout, port, verify=caFile, server_hostname=serverName) |
3dc49a89 CHB |
1131 | |
1132 | receivedQuery = None | |
1133 | ||
1134 | if useQueue: | |
1135 | if fromQueue: | |
1136 | if not fromQueue.empty(): | |
1137 | receivedQuery = fromQueue.get(True, timeout) | |
1138 | else: | |
1139 | if not cls._fromResponderQueue.empty(): | |
1140 | receivedQuery = cls._fromResponderQueue.get(True, timeout) | |
1141 | ||
1142 | return (receivedQuery, message) | |
4f0b10a9 CHB |
1143 | |
1144 | @classmethod | |
d0439b42 | 1145 | def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False): |
4f0b10a9 CHB |
1146 | |
1147 | if response: | |
1148 | if toQueue: | |
1149 | toQueue.put(response, True, timeout) | |
1150 | else: | |
1151 | cls._toResponderQueue.put(response, True, timeout) | |
1152 | ||
d0439b42 | 1153 | message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post) |
4f0b10a9 CHB |
1154 | |
1155 | receivedQuery = None | |
1156 | ||
1157 | if useQueue: | |
1158 | if fromQueue: | |
1159 | if not fromQueue.empty(): | |
1160 | receivedQuery = fromQueue.get(True, timeout) | |
1161 | else: | |
1162 | if not cls._fromResponderQueue.empty(): | |
1163 | receivedQuery = cls._fromResponderQueue.get(True, timeout) | |
1164 | ||
1165 | return (receivedQuery, message) |