]>
Commit | Line | Data |
---|---|---|
ca404e94 RG |
1 | #!/usr/bin/env python2 |
2 | ||
fda32c1c | 3 | import base64 |
95f0b802 | 4 | import copy |
ffabdc3e | 5 | import errno |
ca404e94 RG |
6 | import os |
7 | import socket | |
a227f47d | 8 | import ssl |
ca404e94 RG |
9 | import struct |
10 | import subprocess | |
11 | import sys | |
12 | import threading | |
13 | import time | |
14 | import unittest | |
9d71a0cf | 15 | |
5df86a8a | 16 | import clientsubnetoption |
9d71a0cf | 17 | |
b1bec9f0 RG |
18 | import dns |
19 | import dns.message | |
9d71a0cf | 20 | |
1ea747c0 RG |
21 | import libnacl |
22 | import libnacl.utils | |
ca404e94 | 23 | |
9d71a0cf RG |
24 | import h2.connection |
25 | import h2.events | |
26 | import h2.config | |
27 | ||
fda32c1c RG |
28 | import pycurl |
29 | from io import BytesIO | |
30 | ||
e7000cce | 31 | from doqclient import quic_query |
4f0b10a9 | 32 | from doh3client import doh3_query |
e7000cce | 33 | |
6bd430bf | 34 | from eqdnsmessage import AssertEqualDNSMessageMixin |
0e6892c6 | 35 | from proxyprotocol import ProxyProtocol |
6bd430bf | 36 | |
b4f23783 | 37 | # Python2/3 compatibility hacks |
7a0ea291 | 38 | try: |
39 | from queue import Queue | |
40 | except ImportError: | |
b4f23783 | 41 | from Queue import Queue |
7a0ea291 | 42 | |
43 | try: | |
b4f23783 | 44 | range = xrange |
7a0ea291 | 45 | except NameError: |
46 | pass | |
b4f23783 | 47 | |
630eb526 RG |
48 | def getWorkerID(): |
49 | if not 'PYTEST_XDIST_WORKER' in os.environ: | |
50 | return 0 | |
51 | workerName = os.environ['PYTEST_XDIST_WORKER'] | |
52 | return int(workerName[2:]) | |
53 | ||
54 | workerPorts = {} | |
55 | ||
56 | def pickAvailablePort(): | |
57 | global workerPorts | |
58 | workerID = getWorkerID() | |
59 | if workerID in workerPorts: | |
60 | port = workerPorts[workerID] + 1 | |
61 | else: | |
62 | port = 11000 + (workerID * 1000) | |
63 | workerPorts[workerID] = port | |
64 | return port | |
b4f23783 | 65 | |
6bd430bf | 66 | class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): |
ca404e94 RG |
67 | """ |
68 | Set up a dnsdist instance and responder threads. | |
69 | Queries sent to dnsdist are relayed to the responder threads, | |
70 | who reply with the response provided by the tests themselves | |
71 | on a queue. Responder threads also queue the queries received | |
72 | from dnsdist on a separate queue, allowing the tests to check | |
73 | that the queries sent from dnsdist were as expected. | |
74 | """ | |
b052847c | 75 | _dnsDistListeningAddr = "127.0.0.1" |
b4f23783 CH |
76 | _toResponderQueue = Queue() |
77 | _fromResponderQueue = Queue() | |
617dfe22 | 78 | _queueTimeout = 1 |
ca404e94 | 79 | _dnsdist = None |
ec5f5c6b | 80 | _responsesCounter = {} |
18a0e7c6 | 81 | _config_template = """ |
18a0e7c6 CH |
82 | """ |
83 | _config_params = ['_testServerPort'] | |
84 | _acl = ['127.0.0.1/32'] | |
1ea747c0 | 85 | _consoleKey = None |
98650fde RG |
86 | _healthCheckName = 'a.root-servers.net.' |
87 | _healthCheckCounter = 0 | |
e44df0f1 | 88 | _answerUnexpected = True |
f73ce0e3 | 89 | _checkConfigExpectedOutput = None |
2a3cafcd | 90 | _verboseMode = False |
9bf515e1 | 91 | _sudoMode = False |
db7acdaf | 92 | _skipListeningOnCL = False |
1953ab6c RG |
93 | _alternateListeningAddr = None |
94 | _alternateListeningPort = None | |
7373e3a6 RG |
95 | _backgroundThreads = {} |
96 | _UDPResponder = None | |
97 | _TCPResponder = None | |
b38aaeb5 | 98 | _extraStartupSleep = 0 |
630eb526 RG |
99 | _dnsDistPort = pickAvailablePort() |
100 | _consolePort = pickAvailablePort() | |
101 | _testServerPort = pickAvailablePort() | |
ca404e94 | 102 | |
ffabdc3e OM |
103 | @classmethod |
104 | def waitForTCPSocket(cls, ipaddress, port): | |
105 | for try_number in range(0, 20): | |
106 | try: | |
107 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
108 | sock.settimeout(1.0) | |
109 | sock.connect((ipaddress, port)) | |
110 | sock.close() | |
111 | return | |
112 | except Exception as err: | |
113 | if err.errno != errno.ECONNREFUSED: | |
114 | print(f'Error occurred: {try_number} {err}', file=sys.stderr) | |
115 | time.sleep(0.1) | |
116 | # We assume the dnsdist instance does not listen. That's fine. | |
117 | ||
ca404e94 RG |
118 | @classmethod |
119 | def startResponders(cls): | |
120 | print("Launching responders..") | |
630eb526 | 121 | cls._testServerPort = pickAvailablePort() |
ec5f5c6b | 122 | |
5df86a8a | 123 | cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue]) |
630eb526 | 124 | cls._UDPResponder.daemon = True |
ca404e94 | 125 | cls._UDPResponder.start() |
5df86a8a | 126 | cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue]) |
630eb526 | 127 | cls._TCPResponder.daemon = True |
ca404e94 | 128 | cls._TCPResponder.start() |
d53e77f2 | 129 | cls.waitForTCPSocket("127.0.0.1", cls._testServerPort); |
ca404e94 RG |
130 | |
131 | @classmethod | |
aac2124b | 132 | def startDNSDist(cls): |
630eb526 RG |
133 | cls._dnsDistPort = pickAvailablePort() |
134 | cls._consolePort = pickAvailablePort() | |
135 | ||
ca404e94 | 136 | print("Launching dnsdist..") |
aac2124b | 137 | confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__)) |
18a0e7c6 CH |
138 | params = tuple([getattr(cls, param) for param in cls._config_params]) |
139 | print(params) | |
aac2124b | 140 | with open(confFile, 'w') as conf: |
18a0e7c6 | 141 | conf.write("-- Autogenerated by dnsdisttests.py\n") |
630eb526 | 142 | conf.write(f"-- dnsdist will listen on {cls._dnsDistPort}") |
18a0e7c6 | 143 | conf.write(cls._config_template % params) |
f3853a40 | 144 | conf.write("setSecurityPollSuffix('')") |
18a0e7c6 | 145 | |
db7acdaf RG |
146 | if cls._skipListeningOnCL: |
147 | dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile ] | |
148 | else: | |
149 | dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile, | |
150 | '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ] | |
151 | ||
2a3cafcd RG |
152 | if cls._verboseMode: |
153 | dnsdistcmd.append('-v') | |
9bf515e1 RG |
154 | if cls._sudoMode: |
155 | if 'LD_LIBRARY_PATH' in os.environ: | |
156 | dnsdistcmd.insert(0, 'LD_LIBRARY_PATH=' + os.environ['LD_LIBRARY_PATH']) | |
157 | dnsdistcmd.insert(0, 'sudo') | |
2a3cafcd | 158 | |
18a0e7c6 CH |
159 | for acl in cls._acl: |
160 | dnsdistcmd.extend(['--acl', acl]) | |
161 | print(' '.join(dnsdistcmd)) | |
162 | ||
6b44773a CH |
163 | # validate config with --check-config, which sets client=true, possibly exposing bugs. |
164 | testcmd = dnsdistcmd + ['--check-config'] | |
ff0bc6a6 JS |
165 | try: |
166 | output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True) | |
167 | except subprocess.CalledProcessError as exc: | |
168 | raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output)) | |
f73ce0e3 RG |
169 | if cls._checkConfigExpectedOutput is not None: |
170 | expectedOutput = cls._checkConfigExpectedOutput | |
171 | else: | |
172 | expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode() | |
2a3cafcd | 173 | if not cls._verboseMode and output != expectedOutput: |
630eb526 | 174 | raise AssertionError('dnsdist --check-config failed: %s (expected %s)' % (output, expectedOutput)) |
6b44773a | 175 | |
aac2124b RG |
176 | logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__)) |
177 | with open(logFile, 'w') as fdLog: | |
178 | cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog) | |
ca404e94 | 179 | |
1953ab6c RG |
180 | if cls._alternateListeningAddr and cls._alternateListeningPort: |
181 | cls.waitForTCPSocket(cls._alternateListeningAddr, cls._alternateListeningPort) | |
182 | else: | |
183 | cls.waitForTCPSocket(cls._dnsDistListeningAddr, cls._dnsDistPort) | |
ca404e94 RG |
184 | |
185 | if cls._dnsdist.poll() is not None: | |
ffabdc3e OM |
186 | print(f"\n*** startDNSDist log for {logFile} ***") |
187 | with open(logFile, 'r') as fdLog: | |
188 | print(fdLog.read()) | |
189 | print(f"*** End startDNSDist log for {logFile} ***") | |
190 | raise AssertionError('%s failed (%d)' % (dnsdistcmd, cls._dnsdist.returncode)) | |
b38aaeb5 | 191 | time.sleep(cls._extraStartupSleep) |
ca404e94 RG |
192 | |
193 | @classmethod | |
194 | def setUpSockets(cls): | |
195 | print("Setting up UDP socket..") | |
196 | cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
1ade83b2 | 197 | cls._sock.settimeout(2.0) |
ca404e94 RG |
198 | cls._sock.connect(("127.0.0.1", cls._dnsDistPort)) |
199 | ||
e4284d05 OM |
200 | @classmethod |
201 | def killProcess(cls, p): | |
202 | # Don't try to kill it if it's already dead | |
203 | if p.poll() is not None: | |
204 | return | |
205 | try: | |
206 | p.terminate() | |
6a51a279 | 207 | for count in range(20): |
e4284d05 OM |
208 | x = p.poll() |
209 | if x is not None: | |
210 | break | |
211 | time.sleep(0.1) | |
212 | if x is None: | |
213 | print("kill...", p, file=sys.stderr) | |
214 | p.kill() | |
215 | p.wait() | |
216 | except OSError as e: | |
217 | # There is a race-condition with the poll() and | |
218 | # kill() statements, when the process is dead on the | |
219 | # kill(), this is fine | |
220 | if e.errno != errno.ESRCH: | |
221 | raise | |
222 | ||
ca404e94 RG |
223 | @classmethod |
224 | def setUpClass(cls): | |
225 | ||
226 | cls.startResponders() | |
aac2124b | 227 | cls.startDNSDist() |
ca404e94 RG |
228 | cls.setUpSockets() |
229 | ||
230 | print("Launching tests..") | |
231 | ||
232 | @classmethod | |
233 | def tearDownClass(cls): | |
ffabdc3e OM |
234 | cls._sock.close() |
235 | # tell the background threads to stop, if any | |
236 | for backgroundThread in cls._backgroundThreads: | |
237 | cls._backgroundThreads[backgroundThread] = False | |
e4284d05 | 238 | cls.killProcess(cls._dnsdist) |
7373e3a6 | 239 | |
ca404e94 | 240 | @classmethod |
fe1c60f2 | 241 | def _ResponderIncrementCounter(cls): |
630eb526 RG |
242 | if threading.current_thread().name in cls._responsesCounter: |
243 | cls._responsesCounter[threading.current_thread().name] += 1 | |
ec5f5c6b | 244 | else: |
630eb526 | 245 | cls._responsesCounter[threading.current_thread().name] = 1 |
ec5f5c6b | 246 | |
fe1c60f2 | 247 | @classmethod |
4aa08b62 | 248 | def _getResponse(cls, request, fromQueue, toQueue, synthesize=None): |
fe1c60f2 RG |
249 | response = None |
250 | if len(request.question) != 1: | |
251 | print("Skipping query with question count %d" % (len(request.question))) | |
252 | return None | |
98650fde RG |
253 | healthCheck = str(request.question[0].name).endswith(cls._healthCheckName) |
254 | if healthCheck: | |
255 | cls._healthCheckCounter += 1 | |
4aa08b62 | 256 | response = dns.message.make_response(request) |
98650fde | 257 | else: |
fe1c60f2 | 258 | cls._ResponderIncrementCounter() |
5df86a8a | 259 | if not fromQueue.empty(): |
4aa08b62 | 260 | toQueue.put(request, True, cls._queueTimeout) |
90186270 RG |
261 | response = fromQueue.get(True, cls._queueTimeout) |
262 | if response: | |
263 | response = copy.copy(response) | |
264 | response.id = request.id | |
265 | ||
266 | if synthesize is not None: | |
267 | response = dns.message.make_response(request) | |
268 | response.set_rcode(synthesize) | |
fe1c60f2 | 269 | |
e44df0f1 | 270 | if not response: |
90186270 | 271 | if cls._answerUnexpected: |
e44df0f1 RG |
272 | response = dns.message.make_response(request) |
273 | response.set_rcode(dns.rcode.SERVFAIL) | |
fe1c60f2 RG |
274 | |
275 | return response | |
276 | ||
ec5f5c6b | 277 | @classmethod |
a620f197 | 278 | def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None): |
7373e3a6 | 279 | cls._backgroundThreads[threading.get_native_id()] = True |
3ef7ab0d RG |
280 | # trailingDataResponse=True means "ignore trailing data". |
281 | # Other values are either False (meaning "raise an exception") | |
282 | # or are interpreted as a response RCODE for queries with trailing data. | |
a620f197 | 283 | # callback is invoked for every -even healthcheck ones- query and should return a raw response |
4aa08b62 | 284 | ignoreTrailing = trailingDataResponse is True |
3ef7ab0d | 285 | |
ca404e94 RG |
286 | sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) |
287 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | |
ec5f5c6b | 288 | sock.bind(("127.0.0.1", port)) |
28d4c42d | 289 | sock.settimeout(0.5) |
ca404e94 | 290 | while True: |
7373e3a6 RG |
291 | try: |
292 | data, addr = sock.recvfrom(4096) | |
293 | except socket.timeout: | |
294 | if cls._backgroundThreads.get(threading.get_native_id(), False) == False: | |
295 | del cls._backgroundThreads[threading.get_native_id()] | |
296 | break | |
297 | else: | |
298 | continue | |
299 | ||
4aa08b62 RG |
300 | forceRcode = None |
301 | try: | |
302 | request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) | |
303 | except dns.message.TrailingJunk as e: | |
51f07ad4 | 304 | print('trailing data exception in UDPResponder') |
3ef7ab0d | 305 | if trailingDataResponse is False or forceRcode is True: |
4aa08b62 RG |
306 | raise |
307 | print("UDP query with trailing data, synthesizing response") | |
308 | request = dns.message.from_wire(data, ignore_trailing=True) | |
309 | forceRcode = trailingDataResponse | |
310 | ||
f3913dd2 | 311 | wire = None |
a620f197 RG |
312 | if callback: |
313 | wire = callback(request) | |
314 | else: | |
f8662974 RG |
315 | if request.edns > 1: |
316 | forceRcode = dns.rcode.BADVERS | |
a620f197 RG |
317 | response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) |
318 | if response: | |
319 | wire = response.to_wire() | |
87c605c4 | 320 | |
f3913dd2 RG |
321 | if not wire: |
322 | continue | |
323 | ||
a620f197 | 324 | sock.sendto(wire, addr) |
7373e3a6 | 325 | |
ca404e94 RG |
326 | sock.close() |
327 | ||
328 | @classmethod | |
e572dbf5 | 329 | def handleTCPConnection(cls, conn, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, partialWrite=False): |
645a1ca4 | 330 | ignoreTrailing = trailingDataResponse is True |
fabe8e3a RG |
331 | try: |
332 | data = conn.recv(2) | |
333 | except Exception as err: | |
334 | data = None | |
335 | print(f'Error while reading query size in TCP responder thread {err=}, {type(err)=}') | |
645a1ca4 RG |
336 | if not data: |
337 | conn.close() | |
338 | return | |
339 | ||
340 | (datalen,) = struct.unpack("!H", data) | |
341 | data = conn.recv(datalen) | |
342 | forceRcode = None | |
343 | try: | |
344 | request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) | |
345 | except dns.message.TrailingJunk as e: | |
346 | if trailingDataResponse is False or forceRcode is True: | |
347 | raise | |
348 | print("TCP query with trailing data, synthesizing response") | |
349 | request = dns.message.from_wire(data, ignore_trailing=True) | |
350 | forceRcode = trailingDataResponse | |
351 | ||
352 | if callback: | |
353 | wire = callback(request) | |
354 | else: | |
f8662974 RG |
355 | if request.edns > 1: |
356 | forceRcode = dns.rcode.BADVERS | |
645a1ca4 RG |
357 | response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) |
358 | if response: | |
359 | wire = response.to_wire(max_size=65535) | |
360 | ||
361 | if not wire: | |
362 | conn.close() | |
363 | return | |
364 | ||
e572dbf5 RG |
365 | wireLen = struct.pack("!H", len(wire)) |
366 | if partialWrite: | |
367 | for b in wireLen: | |
368 | conn.send(bytes([b])) | |
369 | time.sleep(0.5) | |
370 | else: | |
371 | conn.send(wireLen) | |
645a1ca4 RG |
372 | conn.send(wire) |
373 | ||
374 | while multipleResponses: | |
936dd73c RG |
375 | # do not block, and stop as soon as the queue is empty, either the next response is already here or we are done |
376 | # otherwise we might read responses intended for the next connection | |
645a1ca4 RG |
377 | if fromQueue.empty(): |
378 | break | |
379 | ||
936dd73c | 380 | response = fromQueue.get(False) |
645a1ca4 RG |
381 | if not response: |
382 | break | |
383 | ||
384 | response = copy.copy(response) | |
385 | response.id = request.id | |
386 | wire = response.to_wire(max_size=65535) | |
387 | try: | |
388 | conn.send(struct.pack("!H", len(wire))) | |
389 | conn.send(wire) | |
390 | except socket.error as e: | |
391 | # some of the tests are going to close | |
392 | # the connection on us, just deal with it | |
393 | break | |
394 | ||
395 | conn.close() | |
396 | ||
397 | @classmethod | |
e572dbf5 | 398 | def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, multipleConnections=False, listeningAddr='127.0.0.1', partialWrite=False): |
7373e3a6 | 399 | cls._backgroundThreads[threading.get_native_id()] = True |
3ef7ab0d RG |
400 | # trailingDataResponse=True means "ignore trailing data". |
401 | # Other values are either False (meaning "raise an exception") | |
402 | # or are interpreted as a response RCODE for queries with trailing data. | |
a620f197 | 403 | # callback is invoked for every -even healthcheck ones- query and should return a raw response |
3ef7ab0d | 404 | |
ca404e94 | 405 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
501af9ae | 406 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
ca404e94 RG |
407 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) |
408 | try: | |
144eebeb | 409 | sock.bind((listeningAddr, port)) |
ca404e94 RG |
410 | except socket.error as e: |
411 | print("Error binding in the TCP responder: %s" % str(e)) | |
412 | sys.exit(1) | |
413 | ||
414 | sock.listen(100) | |
28d4c42d | 415 | sock.settimeout(0.5) |
7d2856a6 RG |
416 | if tlsContext: |
417 | sock = tlsContext.wrap_socket(sock, server_side=True) | |
418 | ||
ca404e94 | 419 | while True: |
7d2856a6 RG |
420 | try: |
421 | (conn, _) = sock.accept() | |
422 | except ssl.SSLError: | |
423 | continue | |
ae3b96d9 RG |
424 | except ConnectionResetError: |
425 | continue | |
7373e3a6 RG |
426 | except socket.timeout: |
427 | if cls._backgroundThreads.get(threading.get_native_id(), False) == False: | |
428 | del cls._backgroundThreads[threading.get_native_id()] | |
429 | break | |
430 | else: | |
431 | continue | |
ae3b96d9 | 432 | |
6ac8517d | 433 | conn.settimeout(5.0) |
645a1ca4 RG |
434 | if multipleConnections: |
435 | thread = threading.Thread(name='TCP Connection Handler', | |
436 | target=cls.handleTCPConnection, | |
e572dbf5 | 437 | args=[conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, partialWrite]) |
630eb526 | 438 | thread.daemon = True |
645a1ca4 | 439 | thread.start() |
a620f197 | 440 | else: |
e572dbf5 | 441 | cls.handleTCPConnection(conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, partialWrite) |
548c8b66 | 442 | |
ca404e94 RG |
443 | sock.close() |
444 | ||
c4c72a2c RG |
445 | @classmethod |
446 | def handleDoHConnection(cls, config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol): | |
447 | ignoreTrailing = trailingDataResponse is True | |
144eebeb RG |
448 | try: |
449 | h2conn = h2.connection.H2Connection(config=config) | |
450 | h2conn.initiate_connection() | |
451 | conn.sendall(h2conn.data_to_send()) | |
452 | except ssl.SSLEOFError as e: | |
453 | print("Unexpected EOF: %s" % (e)) | |
454 | return | |
fabe8e3a RG |
455 | except Exception as err: |
456 | print(f'Unexpected exception in DoH responder thread (connection init) {err=}, {type(err)=}') | |
457 | return | |
144eebeb | 458 | |
c4c72a2c RG |
459 | dnsData = {} |
460 | ||
461 | if useProxyProtocol: | |
462 | # try to read the entire Proxy Protocol header | |
463 | proxy = ProxyProtocol() | |
464 | header = conn.recv(proxy.HEADER_SIZE) | |
465 | if not header: | |
466 | print('unable to get header') | |
467 | conn.close() | |
468 | return | |
469 | ||
470 | if not proxy.parseHeader(header): | |
471 | print('unable to parse header') | |
472 | print(header) | |
473 | conn.close() | |
474 | return | |
475 | ||
476 | proxyContent = conn.recv(proxy.contentLen) | |
477 | if not proxyContent: | |
478 | print('unable to get content') | |
479 | conn.close() | |
480 | return | |
481 | ||
482 | payload = header + proxyContent | |
483 | toQueue.put(payload, True, cls._queueTimeout) | |
484 | ||
485 | # be careful, HTTP/2 headers and data might be in different recv() results | |
486 | requestHeaders = None | |
487 | while True: | |
fabe8e3a RG |
488 | try: |
489 | data = conn.recv(65535) | |
490 | except Exception as err: | |
491 | data = None | |
492 | print(f'Unexpected exception in DoH responder thread {err=}, {type(err)=}') | |
c4c72a2c RG |
493 | if not data: |
494 | break | |
495 | ||
496 | events = h2conn.receive_data(data) | |
497 | for event in events: | |
498 | if isinstance(event, h2.events.RequestReceived): | |
499 | requestHeaders = event.headers | |
500 | if isinstance(event, h2.events.DataReceived): | |
501 | h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id) | |
502 | if not event.stream_id in dnsData: | |
503 | dnsData[event.stream_id] = b'' | |
504 | dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data) | |
505 | if event.stream_ended: | |
506 | forceRcode = None | |
507 | status = 200 | |
508 | try: | |
509 | request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing) | |
510 | except dns.message.TrailingJunk as e: | |
511 | if trailingDataResponse is False or forceRcode is True: | |
512 | raise | |
513 | print("DOH query with trailing data, synthesizing response") | |
514 | request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True) | |
515 | forceRcode = trailingDataResponse | |
516 | ||
517 | if callback: | |
518 | status, wire = callback(request, requestHeaders, fromQueue, toQueue) | |
519 | else: | |
520 | response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) | |
521 | if response: | |
522 | wire = response.to_wire(max_size=65535) | |
523 | ||
524 | if not wire: | |
525 | conn.close() | |
526 | conn = None | |
527 | break | |
528 | ||
529 | headers = [ | |
530 | (':status', str(status)), | |
531 | ('content-length', str(len(wire))), | |
532 | ('content-type', 'application/dns-message'), | |
533 | ] | |
534 | h2conn.send_headers(stream_id=event.stream_id, headers=headers) | |
535 | h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True) | |
536 | ||
537 | data_to_send = h2conn.data_to_send() | |
538 | if data_to_send: | |
539 | conn.sendall(data_to_send) | |
540 | ||
541 | if conn is None: | |
542 | break | |
543 | ||
544 | if conn is not None: | |
545 | conn.close() | |
546 | ||
9d71a0cf | 547 | @classmethod |
0e6892c6 | 548 | def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False): |
7373e3a6 | 549 | cls._backgroundThreads[threading.get_native_id()] = True |
9d71a0cf RG |
550 | # trailingDataResponse=True means "ignore trailing data". |
551 | # Other values are either False (meaning "raise an exception") | |
552 | # or are interpreted as a response RCODE for queries with trailing data. | |
553 | # callback is invoked for every -even healthcheck ones- query and should return a raw response | |
9d71a0cf RG |
554 | |
555 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
556 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) | |
557 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | |
558 | try: | |
559 | sock.bind(("127.0.0.1", port)) | |
560 | except socket.error as e: | |
561 | print("Error binding in the TCP responder: %s" % str(e)) | |
562 | sys.exit(1) | |
563 | ||
564 | sock.listen(100) | |
28d4c42d | 565 | sock.settimeout(0.5) |
9d71a0cf RG |
566 | if tlsContext: |
567 | sock = tlsContext.wrap_socket(sock, server_side=True) | |
568 | ||
569 | config = h2.config.H2Configuration(client_side=False) | |
570 | ||
571 | while True: | |
572 | try: | |
573 | (conn, _) = sock.accept() | |
574 | except ssl.SSLError: | |
575 | continue | |
576 | except ConnectionResetError: | |
577 | continue | |
7373e3a6 RG |
578 | except socket.timeout: |
579 | if cls._backgroundThreads.get(threading.get_native_id(), False) == False: | |
580 | del cls._backgroundThreads[threading.get_native_id()] | |
581 | break | |
582 | else: | |
583 | continue | |
ae3b96d9 | 584 | |
9d71a0cf | 585 | conn.settimeout(5.0) |
c4c72a2c RG |
586 | thread = threading.Thread(name='DoH Connection Handler', |
587 | target=cls.handleDoHConnection, | |
588 | args=[config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol]) | |
630eb526 | 589 | thread.daemon = True |
c4c72a2c | 590 | thread.start() |
9d71a0cf RG |
591 | |
592 | sock.close() | |
593 | ||
ca404e94 | 594 | @classmethod |
55baa1f2 | 595 | def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False): |
90186270 | 596 | if useQueue and response is not None: |
617dfe22 | 597 | cls._toResponderQueue.put(response, True, timeout) |
ca404e94 RG |
598 | |
599 | if timeout: | |
600 | cls._sock.settimeout(timeout) | |
601 | ||
602 | try: | |
55baa1f2 RG |
603 | if not rawQuery: |
604 | query = query.to_wire() | |
605 | cls._sock.send(query) | |
ca404e94 | 606 | data = cls._sock.recv(4096) |
b1bec9f0 | 607 | except socket.timeout: |
ca404e94 RG |
608 | data = None |
609 | finally: | |
610 | if timeout: | |
611 | cls._sock.settimeout(None) | |
612 | ||
613 | receivedQuery = None | |
614 | message = None | |
615 | if useQueue and not cls._fromResponderQueue.empty(): | |
617dfe22 | 616 | receivedQuery = cls._fromResponderQueue.get(True, timeout) |
ca404e94 RG |
617 | if data: |
618 | message = dns.message.from_wire(data) | |
619 | return (receivedQuery, message) | |
620 | ||
621 | @classmethod | |
db7acdaf | 622 | def openTCPConnection(cls, timeout=None, port=None): |
ca404e94 | 623 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
501af9ae | 624 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
ca404e94 RG |
625 | if timeout: |
626 | sock.settimeout(timeout) | |
627 | ||
db7acdaf RG |
628 | if not port: |
629 | port = cls._dnsDistPort | |
630 | ||
631 | sock.connect(("127.0.0.1", port)) | |
9396d955 | 632 | return sock |
0a2087eb | 633 | |
9396d955 | 634 | @classmethod |
7e8a05fa | 635 | def openTLSConnection(cls, port, serverName, caCert=None, timeout=None, alpn=[]): |
a227f47d | 636 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
501af9ae | 637 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
a227f47d RG |
638 | if timeout: |
639 | sock.settimeout(timeout) | |
640 | ||
641 | # 2.7.9+ | |
642 | if hasattr(ssl, 'create_default_context'): | |
643 | sslctx = ssl.create_default_context(cafile=caCert) | |
7e8a05fa RG |
644 | if len(alpn)> 0 and hasattr(sslctx, 'set_alpn_protocols'): |
645 | sslctx.set_alpn_protocols(alpn) | |
a227f47d RG |
646 | sslsock = sslctx.wrap_socket(sock, server_hostname=serverName) |
647 | else: | |
648 | sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED) | |
649 | ||
650 | sslsock.connect(("127.0.0.1", port)) | |
651 | return sslsock | |
652 | ||
653 | @classmethod | |
654 | def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0): | |
9396d955 RG |
655 | if not rawQuery: |
656 | wire = query.to_wire() | |
657 | else: | |
658 | wire = query | |
55baa1f2 | 659 | |
a227f47d RG |
660 | if response: |
661 | cls._toResponderQueue.put(response, True, timeout) | |
662 | ||
9396d955 RG |
663 | sock.send(struct.pack("!H", len(wire))) |
664 | sock.send(wire) | |
665 | ||
666 | @classmethod | |
a227f47d | 667 | def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0): |
9396d955 RG |
668 | message = None |
669 | data = sock.recv(2) | |
670 | if data: | |
671 | (datalen,) = struct.unpack("!H", data) | |
f05cd66c | 672 | print(datalen) |
9396d955 | 673 | data = sock.recv(datalen) |
ca404e94 | 674 | if data: |
f05cd66c | 675 | print(data) |
9396d955 | 676 | message = dns.message.from_wire(data) |
a227f47d | 677 | |
f05cd66c | 678 | print(useQueue) |
a227f47d RG |
679 | if useQueue and not cls._fromResponderQueue.empty(): |
680 | receivedQuery = cls._fromResponderQueue.get(True, timeout) | |
f05cd66c | 681 | print(receivedQuery) |
a227f47d RG |
682 | return (receivedQuery, message) |
683 | else: | |
f05cd66c | 684 | print("queue empty") |
a227f47d | 685 | return message |
9396d955 | 686 | |
fda32c1c RG |
687 | @classmethod |
688 | def sendDOTQuery(cls, port, serverName, query, response, caFile, useQueue=True): | |
689 | conn = cls.openTLSConnection(port, serverName, caFile) | |
690 | cls.sendTCPQueryOverConnection(conn, query, response=response) | |
691 | if useQueue: | |
692 | return cls.recvTCPResponseOverConnection(conn, useQueue=useQueue) | |
693 | return None, cls.recvTCPResponseOverConnection(conn, useQueue=useQueue) | |
694 | ||
9396d955 RG |
695 | @classmethod |
696 | def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False): | |
697 | message = None | |
698 | if useQueue: | |
699 | cls._toResponderQueue.put(response, True, timeout) | |
700 | ||
9bf515e1 RG |
701 | try: |
702 | sock = cls.openTCPConnection(timeout) | |
703 | except socket.timeout as e: | |
704 | print("Timeout while opening TCP connection: %s" % (str(e))) | |
705 | return (None, None) | |
9396d955 RG |
706 | |
707 | try: | |
708 | cls.sendTCPQueryOverConnection(sock, query, rawQuery) | |
709 | message = cls.recvTCPResponseOverConnection(sock) | |
ca404e94 | 710 | except socket.timeout as e: |
0e6892c6 | 711 | print("Timeout while sending or receiving TCP data: %s" % (str(e))) |
ca404e94 RG |
712 | except socket.error as e: |
713 | print("Network error: %s" % (str(e))) | |
ca404e94 RG |
714 | finally: |
715 | sock.close() | |
716 | ||
717 | receivedQuery = None | |
f05cd66c | 718 | print(useQueue) |
ca404e94 | 719 | if useQueue and not cls._fromResponderQueue.empty(): |
f05cd66c | 720 | print(receivedQuery) |
617dfe22 | 721 | receivedQuery = cls._fromResponderQueue.get(True, timeout) |
f05cd66c RG |
722 | else: |
723 | print("queue is empty") | |
9396d955 | 724 | |
ca404e94 | 725 | return (receivedQuery, message) |
617dfe22 | 726 | |
548c8b66 RG |
727 | @classmethod |
728 | def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False): | |
729 | if useQueue: | |
730 | for response in responses: | |
731 | cls._toResponderQueue.put(response, True, timeout) | |
732 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
501af9ae | 733 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
548c8b66 RG |
734 | if timeout: |
735 | sock.settimeout(timeout) | |
736 | ||
737 | sock.connect(("127.0.0.1", cls._dnsDistPort)) | |
738 | messages = [] | |
739 | ||
740 | try: | |
741 | if not rawQuery: | |
742 | wire = query.to_wire() | |
743 | else: | |
744 | wire = query | |
745 | ||
746 | sock.send(struct.pack("!H", len(wire))) | |
747 | sock.send(wire) | |
748 | while True: | |
749 | data = sock.recv(2) | |
750 | if not data: | |
751 | break | |
752 | (datalen,) = struct.unpack("!H", data) | |
753 | data = sock.recv(datalen) | |
754 | messages.append(dns.message.from_wire(data)) | |
755 | ||
756 | except socket.timeout as e: | |
0e6892c6 | 757 | print("Timeout while receiving multiple TCP responses: %s" % (str(e))) |
548c8b66 RG |
758 | except socket.error as e: |
759 | print("Network error: %s" % (str(e))) | |
760 | finally: | |
761 | sock.close() | |
762 | ||
763 | receivedQuery = None | |
764 | if useQueue and not cls._fromResponderQueue.empty(): | |
765 | receivedQuery = cls._fromResponderQueue.get(True, timeout) | |
766 | return (receivedQuery, messages) | |
767 | ||
617dfe22 | 768 | def setUp(self): |
936dd73c | 769 | # This function is called before every test |
617dfe22 RG |
770 | |
771 | # Clear the responses counters | |
fd71df4e | 772 | self._responsesCounter.clear() |
617dfe22 | 773 | |
98650fde RG |
774 | self._healthCheckCounter = 0 |
775 | ||
617dfe22 RG |
776 | # Make sure the queues are empty, in case |
777 | # a previous test failed | |
936dd73c | 778 | self.clearResponderQueues() |
1ea747c0 | 779 | |
6bd430bf PD |
780 | super(DNSDistTest, self).setUp() |
781 | ||
3bef39c3 RG |
782 | @classmethod |
783 | def clearToResponderQueue(cls): | |
784 | while not cls._toResponderQueue.empty(): | |
785 | cls._toResponderQueue.get(False) | |
786 | ||
787 | @classmethod | |
788 | def clearFromResponderQueue(cls): | |
789 | while not cls._fromResponderQueue.empty(): | |
790 | cls._fromResponderQueue.get(False) | |
791 | ||
792 | @classmethod | |
793 | def clearResponderQueues(cls): | |
794 | cls.clearToResponderQueue() | |
795 | cls.clearFromResponderQueue() | |
796 | ||
1ea747c0 RG |
797 | @staticmethod |
798 | def generateConsoleKey(): | |
799 | return libnacl.utils.salsa_key() | |
800 | ||
801 | @classmethod | |
802 | def _encryptConsole(cls, command, nonce): | |
b4f23783 | 803 | command = command.encode('UTF-8') |
1ea747c0 RG |
804 | if cls._consoleKey is None: |
805 | return command | |
806 | return libnacl.crypto_secretbox(command, nonce, cls._consoleKey) | |
807 | ||
808 | @classmethod | |
809 | def _decryptConsole(cls, command, nonce): | |
810 | if cls._consoleKey is None: | |
b4f23783 CH |
811 | result = command |
812 | else: | |
813 | result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey) | |
814 | return result.decode('UTF-8') | |
1ea747c0 RG |
815 | |
816 | @classmethod | |
7d862cb3 | 817 | def sendConsoleCommand(cls, command, timeout=5.0): |
1ea747c0 RG |
818 | ourNonce = libnacl.utils.rand_nonce() |
819 | theirNonce = None | |
820 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
501af9ae | 821 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
1ea747c0 RG |
822 | if timeout: |
823 | sock.settimeout(timeout) | |
824 | ||
825 | sock.connect(("127.0.0.1", cls._consolePort)) | |
826 | sock.send(ourNonce) | |
827 | theirNonce = sock.recv(len(ourNonce)) | |
7b925432 | 828 | if len(theirNonce) != len(ourNonce): |
05a5b575 | 829 | print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce))) |
bdfa6902 RG |
830 | if len(theirNonce) == 0: |
831 | raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce))) | |
7b925432 | 832 | return None |
1ea747c0 | 833 | |
b4f23783 | 834 | halfNonceSize = int(len(ourNonce) / 2) |
333ea16e RG |
835 | readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:] |
836 | writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:] | |
333ea16e | 837 | msg = cls._encryptConsole(command, writingNonce) |
1ea747c0 RG |
838 | sock.send(struct.pack("!I", len(msg))) |
839 | sock.send(msg) | |
840 | data = sock.recv(4) | |
9c9b4998 RG |
841 | if not data: |
842 | raise socket.error("Got EOF while reading the response size") | |
843 | ||
1ea747c0 RG |
844 | (responseLen,) = struct.unpack("!I", data) |
845 | data = sock.recv(responseLen) | |
333ea16e | 846 | response = cls._decryptConsole(data, readingNonce) |
75b536de | 847 | sock.close() |
1ea747c0 | 848 | return response |
5df86a8a RG |
849 | |
850 | def compareOptions(self, a, b): | |
4bfebc93 | 851 | self.assertEqual(len(a), len(b)) |
b4f23783 | 852 | for idx in range(len(a)): |
4bfebc93 | 853 | self.assertEqual(a[idx], b[idx]) |
5df86a8a RG |
854 | |
855 | def checkMessageNoEDNS(self, expected, received): | |
4bfebc93 CH |
856 | self.assertEqual(expected, received) |
857 | self.assertEqual(received.edns, -1) | |
858 | self.assertEqual(len(received.options), 0) | |
5df86a8a | 859 | |
e7c732b8 | 860 | def checkMessageEDNSWithoutOptions(self, expected, received): |
4bfebc93 CH |
861 | self.assertEqual(expected, received) |
862 | self.assertEqual(received.edns, 0) | |
863 | self.assertEqual(expected.payload, received.payload) | |
e7c732b8 | 864 | |
5df86a8a | 865 | def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0): |
4bfebc93 CH |
866 | self.assertEqual(expected, received) |
867 | self.assertEqual(received.edns, 0) | |
868 | self.assertEqual(expected.payload, received.payload) | |
869 | self.assertEqual(len(received.options), withCookies) | |
5df86a8a RG |
870 | if withCookies: |
871 | for option in received.options: | |
4bfebc93 | 872 | self.assertEqual(option.otype, 10) |
be90d6bd RG |
873 | else: |
874 | for option in received.options: | |
4bfebc93 | 875 | self.assertNotEqual(option.otype, 10) |
5df86a8a | 876 | |
cbf4e13a | 877 | def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0): |
4bfebc93 CH |
878 | self.assertEqual(expected, received) |
879 | self.assertEqual(received.edns, 0) | |
880 | self.assertEqual(expected.payload, received.payload) | |
881 | self.assertEqual(len(received.options), 1 + additionalOptions) | |
cbf4e13a RG |
882 | hasECS = False |
883 | for option in received.options: | |
884 | if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE: | |
885 | hasECS = True | |
886 | else: | |
4bfebc93 | 887 | self.assertNotEqual(additionalOptions, 0) |
cbf4e13a | 888 | |
5df86a8a | 889 | self.compareOptions(expected.options, received.options) |
cbf4e13a | 890 | self.assertTrue(hasECS) |
5df86a8a | 891 | |
346410cd CHB |
892 | def checkMessageEDNS(self, expected, received): |
893 | self.assertEqual(expected, received) | |
894 | self.assertEqual(received.edns, 0) | |
895 | self.assertEqual(expected.payload, received.payload) | |
896 | self.assertEqual(len(expected.options), len(received.options)) | |
897 | self.compareOptions(expected.options, received.options) | |
898 | ||
cbf4e13a RG |
899 | def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0): |
900 | self.checkMessageEDNSWithECS(expected, received, additionalOptions) | |
5df86a8a | 901 | |
346410cd CHB |
902 | def checkQueryEDNS(self, expected, received): |
903 | self.checkMessageEDNS(expected, received) | |
904 | ||
cbf4e13a RG |
905 | def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0): |
906 | self.checkMessageEDNSWithECS(expected, received, additionalOptions) | |
5df86a8a RG |
907 | |
908 | def checkQueryEDNSWithoutECS(self, expected, received): | |
909 | self.checkMessageEDNSWithoutECS(expected, received) | |
910 | ||
911 | def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0): | |
912 | self.checkMessageEDNSWithoutECS(expected, received, withCookies) | |
913 | ||
914 | def checkQueryNoEDNS(self, expected, received): | |
915 | self.checkMessageNoEDNS(expected, received) | |
916 | ||
917 | def checkResponseNoEDNS(self, expected, received): | |
918 | self.checkMessageNoEDNS(expected, received) | |
b0d08f82 | 919 | |
58ae5410 RG |
920 | @staticmethod |
921 | def generateNewCertificateAndKey(filePrefix): | |
1d896c34 | 922 | # generate and sign a new cert |
58ae5410 | 923 | cmd = ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', filePrefix + '.key', '-out', filePrefix + '.csr', '-config', 'configServer.conf'] |
1d896c34 RG |
924 | output = None |
925 | try: | |
926 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True) | |
927 | output = process.communicate(input='') | |
928 | except subprocess.CalledProcessError as exc: | |
929 | raise AssertionError('openssl req failed (%d): %s' % (exc.returncode, exc.output)) | |
58ae5410 | 930 | cmd = ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', filePrefix + '.csr', '-out', filePrefix + '.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req'] |
1d896c34 RG |
931 | output = None |
932 | try: | |
933 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True) | |
934 | output = process.communicate(input='') | |
935 | except subprocess.CalledProcessError as exc: | |
936 | raise AssertionError('openssl x509 failed (%d): %s' % (exc.returncode, exc.output)) | |
937 | ||
58ae5410 RG |
938 | with open(filePrefix + '.chain', 'w') as outFile: |
939 | for inFileName in [filePrefix + '.pem', 'ca.pem']: | |
1d896c34 RG |
940 | with open(inFileName) as inFile: |
941 | outFile.write(inFile.read()) | |
0e6892c6 | 942 | |
58ae5410 | 943 | cmd = ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', filePrefix + '.pem', '-CAfile', 'ca.pem', '-inkey', filePrefix + '.key', '-out', filePrefix + '.p12'] |
5ac11505 CHB |
944 | output = None |
945 | try: | |
946 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True) | |
947 | output = process.communicate(input='') | |
948 | except subprocess.CalledProcessError as exc: | |
949 | raise AssertionError('openssl pkcs12 failed (%d): %s' % (exc.returncode, exc.output)) | |
950 | ||
0e6892c6 RG |
951 | def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[], v6=False, sourcePort=None, destinationPort=None): |
952 | proxy = ProxyProtocol() | |
953 | self.assertTrue(proxy.parseHeader(receivedProxyPayload)) | |
954 | self.assertEqual(proxy.version, 0x02) | |
955 | self.assertEqual(proxy.command, 0x01) | |
956 | if v6: | |
957 | self.assertEqual(proxy.family, 0x02) | |
958 | else: | |
959 | self.assertEqual(proxy.family, 0x01) | |
960 | if not isTCP: | |
961 | self.assertEqual(proxy.protocol, 0x02) | |
962 | else: | |
963 | self.assertEqual(proxy.protocol, 0x01) | |
964 | self.assertGreater(proxy.contentLen, 0) | |
965 | ||
966 | self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload)) | |
967 | self.assertEqual(proxy.source, source) | |
968 | self.assertEqual(proxy.destination, destination) | |
969 | if sourcePort: | |
970 | self.assertEqual(proxy.sourcePort, sourcePort) | |
971 | if destinationPort: | |
972 | self.assertEqual(proxy.destinationPort, destinationPort) | |
973 | else: | |
974 | self.assertEqual(proxy.destinationPort, self._dnsDistPort) | |
975 | ||
976 | self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload)) | |
977 | proxy.values.sort() | |
978 | values.sort() | |
979 | self.assertEqual(proxy.values, values) | |
fda32c1c RG |
980 | |
981 | @classmethod | |
982 | def getDOHGetURL(cls, baseurl, query, rawQuery=False): | |
983 | if rawQuery: | |
984 | wire = query | |
985 | else: | |
986 | wire = query.to_wire() | |
987 | param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=') | |
988 | return baseurl + "?dns=" + param | |
989 | ||
990 | @classmethod | |
991 | def openDOHConnection(cls, port, caFile, timeout=2.0): | |
992 | conn = pycurl.Curl() | |
993 | conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2) | |
994 | ||
995 | conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message", | |
996 | "Accept: application/dns-message"]) | |
997 | return conn | |
998 | ||
999 | @classmethod | |
41f36765 | 1000 | def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None, conn=None): |
fda32c1c | 1001 | url = cls.getDOHGetURL(baseurl, query, rawQuery) |
c02b7e13 RG |
1002 | |
1003 | if not conn: | |
c02b7e13 RG |
1004 | conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout) |
1005 | # this means "really do HTTP/2, not HTTP/1 with Upgrade headers" | |
1006 | conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE) | |
1007 | ||
fda32c1c RG |
1008 | if useHTTPS: |
1009 | conn.setopt(pycurl.SSL_VERIFYPEER, 1) | |
1010 | conn.setopt(pycurl.SSL_VERIFYHOST, 2) | |
1011 | if caFile: | |
1012 | conn.setopt(pycurl.CAINFO, caFile) | |
1013 | ||
c02b7e13 RG |
1014 | response_headers = BytesIO() |
1015 | #conn.setopt(pycurl.VERBOSE, True) | |
1016 | conn.setopt(pycurl.URL, url) | |
1017 | conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)]) | |
1018 | ||
fda32c1c RG |
1019 | conn.setopt(pycurl.HTTPHEADER, customHeaders) |
1020 | conn.setopt(pycurl.HEADERFUNCTION, response_headers.write) | |
1021 | ||
1022 | if response: | |
1023 | if toQueue: | |
1024 | toQueue.put(response, True, timeout) | |
1025 | else: | |
1026 | cls._toResponderQueue.put(response, True, timeout) | |
1027 | ||
1028 | receivedQuery = None | |
1029 | message = None | |
1030 | cls._response_headers = '' | |
1031 | data = conn.perform_rb() | |
1032 | cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE) | |
1033 | if cls._rcode == 200 and not rawResponse: | |
1034 | message = dns.message.from_wire(data) | |
1035 | elif rawResponse: | |
1036 | message = data | |
1037 | ||
1038 | if useQueue: | |
1039 | if fromQueue: | |
1040 | if not fromQueue.empty(): | |
1041 | receivedQuery = fromQueue.get(True, timeout) | |
1042 | else: | |
1043 | if not cls._fromResponderQueue.empty(): | |
1044 | receivedQuery = cls._fromResponderQueue.get(True, timeout) | |
1045 | ||
1046 | cls._response_headers = response_headers.getvalue() | |
1047 | return (receivedQuery, message) | |
1048 | ||
1049 | @classmethod | |
1050 | def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True): | |
1051 | url = baseurl | |
1052 | conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout) | |
1053 | response_headers = BytesIO() | |
1054 | #conn.setopt(pycurl.VERBOSE, True) | |
1055 | conn.setopt(pycurl.URL, url) | |
1056 | conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)]) | |
7e8a05fa RG |
1057 | # this means "really do HTTP/2, not HTTP/1 with Upgrade headers" |
1058 | conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE) | |
fda32c1c RG |
1059 | if useHTTPS: |
1060 | conn.setopt(pycurl.SSL_VERIFYPEER, 1) | |
1061 | conn.setopt(pycurl.SSL_VERIFYHOST, 2) | |
1062 | if caFile: | |
1063 | conn.setopt(pycurl.CAINFO, caFile) | |
1064 | ||
1065 | conn.setopt(pycurl.HTTPHEADER, customHeaders) | |
1066 | conn.setopt(pycurl.HEADERFUNCTION, response_headers.write) | |
1067 | conn.setopt(pycurl.POST, True) | |
1068 | data = query | |
1069 | if not rawQuery: | |
1070 | data = data.to_wire() | |
1071 | ||
1072 | conn.setopt(pycurl.POSTFIELDS, data) | |
1073 | ||
1074 | if response: | |
1075 | cls._toResponderQueue.put(response, True, timeout) | |
1076 | ||
1077 | receivedQuery = None | |
1078 | message = None | |
1079 | cls._response_headers = '' | |
1080 | data = conn.perform_rb() | |
1081 | cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE) | |
1082 | if cls._rcode == 200 and not rawResponse: | |
1083 | message = dns.message.from_wire(data) | |
1084 | elif rawResponse: | |
1085 | message = data | |
1086 | ||
1087 | if useQueue and not cls._fromResponderQueue.empty(): | |
1088 | receivedQuery = cls._fromResponderQueue.get(True, timeout) | |
1089 | ||
1090 | cls._response_headers = response_headers.getvalue() | |
1091 | return (receivedQuery, message) | |
334e3549 | 1092 | |
41f36765 RG |
1093 | def sendDOHQueryWrapper(self, query, response, useQueue=True): |
1094 | return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue) | |
334e3549 | 1095 | |
2f4ac048 RG |
1096 | def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True): |
1097 | return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue) | |
1098 | ||
1099 | def sendDOHWithH2OQueryWrapper(self, query, response, useQueue=True): | |
1100 | return self.sendDOHQuery(self._dohWithH2OServerPort, self._serverName, self._dohWithH2OBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue) | |
1101 | ||
334e3549 RG |
1102 | def sendDOTQueryWrapper(self, query, response, useQueue=True): |
1103 | return self.sendDOTQuery(self._tlsServerPort, self._serverName, query, response, self._caCert, useQueue=useQueue) | |
3dc49a89 | 1104 | |
0a6676a4 RG |
1105 | def sendDOQQueryWrapper(self, query, response, useQueue=True): |
1106 | return self.sendDOQQuery(self._doqServerPort, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName) | |
1107 | ||
655fe34d CHB |
1108 | def sendDOH3QueryWrapper(self, query, response, useQueue=True): |
1109 | return self.sendDOH3Query(self._doh3ServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, serverName=self._serverName) | |
3dc49a89 | 1110 | @classmethod |
0a6676a4 | 1111 | def getDOQConnection(cls, port, caFile=None, source=None, source_port=0): |
3dc49a89 CHB |
1112 | |
1113 | manager = dns.quic.SyncQuicManager( | |
1114 | verify_mode=caFile | |
1115 | ) | |
1116 | ||
0a6676a4 | 1117 | return manager.connect('127.0.0.1', port, source, source_port) |
3dc49a89 CHB |
1118 | |
1119 | @classmethod | |
0a6676a4 | 1120 | def sendDOQQuery(cls, port, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None): |
3dc49a89 CHB |
1121 | |
1122 | if response: | |
1123 | if toQueue: | |
1124 | toQueue.put(response, True, timeout) | |
1125 | else: | |
1126 | cls._toResponderQueue.put(response, True, timeout) | |
1127 | ||
e7000cce | 1128 | message = quic_query(query, '127.0.0.1', timeout, port, verify=caFile, server_hostname=serverName) |
3dc49a89 CHB |
1129 | |
1130 | receivedQuery = None | |
1131 | ||
1132 | if useQueue: | |
1133 | if fromQueue: | |
1134 | if not fromQueue.empty(): | |
1135 | receivedQuery = fromQueue.get(True, timeout) | |
1136 | else: | |
1137 | if not cls._fromResponderQueue.empty(): | |
1138 | receivedQuery = cls._fromResponderQueue.get(True, timeout) | |
1139 | ||
1140 | return (receivedQuery, message) | |
4f0b10a9 CHB |
1141 | |
1142 | @classmethod | |
d0439b42 | 1143 | def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False): |
4f0b10a9 CHB |
1144 | |
1145 | if response: | |
1146 | if toQueue: | |
1147 | toQueue.put(response, True, timeout) | |
1148 | else: | |
1149 | cls._toResponderQueue.put(response, True, timeout) | |
1150 | ||
d0439b42 | 1151 | message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post) |
4f0b10a9 CHB |
1152 | |
1153 | receivedQuery = None | |
1154 | ||
1155 | if useQueue: | |
1156 | if fromQueue: | |
1157 | if not fromQueue.empty(): | |
1158 | receivedQuery = fromQueue.get(True, timeout) | |
1159 | else: | |
1160 | if not cls._fromResponderQueue.empty(): | |
1161 | receivedQuery = cls._fromResponderQueue.get(True, timeout) | |
1162 | ||
1163 | return (receivedQuery, message) |