]>
Commit | Line | Data |
---|---|---|
ca404e94 RG |
1 | #!/usr/bin/env python2 |
2 | ||
95f0b802 | 3 | import copy |
ffabdc3e | 4 | import errno |
ca404e94 RG |
5 | import os |
6 | import socket | |
a227f47d | 7 | import ssl |
ca404e94 RG |
8 | import struct |
9 | import subprocess | |
10 | import sys | |
11 | import threading | |
12 | import time | |
13 | import unittest | |
9d71a0cf | 14 | |
5df86a8a | 15 | import clientsubnetoption |
9d71a0cf | 16 | |
b1bec9f0 RG |
17 | import dns |
18 | import dns.message | |
9d71a0cf | 19 | |
1ea747c0 RG |
20 | import libnacl |
21 | import libnacl.utils | |
ca404e94 | 22 | |
9d71a0cf RG |
23 | import h2.connection |
24 | import h2.events | |
25 | import h2.config | |
26 | ||
6bd430bf | 27 | from eqdnsmessage import AssertEqualDNSMessageMixin |
0e6892c6 | 28 | from proxyprotocol import ProxyProtocol |
6bd430bf | 29 | |
b4f23783 | 30 | # Python2/3 compatibility hacks |
7a0ea291 | 31 | try: |
32 | from queue import Queue | |
33 | except ImportError: | |
b4f23783 | 34 | from Queue import Queue |
7a0ea291 | 35 | |
36 | try: | |
b4f23783 | 37 | range = xrange |
7a0ea291 | 38 | except NameError: |
39 | pass | |
b4f23783 CH |
40 | |
41 | ||
6bd430bf | 42 | class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): |
ca404e94 RG |
43 | """ |
44 | Set up a dnsdist instance and responder threads. | |
45 | Queries sent to dnsdist are relayed to the responder threads, | |
46 | who reply with the response provided by the tests themselves | |
47 | on a queue. Responder threads also queue the queries received | |
48 | from dnsdist on a separate queue, allowing the tests to check | |
49 | that the queries sent from dnsdist were as expected. | |
50 | """ | |
51 | _dnsDistPort = 5340 | |
b052847c | 52 | _dnsDistListeningAddr = "127.0.0.1" |
ca404e94 | 53 | _testServerPort = 5350 |
b4f23783 CH |
54 | _toResponderQueue = Queue() |
55 | _fromResponderQueue = Queue() | |
617dfe22 | 56 | _queueTimeout = 1 |
b1bec9f0 | 57 | _dnsdistStartupDelay = 2.0 |
ca404e94 | 58 | _dnsdist = None |
ec5f5c6b | 59 | _responsesCounter = {} |
18a0e7c6 | 60 | _config_template = """ |
18a0e7c6 CH |
61 | """ |
62 | _config_params = ['_testServerPort'] | |
63 | _acl = ['127.0.0.1/32'] | |
1ea747c0 RG |
64 | _consolePort = 5199 |
65 | _consoleKey = None | |
98650fde RG |
66 | _healthCheckName = 'a.root-servers.net.' |
67 | _healthCheckCounter = 0 | |
e44df0f1 | 68 | _answerUnexpected = True |
f73ce0e3 | 69 | _checkConfigExpectedOutput = None |
2a3cafcd | 70 | _verboseMode = False |
db7acdaf | 71 | _skipListeningOnCL = False |
7373e3a6 RG |
72 | _backgroundThreads = {} |
73 | _UDPResponder = None | |
74 | _TCPResponder = None | |
ca404e94 | 75 | |
ffabdc3e OM |
76 | @classmethod |
77 | def waitForTCPSocket(cls, ipaddress, port): | |
78 | for try_number in range(0, 20): | |
79 | try: | |
80 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
81 | sock.settimeout(1.0) | |
82 | sock.connect((ipaddress, port)) | |
83 | sock.close() | |
84 | return | |
85 | except Exception as err: | |
86 | if err.errno != errno.ECONNREFUSED: | |
87 | print(f'Error occurred: {try_number} {err}', file=sys.stderr) | |
88 | time.sleep(0.1) | |
89 | # We assume the dnsdist instance does not listen. That's fine. | |
90 | ||
ca404e94 RG |
91 | @classmethod |
92 | def startResponders(cls): | |
93 | print("Launching responders..") | |
ec5f5c6b | 94 | |
5df86a8a | 95 | cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue]) |
ca404e94 RG |
96 | cls._UDPResponder.setDaemon(True) |
97 | cls._UDPResponder.start() | |
5df86a8a | 98 | cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue]) |
ca404e94 RG |
99 | cls._TCPResponder.setDaemon(True) |
100 | cls._TCPResponder.start() | |
101 | ||
102 | @classmethod | |
aac2124b | 103 | def startDNSDist(cls): |
ca404e94 | 104 | print("Launching dnsdist..") |
aac2124b | 105 | confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__)) |
18a0e7c6 CH |
106 | params = tuple([getattr(cls, param) for param in cls._config_params]) |
107 | print(params) | |
aac2124b | 108 | with open(confFile, 'w') as conf: |
18a0e7c6 CH |
109 | conf.write("-- Autogenerated by dnsdisttests.py\n") |
110 | conf.write(cls._config_template % params) | |
f3853a40 | 111 | conf.write("setSecurityPollSuffix('')") |
18a0e7c6 | 112 | |
db7acdaf RG |
113 | if cls._skipListeningOnCL: |
114 | dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile ] | |
115 | else: | |
116 | dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile, | |
117 | '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ] | |
118 | ||
2a3cafcd RG |
119 | if cls._verboseMode: |
120 | dnsdistcmd.append('-v') | |
121 | ||
18a0e7c6 CH |
122 | for acl in cls._acl: |
123 | dnsdistcmd.extend(['--acl', acl]) | |
124 | print(' '.join(dnsdistcmd)) | |
125 | ||
6b44773a CH |
126 | # validate config with --check-config, which sets client=true, possibly exposing bugs. |
127 | testcmd = dnsdistcmd + ['--check-config'] | |
ff0bc6a6 JS |
128 | try: |
129 | output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True) | |
130 | except subprocess.CalledProcessError as exc: | |
131 | raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output)) | |
f73ce0e3 RG |
132 | if cls._checkConfigExpectedOutput is not None: |
133 | expectedOutput = cls._checkConfigExpectedOutput | |
134 | else: | |
135 | expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode() | |
2a3cafcd | 136 | if not cls._verboseMode and output != expectedOutput: |
6b44773a CH |
137 | raise AssertionError('dnsdist --check-config failed: %s' % output) |
138 | ||
aac2124b RG |
139 | logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__)) |
140 | with open(logFile, 'w') as fdLog: | |
141 | cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog) | |
ca404e94 | 142 | |
ffabdc3e | 143 | cls.waitForTCPSocket(cls._dnsDistListeningAddr, cls._dnsDistPort); |
ca404e94 RG |
144 | |
145 | if cls._dnsdist.poll() is not None: | |
ffabdc3e OM |
146 | print(f"\n*** startDNSDist log for {logFile} ***") |
147 | with open(logFile, 'r') as fdLog: | |
148 | print(fdLog.read()) | |
149 | print(f"*** End startDNSDist log for {logFile} ***") | |
150 | raise AssertionError('%s failed (%d)' % (dnsdistcmd, cls._dnsdist.returncode)) | |
ca404e94 RG |
151 | |
152 | @classmethod | |
153 | def setUpSockets(cls): | |
154 | print("Setting up UDP socket..") | |
155 | cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
1ade83b2 | 156 | cls._sock.settimeout(2.0) |
ca404e94 RG |
157 | cls._sock.connect(("127.0.0.1", cls._dnsDistPort)) |
158 | ||
e4284d05 OM |
159 | @classmethod |
160 | def killProcess(cls, p): | |
161 | # Don't try to kill it if it's already dead | |
162 | if p.poll() is not None: | |
163 | return | |
164 | try: | |
165 | p.terminate() | |
166 | for count in range(10): | |
167 | x = p.poll() | |
168 | if x is not None: | |
169 | break | |
170 | time.sleep(0.1) | |
171 | if x is None: | |
172 | print("kill...", p, file=sys.stderr) | |
173 | p.kill() | |
174 | p.wait() | |
175 | except OSError as e: | |
176 | # There is a race-condition with the poll() and | |
177 | # kill() statements, when the process is dead on the | |
178 | # kill(), this is fine | |
179 | if e.errno != errno.ESRCH: | |
180 | raise | |
181 | ||
ca404e94 RG |
182 | @classmethod |
183 | def setUpClass(cls): | |
184 | ||
185 | cls.startResponders() | |
aac2124b | 186 | cls.startDNSDist() |
ca404e94 RG |
187 | cls.setUpSockets() |
188 | ||
189 | print("Launching tests..") | |
190 | ||
191 | @classmethod | |
192 | def tearDownClass(cls): | |
ffabdc3e OM |
193 | cls._sock.close() |
194 | # tell the background threads to stop, if any | |
195 | for backgroundThread in cls._backgroundThreads: | |
196 | cls._backgroundThreads[backgroundThread] = False | |
e4284d05 | 197 | cls.killProcess(cls._dnsdist) |
7373e3a6 | 198 | |
ca404e94 | 199 | @classmethod |
fe1c60f2 | 200 | def _ResponderIncrementCounter(cls): |
ec5f5c6b RG |
201 | if threading.currentThread().name in cls._responsesCounter: |
202 | cls._responsesCounter[threading.currentThread().name] += 1 | |
203 | else: | |
204 | cls._responsesCounter[threading.currentThread().name] = 1 | |
205 | ||
fe1c60f2 | 206 | @classmethod |
4aa08b62 | 207 | def _getResponse(cls, request, fromQueue, toQueue, synthesize=None): |
fe1c60f2 RG |
208 | response = None |
209 | if len(request.question) != 1: | |
210 | print("Skipping query with question count %d" % (len(request.question))) | |
211 | return None | |
98650fde RG |
212 | healthCheck = str(request.question[0].name).endswith(cls._healthCheckName) |
213 | if healthCheck: | |
214 | cls._healthCheckCounter += 1 | |
4aa08b62 | 215 | response = dns.message.make_response(request) |
98650fde | 216 | else: |
fe1c60f2 | 217 | cls._ResponderIncrementCounter() |
5df86a8a | 218 | if not fromQueue.empty(): |
4aa08b62 | 219 | toQueue.put(request, True, cls._queueTimeout) |
90186270 RG |
220 | response = fromQueue.get(True, cls._queueTimeout) |
221 | if response: | |
222 | response = copy.copy(response) | |
223 | response.id = request.id | |
224 | ||
225 | if synthesize is not None: | |
226 | response = dns.message.make_response(request) | |
227 | response.set_rcode(synthesize) | |
fe1c60f2 | 228 | |
e44df0f1 | 229 | if not response: |
90186270 | 230 | if cls._answerUnexpected: |
e44df0f1 RG |
231 | response = dns.message.make_response(request) |
232 | response.set_rcode(dns.rcode.SERVFAIL) | |
fe1c60f2 RG |
233 | |
234 | return response | |
235 | ||
ec5f5c6b | 236 | @classmethod |
a620f197 | 237 | def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None): |
7373e3a6 | 238 | cls._backgroundThreads[threading.get_native_id()] = True |
3ef7ab0d RG |
239 | # trailingDataResponse=True means "ignore trailing data". |
240 | # Other values are either False (meaning "raise an exception") | |
241 | # or are interpreted as a response RCODE for queries with trailing data. | |
a620f197 | 242 | # callback is invoked for every -even healthcheck ones- query and should return a raw response |
4aa08b62 | 243 | ignoreTrailing = trailingDataResponse is True |
3ef7ab0d | 244 | |
ca404e94 RG |
245 | sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) |
246 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | |
ec5f5c6b | 247 | sock.bind(("127.0.0.1", port)) |
7373e3a6 | 248 | sock.settimeout(1.0) |
ca404e94 | 249 | while True: |
7373e3a6 RG |
250 | try: |
251 | data, addr = sock.recvfrom(4096) | |
252 | except socket.timeout: | |
253 | if cls._backgroundThreads.get(threading.get_native_id(), False) == False: | |
254 | del cls._backgroundThreads[threading.get_native_id()] | |
255 | break | |
256 | else: | |
257 | continue | |
258 | ||
4aa08b62 RG |
259 | forceRcode = None |
260 | try: | |
261 | request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) | |
262 | except dns.message.TrailingJunk as e: | |
51f07ad4 | 263 | print('trailing data exception in UDPResponder') |
3ef7ab0d | 264 | if trailingDataResponse is False or forceRcode is True: |
4aa08b62 RG |
265 | raise |
266 | print("UDP query with trailing data, synthesizing response") | |
267 | request = dns.message.from_wire(data, ignore_trailing=True) | |
268 | forceRcode = trailingDataResponse | |
269 | ||
f3913dd2 | 270 | wire = None |
a620f197 RG |
271 | if callback: |
272 | wire = callback(request) | |
273 | else: | |
f8662974 RG |
274 | if request.edns > 1: |
275 | forceRcode = dns.rcode.BADVERS | |
a620f197 RG |
276 | response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) |
277 | if response: | |
278 | wire = response.to_wire() | |
87c605c4 | 279 | |
f3913dd2 RG |
280 | if not wire: |
281 | continue | |
282 | ||
a620f197 | 283 | sock.sendto(wire, addr) |
7373e3a6 | 284 | |
ca404e94 RG |
285 | sock.close() |
286 | ||
287 | @classmethod | |
645a1ca4 RG |
288 | def handleTCPConnection(cls, conn, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None): |
289 | ignoreTrailing = trailingDataResponse is True | |
290 | data = conn.recv(2) | |
291 | if not data: | |
292 | conn.close() | |
293 | return | |
294 | ||
295 | (datalen,) = struct.unpack("!H", data) | |
296 | data = conn.recv(datalen) | |
297 | forceRcode = None | |
298 | try: | |
299 | request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) | |
300 | except dns.message.TrailingJunk as e: | |
301 | if trailingDataResponse is False or forceRcode is True: | |
302 | raise | |
303 | print("TCP query with trailing data, synthesizing response") | |
304 | request = dns.message.from_wire(data, ignore_trailing=True) | |
305 | forceRcode = trailingDataResponse | |
306 | ||
307 | if callback: | |
308 | wire = callback(request) | |
309 | else: | |
f8662974 RG |
310 | if request.edns > 1: |
311 | forceRcode = dns.rcode.BADVERS | |
645a1ca4 RG |
312 | response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) |
313 | if response: | |
314 | wire = response.to_wire(max_size=65535) | |
315 | ||
316 | if not wire: | |
317 | conn.close() | |
318 | return | |
319 | ||
320 | conn.send(struct.pack("!H", len(wire))) | |
321 | conn.send(wire) | |
322 | ||
323 | while multipleResponses: | |
936dd73c RG |
324 | # do not block, and stop as soon as the queue is empty, either the next response is already here or we are done |
325 | # otherwise we might read responses intended for the next connection | |
645a1ca4 RG |
326 | if fromQueue.empty(): |
327 | break | |
328 | ||
936dd73c | 329 | response = fromQueue.get(False) |
645a1ca4 RG |
330 | if not response: |
331 | break | |
332 | ||
333 | response = copy.copy(response) | |
334 | response.id = request.id | |
335 | wire = response.to_wire(max_size=65535) | |
336 | try: | |
337 | conn.send(struct.pack("!H", len(wire))) | |
338 | conn.send(wire) | |
339 | except socket.error as e: | |
340 | # some of the tests are going to close | |
341 | # the connection on us, just deal with it | |
342 | break | |
343 | ||
344 | conn.close() | |
345 | ||
346 | @classmethod | |
144eebeb | 347 | def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, multipleConnections=False, listeningAddr='127.0.0.1'): |
7373e3a6 | 348 | cls._backgroundThreads[threading.get_native_id()] = True |
3ef7ab0d RG |
349 | # trailingDataResponse=True means "ignore trailing data". |
350 | # Other values are either False (meaning "raise an exception") | |
351 | # or are interpreted as a response RCODE for queries with trailing data. | |
a620f197 | 352 | # callback is invoked for every -even healthcheck ones- query and should return a raw response |
3ef7ab0d | 353 | |
ca404e94 | 354 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
501af9ae | 355 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
ca404e94 RG |
356 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) |
357 | try: | |
144eebeb | 358 | sock.bind((listeningAddr, port)) |
ca404e94 RG |
359 | except socket.error as e: |
360 | print("Error binding in the TCP responder: %s" % str(e)) | |
361 | sys.exit(1) | |
362 | ||
363 | sock.listen(100) | |
7373e3a6 | 364 | sock.settimeout(1.0) |
7d2856a6 RG |
365 | if tlsContext: |
366 | sock = tlsContext.wrap_socket(sock, server_side=True) | |
367 | ||
ca404e94 | 368 | while True: |
7d2856a6 RG |
369 | try: |
370 | (conn, _) = sock.accept() | |
371 | except ssl.SSLError: | |
372 | continue | |
ae3b96d9 RG |
373 | except ConnectionResetError: |
374 | continue | |
7373e3a6 RG |
375 | except socket.timeout: |
376 | if cls._backgroundThreads.get(threading.get_native_id(), False) == False: | |
377 | del cls._backgroundThreads[threading.get_native_id()] | |
378 | break | |
379 | else: | |
380 | continue | |
ae3b96d9 | 381 | |
6ac8517d | 382 | conn.settimeout(5.0) |
645a1ca4 RG |
383 | if multipleConnections: |
384 | thread = threading.Thread(name='TCP Connection Handler', | |
385 | target=cls.handleTCPConnection, | |
386 | args=[conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback]) | |
387 | thread.setDaemon(True) | |
388 | thread.start() | |
a620f197 | 389 | else: |
645a1ca4 | 390 | cls.handleTCPConnection(conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback) |
548c8b66 | 391 | |
ca404e94 RG |
392 | sock.close() |
393 | ||
c4c72a2c RG |
394 | @classmethod |
395 | def handleDoHConnection(cls, config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol): | |
396 | ignoreTrailing = trailingDataResponse is True | |
144eebeb RG |
397 | try: |
398 | h2conn = h2.connection.H2Connection(config=config) | |
399 | h2conn.initiate_connection() | |
400 | conn.sendall(h2conn.data_to_send()) | |
401 | except ssl.SSLEOFError as e: | |
402 | print("Unexpected EOF: %s" % (e)) | |
403 | return | |
404 | ||
c4c72a2c RG |
405 | dnsData = {} |
406 | ||
407 | if useProxyProtocol: | |
408 | # try to read the entire Proxy Protocol header | |
409 | proxy = ProxyProtocol() | |
410 | header = conn.recv(proxy.HEADER_SIZE) | |
411 | if not header: | |
412 | print('unable to get header') | |
413 | conn.close() | |
414 | return | |
415 | ||
416 | if not proxy.parseHeader(header): | |
417 | print('unable to parse header') | |
418 | print(header) | |
419 | conn.close() | |
420 | return | |
421 | ||
422 | proxyContent = conn.recv(proxy.contentLen) | |
423 | if not proxyContent: | |
424 | print('unable to get content') | |
425 | conn.close() | |
426 | return | |
427 | ||
428 | payload = header + proxyContent | |
429 | toQueue.put(payload, True, cls._queueTimeout) | |
430 | ||
431 | # be careful, HTTP/2 headers and data might be in different recv() results | |
432 | requestHeaders = None | |
433 | while True: | |
434 | data = conn.recv(65535) | |
435 | if not data: | |
436 | break | |
437 | ||
438 | events = h2conn.receive_data(data) | |
439 | for event in events: | |
440 | if isinstance(event, h2.events.RequestReceived): | |
441 | requestHeaders = event.headers | |
442 | if isinstance(event, h2.events.DataReceived): | |
443 | h2conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id) | |
444 | if not event.stream_id in dnsData: | |
445 | dnsData[event.stream_id] = b'' | |
446 | dnsData[event.stream_id] = dnsData[event.stream_id] + (event.data) | |
447 | if event.stream_ended: | |
448 | forceRcode = None | |
449 | status = 200 | |
450 | try: | |
451 | request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=ignoreTrailing) | |
452 | except dns.message.TrailingJunk as e: | |
453 | if trailingDataResponse is False or forceRcode is True: | |
454 | raise | |
455 | print("DOH query with trailing data, synthesizing response") | |
456 | request = dns.message.from_wire(dnsData[event.stream_id], ignore_trailing=True) | |
457 | forceRcode = trailingDataResponse | |
458 | ||
459 | if callback: | |
460 | status, wire = callback(request, requestHeaders, fromQueue, toQueue) | |
461 | else: | |
462 | response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) | |
463 | if response: | |
464 | wire = response.to_wire(max_size=65535) | |
465 | ||
466 | if not wire: | |
467 | conn.close() | |
468 | conn = None | |
469 | break | |
470 | ||
471 | headers = [ | |
472 | (':status', str(status)), | |
473 | ('content-length', str(len(wire))), | |
474 | ('content-type', 'application/dns-message'), | |
475 | ] | |
476 | h2conn.send_headers(stream_id=event.stream_id, headers=headers) | |
477 | h2conn.send_data(stream_id=event.stream_id, data=wire, end_stream=True) | |
478 | ||
479 | data_to_send = h2conn.data_to_send() | |
480 | if data_to_send: | |
481 | conn.sendall(data_to_send) | |
482 | ||
483 | if conn is None: | |
484 | break | |
485 | ||
486 | if conn is not None: | |
487 | conn.close() | |
488 | ||
9d71a0cf | 489 | @classmethod |
0e6892c6 | 490 | def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False): |
7373e3a6 | 491 | cls._backgroundThreads[threading.get_native_id()] = True |
9d71a0cf RG |
492 | # trailingDataResponse=True means "ignore trailing data". |
493 | # Other values are either False (meaning "raise an exception") | |
494 | # or are interpreted as a response RCODE for queries with trailing data. | |
495 | # callback is invoked for every -even healthcheck ones- query and should return a raw response | |
9d71a0cf RG |
496 | |
497 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
498 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) | |
499 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | |
500 | try: | |
501 | sock.bind(("127.0.0.1", port)) | |
502 | except socket.error as e: | |
503 | print("Error binding in the TCP responder: %s" % str(e)) | |
504 | sys.exit(1) | |
505 | ||
506 | sock.listen(100) | |
7373e3a6 | 507 | sock.settimeout(1.0) |
9d71a0cf RG |
508 | if tlsContext: |
509 | sock = tlsContext.wrap_socket(sock, server_side=True) | |
510 | ||
511 | config = h2.config.H2Configuration(client_side=False) | |
512 | ||
513 | while True: | |
514 | try: | |
515 | (conn, _) = sock.accept() | |
516 | except ssl.SSLError: | |
517 | continue | |
518 | except ConnectionResetError: | |
519 | continue | |
7373e3a6 RG |
520 | except socket.timeout: |
521 | if cls._backgroundThreads.get(threading.get_native_id(), False) == False: | |
522 | del cls._backgroundThreads[threading.get_native_id()] | |
523 | break | |
524 | else: | |
525 | continue | |
ae3b96d9 | 526 | |
9d71a0cf | 527 | conn.settimeout(5.0) |
c4c72a2c RG |
528 | thread = threading.Thread(name='DoH Connection Handler', |
529 | target=cls.handleDoHConnection, | |
530 | args=[config, conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback, tlsContext, useProxyProtocol]) | |
531 | thread.setDaemon(True) | |
532 | thread.start() | |
9d71a0cf RG |
533 | |
534 | sock.close() | |
535 | ||
ca404e94 | 536 | @classmethod |
55baa1f2 | 537 | def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False): |
90186270 | 538 | if useQueue and response is not None: |
617dfe22 | 539 | cls._toResponderQueue.put(response, True, timeout) |
ca404e94 RG |
540 | |
541 | if timeout: | |
542 | cls._sock.settimeout(timeout) | |
543 | ||
544 | try: | |
55baa1f2 RG |
545 | if not rawQuery: |
546 | query = query.to_wire() | |
547 | cls._sock.send(query) | |
ca404e94 | 548 | data = cls._sock.recv(4096) |
b1bec9f0 | 549 | except socket.timeout: |
ca404e94 RG |
550 | data = None |
551 | finally: | |
552 | if timeout: | |
553 | cls._sock.settimeout(None) | |
554 | ||
555 | receivedQuery = None | |
556 | message = None | |
557 | if useQueue and not cls._fromResponderQueue.empty(): | |
617dfe22 | 558 | receivedQuery = cls._fromResponderQueue.get(True, timeout) |
ca404e94 RG |
559 | if data: |
560 | message = dns.message.from_wire(data) | |
561 | return (receivedQuery, message) | |
562 | ||
563 | @classmethod | |
db7acdaf | 564 | def openTCPConnection(cls, timeout=None, port=None): |
ca404e94 | 565 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
501af9ae | 566 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
ca404e94 RG |
567 | if timeout: |
568 | sock.settimeout(timeout) | |
569 | ||
db7acdaf RG |
570 | if not port: |
571 | port = cls._dnsDistPort | |
572 | ||
573 | sock.connect(("127.0.0.1", port)) | |
9396d955 | 574 | return sock |
0a2087eb | 575 | |
9396d955 | 576 | @classmethod |
a227f47d RG |
577 | def openTLSConnection(cls, port, serverName, caCert=None, timeout=None): |
578 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
501af9ae | 579 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
a227f47d RG |
580 | if timeout: |
581 | sock.settimeout(timeout) | |
582 | ||
583 | # 2.7.9+ | |
584 | if hasattr(ssl, 'create_default_context'): | |
585 | sslctx = ssl.create_default_context(cafile=caCert) | |
586 | sslsock = sslctx.wrap_socket(sock, server_hostname=serverName) | |
587 | else: | |
588 | sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED) | |
589 | ||
590 | sslsock.connect(("127.0.0.1", port)) | |
591 | return sslsock | |
592 | ||
593 | @classmethod | |
594 | def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0): | |
9396d955 RG |
595 | if not rawQuery: |
596 | wire = query.to_wire() | |
597 | else: | |
598 | wire = query | |
55baa1f2 | 599 | |
a227f47d RG |
600 | if response: |
601 | cls._toResponderQueue.put(response, True, timeout) | |
602 | ||
9396d955 RG |
603 | sock.send(struct.pack("!H", len(wire))) |
604 | sock.send(wire) | |
605 | ||
606 | @classmethod | |
a227f47d | 607 | def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0): |
f05cd66c | 608 | print("reading data") |
9396d955 RG |
609 | message = None |
610 | data = sock.recv(2) | |
611 | if data: | |
612 | (datalen,) = struct.unpack("!H", data) | |
f05cd66c | 613 | print(datalen) |
9396d955 | 614 | data = sock.recv(datalen) |
ca404e94 | 615 | if data: |
f05cd66c | 616 | print(data) |
9396d955 | 617 | message = dns.message.from_wire(data) |
a227f47d | 618 | |
f05cd66c | 619 | print(useQueue) |
a227f47d RG |
620 | if useQueue and not cls._fromResponderQueue.empty(): |
621 | receivedQuery = cls._fromResponderQueue.get(True, timeout) | |
f05cd66c RG |
622 | print("Got from queue") |
623 | print(receivedQuery) | |
a227f47d RG |
624 | return (receivedQuery, message) |
625 | else: | |
f05cd66c | 626 | print("queue empty") |
a227f47d | 627 | return message |
9396d955 RG |
628 | |
629 | @classmethod | |
630 | def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False): | |
631 | message = None | |
632 | if useQueue: | |
633 | cls._toResponderQueue.put(response, True, timeout) | |
634 | ||
635 | sock = cls.openTCPConnection(timeout) | |
636 | ||
637 | try: | |
638 | cls.sendTCPQueryOverConnection(sock, query, rawQuery) | |
639 | message = cls.recvTCPResponseOverConnection(sock) | |
ca404e94 | 640 | except socket.timeout as e: |
0e6892c6 | 641 | print("Timeout while sending or receiving TCP data: %s" % (str(e))) |
ca404e94 RG |
642 | except socket.error as e: |
643 | print("Network error: %s" % (str(e))) | |
ca404e94 RG |
644 | finally: |
645 | sock.close() | |
646 | ||
647 | receivedQuery = None | |
f05cd66c | 648 | print(useQueue) |
ca404e94 | 649 | if useQueue and not cls._fromResponderQueue.empty(): |
f05cd66c RG |
650 | print("Got from queue") |
651 | print(receivedQuery) | |
617dfe22 | 652 | receivedQuery = cls._fromResponderQueue.get(True, timeout) |
f05cd66c RG |
653 | else: |
654 | print("queue is empty") | |
9396d955 | 655 | |
ca404e94 | 656 | return (receivedQuery, message) |
617dfe22 | 657 | |
548c8b66 RG |
658 | @classmethod |
659 | def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False): | |
660 | if useQueue: | |
661 | for response in responses: | |
662 | cls._toResponderQueue.put(response, True, timeout) | |
663 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
501af9ae | 664 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
548c8b66 RG |
665 | if timeout: |
666 | sock.settimeout(timeout) | |
667 | ||
668 | sock.connect(("127.0.0.1", cls._dnsDistPort)) | |
669 | messages = [] | |
670 | ||
671 | try: | |
672 | if not rawQuery: | |
673 | wire = query.to_wire() | |
674 | else: | |
675 | wire = query | |
676 | ||
677 | sock.send(struct.pack("!H", len(wire))) | |
678 | sock.send(wire) | |
679 | while True: | |
680 | data = sock.recv(2) | |
681 | if not data: | |
682 | break | |
683 | (datalen,) = struct.unpack("!H", data) | |
684 | data = sock.recv(datalen) | |
685 | messages.append(dns.message.from_wire(data)) | |
686 | ||
687 | except socket.timeout as e: | |
0e6892c6 | 688 | print("Timeout while receiving multiple TCP responses: %s" % (str(e))) |
548c8b66 RG |
689 | except socket.error as e: |
690 | print("Network error: %s" % (str(e))) | |
691 | finally: | |
692 | sock.close() | |
693 | ||
694 | receivedQuery = None | |
695 | if useQueue and not cls._fromResponderQueue.empty(): | |
696 | receivedQuery = cls._fromResponderQueue.get(True, timeout) | |
697 | return (receivedQuery, messages) | |
698 | ||
617dfe22 | 699 | def setUp(self): |
936dd73c | 700 | # This function is called before every test |
617dfe22 RG |
701 | |
702 | # Clear the responses counters | |
fd71df4e | 703 | self._responsesCounter.clear() |
617dfe22 | 704 | |
98650fde RG |
705 | self._healthCheckCounter = 0 |
706 | ||
617dfe22 RG |
707 | # Make sure the queues are empty, in case |
708 | # a previous test failed | |
936dd73c | 709 | self.clearResponderQueues() |
1ea747c0 | 710 | |
6bd430bf PD |
711 | super(DNSDistTest, self).setUp() |
712 | ||
3bef39c3 RG |
713 | @classmethod |
714 | def clearToResponderQueue(cls): | |
715 | while not cls._toResponderQueue.empty(): | |
716 | cls._toResponderQueue.get(False) | |
717 | ||
718 | @classmethod | |
719 | def clearFromResponderQueue(cls): | |
720 | while not cls._fromResponderQueue.empty(): | |
721 | cls._fromResponderQueue.get(False) | |
722 | ||
723 | @classmethod | |
724 | def clearResponderQueues(cls): | |
725 | cls.clearToResponderQueue() | |
726 | cls.clearFromResponderQueue() | |
727 | ||
1ea747c0 RG |
728 | @staticmethod |
729 | def generateConsoleKey(): | |
730 | return libnacl.utils.salsa_key() | |
731 | ||
732 | @classmethod | |
733 | def _encryptConsole(cls, command, nonce): | |
b4f23783 | 734 | command = command.encode('UTF-8') |
1ea747c0 RG |
735 | if cls._consoleKey is None: |
736 | return command | |
737 | return libnacl.crypto_secretbox(command, nonce, cls._consoleKey) | |
738 | ||
739 | @classmethod | |
740 | def _decryptConsole(cls, command, nonce): | |
741 | if cls._consoleKey is None: | |
b4f23783 CH |
742 | result = command |
743 | else: | |
744 | result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey) | |
745 | return result.decode('UTF-8') | |
1ea747c0 RG |
746 | |
747 | @classmethod | |
748 | def sendConsoleCommand(cls, command, timeout=1.0): | |
749 | ourNonce = libnacl.utils.rand_nonce() | |
750 | theirNonce = None | |
751 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
501af9ae | 752 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
1ea747c0 RG |
753 | if timeout: |
754 | sock.settimeout(timeout) | |
755 | ||
756 | sock.connect(("127.0.0.1", cls._consolePort)) | |
757 | sock.send(ourNonce) | |
758 | theirNonce = sock.recv(len(ourNonce)) | |
7b925432 | 759 | if len(theirNonce) != len(ourNonce): |
05a5b575 | 760 | print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce))) |
bdfa6902 RG |
761 | if len(theirNonce) == 0: |
762 | raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce))) | |
7b925432 | 763 | return None |
1ea747c0 | 764 | |
b4f23783 | 765 | halfNonceSize = int(len(ourNonce) / 2) |
333ea16e RG |
766 | readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:] |
767 | writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:] | |
333ea16e | 768 | msg = cls._encryptConsole(command, writingNonce) |
1ea747c0 RG |
769 | sock.send(struct.pack("!I", len(msg))) |
770 | sock.send(msg) | |
771 | data = sock.recv(4) | |
9c9b4998 RG |
772 | if not data: |
773 | raise socket.error("Got EOF while reading the response size") | |
774 | ||
1ea747c0 RG |
775 | (responseLen,) = struct.unpack("!I", data) |
776 | data = sock.recv(responseLen) | |
333ea16e | 777 | response = cls._decryptConsole(data, readingNonce) |
75b536de | 778 | sock.close() |
1ea747c0 | 779 | return response |
5df86a8a RG |
780 | |
781 | def compareOptions(self, a, b): | |
4bfebc93 | 782 | self.assertEqual(len(a), len(b)) |
b4f23783 | 783 | for idx in range(len(a)): |
4bfebc93 | 784 | self.assertEqual(a[idx], b[idx]) |
5df86a8a RG |
785 | |
786 | def checkMessageNoEDNS(self, expected, received): | |
4bfebc93 CH |
787 | self.assertEqual(expected, received) |
788 | self.assertEqual(received.edns, -1) | |
789 | self.assertEqual(len(received.options), 0) | |
5df86a8a | 790 | |
e7c732b8 | 791 | def checkMessageEDNSWithoutOptions(self, expected, received): |
4bfebc93 CH |
792 | self.assertEqual(expected, received) |
793 | self.assertEqual(received.edns, 0) | |
794 | self.assertEqual(expected.payload, received.payload) | |
e7c732b8 | 795 | |
5df86a8a | 796 | def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0): |
4bfebc93 CH |
797 | self.assertEqual(expected, received) |
798 | self.assertEqual(received.edns, 0) | |
799 | self.assertEqual(expected.payload, received.payload) | |
800 | self.assertEqual(len(received.options), withCookies) | |
5df86a8a RG |
801 | if withCookies: |
802 | for option in received.options: | |
4bfebc93 | 803 | self.assertEqual(option.otype, 10) |
be90d6bd RG |
804 | else: |
805 | for option in received.options: | |
4bfebc93 | 806 | self.assertNotEqual(option.otype, 10) |
5df86a8a | 807 | |
cbf4e13a | 808 | def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0): |
4bfebc93 CH |
809 | self.assertEqual(expected, received) |
810 | self.assertEqual(received.edns, 0) | |
811 | self.assertEqual(expected.payload, received.payload) | |
812 | self.assertEqual(len(received.options), 1 + additionalOptions) | |
cbf4e13a RG |
813 | hasECS = False |
814 | for option in received.options: | |
815 | if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE: | |
816 | hasECS = True | |
817 | else: | |
4bfebc93 | 818 | self.assertNotEqual(additionalOptions, 0) |
cbf4e13a | 819 | |
5df86a8a | 820 | self.compareOptions(expected.options, received.options) |
cbf4e13a | 821 | self.assertTrue(hasECS) |
5df86a8a | 822 | |
346410cd CHB |
823 | def checkMessageEDNS(self, expected, received): |
824 | self.assertEqual(expected, received) | |
825 | self.assertEqual(received.edns, 0) | |
826 | self.assertEqual(expected.payload, received.payload) | |
827 | self.assertEqual(len(expected.options), len(received.options)) | |
828 | self.compareOptions(expected.options, received.options) | |
829 | ||
cbf4e13a RG |
830 | def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0): |
831 | self.checkMessageEDNSWithECS(expected, received, additionalOptions) | |
5df86a8a | 832 | |
346410cd CHB |
833 | def checkQueryEDNS(self, expected, received): |
834 | self.checkMessageEDNS(expected, received) | |
835 | ||
cbf4e13a RG |
836 | def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0): |
837 | self.checkMessageEDNSWithECS(expected, received, additionalOptions) | |
5df86a8a RG |
838 | |
839 | def checkQueryEDNSWithoutECS(self, expected, received): | |
840 | self.checkMessageEDNSWithoutECS(expected, received) | |
841 | ||
842 | def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0): | |
843 | self.checkMessageEDNSWithoutECS(expected, received, withCookies) | |
844 | ||
845 | def checkQueryNoEDNS(self, expected, received): | |
846 | self.checkMessageNoEDNS(expected, received) | |
847 | ||
848 | def checkResponseNoEDNS(self, expected, received): | |
849 | self.checkMessageNoEDNS(expected, received) | |
b0d08f82 | 850 | |
1d896c34 RG |
851 | def generateNewCertificateAndKey(self): |
852 | # generate and sign a new cert | |
853 | cmd = ['openssl', 'req', '-new', '-newkey', 'rsa:2048', '-nodes', '-keyout', 'server.key', '-out', 'server.csr', '-config', 'configServer.conf'] | |
854 | output = None | |
855 | try: | |
856 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True) | |
857 | output = process.communicate(input='') | |
858 | except subprocess.CalledProcessError as exc: | |
859 | raise AssertionError('openssl req failed (%d): %s' % (exc.returncode, exc.output)) | |
860 | cmd = ['openssl', 'x509', '-req', '-days', '1', '-CA', 'ca.pem', '-CAkey', 'ca.key', '-CAcreateserial', '-in', 'server.csr', '-out', 'server.pem', '-extfile', 'configServer.conf', '-extensions', 'v3_req'] | |
861 | output = None | |
862 | try: | |
863 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True) | |
864 | output = process.communicate(input='') | |
865 | except subprocess.CalledProcessError as exc: | |
866 | raise AssertionError('openssl x509 failed (%d): %s' % (exc.returncode, exc.output)) | |
867 | ||
868 | with open('server.chain', 'w') as outFile: | |
869 | for inFileName in ['server.pem', 'ca.pem']: | |
870 | with open(inFileName) as inFile: | |
871 | outFile.write(inFile.read()) | |
0e6892c6 | 872 | |
5ac11505 CHB |
873 | cmd = ['openssl', 'pkcs12', '-export', '-passout', 'pass:passw0rd', '-clcerts', '-in', 'server.pem', '-CAfile', 'ca.pem', '-inkey', 'server.key', '-out', 'server.p12'] |
874 | output = None | |
875 | try: | |
876 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True) | |
877 | output = process.communicate(input='') | |
878 | except subprocess.CalledProcessError as exc: | |
879 | raise AssertionError('openssl pkcs12 failed (%d): %s' % (exc.returncode, exc.output)) | |
880 | ||
0e6892c6 RG |
881 | def checkMessageProxyProtocol(self, receivedProxyPayload, source, destination, isTCP, values=[], v6=False, sourcePort=None, destinationPort=None): |
882 | proxy = ProxyProtocol() | |
883 | self.assertTrue(proxy.parseHeader(receivedProxyPayload)) | |
884 | self.assertEqual(proxy.version, 0x02) | |
885 | self.assertEqual(proxy.command, 0x01) | |
886 | if v6: | |
887 | self.assertEqual(proxy.family, 0x02) | |
888 | else: | |
889 | self.assertEqual(proxy.family, 0x01) | |
890 | if not isTCP: | |
891 | self.assertEqual(proxy.protocol, 0x02) | |
892 | else: | |
893 | self.assertEqual(proxy.protocol, 0x01) | |
894 | self.assertGreater(proxy.contentLen, 0) | |
895 | ||
896 | self.assertTrue(proxy.parseAddressesAndPorts(receivedProxyPayload)) | |
897 | self.assertEqual(proxy.source, source) | |
898 | self.assertEqual(proxy.destination, destination) | |
899 | if sourcePort: | |
900 | self.assertEqual(proxy.sourcePort, sourcePort) | |
901 | if destinationPort: | |
902 | self.assertEqual(proxy.destinationPort, destinationPort) | |
903 | else: | |
904 | self.assertEqual(proxy.destinationPort, self._dnsDistPort) | |
905 | ||
906 | self.assertTrue(proxy.parseAdditionalValues(receivedProxyPayload)) | |
907 | proxy.values.sort() | |
908 | values.sort() | |
909 | self.assertEqual(proxy.values, values) |