]>
Commit | Line | Data |
---|---|---|
ca404e94 RG |
1 | #!/usr/bin/env python2 |
2 | ||
95f0b802 | 3 | import copy |
ca404e94 RG |
4 | import os |
5 | import socket | |
a227f47d | 6 | import ssl |
ca404e94 RG |
7 | import struct |
8 | import subprocess | |
9 | import sys | |
10 | import threading | |
11 | import time | |
12 | import unittest | |
5df86a8a | 13 | import clientsubnetoption |
b1bec9f0 RG |
14 | import dns |
15 | import dns.message | |
1ea747c0 RG |
16 | import libnacl |
17 | import libnacl.utils | |
ca404e94 | 18 | |
6bd430bf PD |
19 | from eqdnsmessage import AssertEqualDNSMessageMixin |
20 | ||
b4f23783 | 21 | # Python2/3 compatibility hacks |
7a0ea291 | 22 | try: |
23 | from queue import Queue | |
24 | except ImportError: | |
b4f23783 | 25 | from Queue import Queue |
7a0ea291 | 26 | |
27 | try: | |
b4f23783 | 28 | range = xrange |
7a0ea291 | 29 | except NameError: |
30 | pass | |
b4f23783 CH |
31 | |
32 | ||
6bd430bf | 33 | class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): |
ca404e94 RG |
34 | """ |
35 | Set up a dnsdist instance and responder threads. | |
36 | Queries sent to dnsdist are relayed to the responder threads, | |
37 | who reply with the response provided by the tests themselves | |
38 | on a queue. Responder threads also queue the queries received | |
39 | from dnsdist on a separate queue, allowing the tests to check | |
40 | that the queries sent from dnsdist were as expected. | |
41 | """ | |
42 | _dnsDistPort = 5340 | |
b052847c | 43 | _dnsDistListeningAddr = "127.0.0.1" |
ca404e94 | 44 | _testServerPort = 5350 |
b4f23783 CH |
45 | _toResponderQueue = Queue() |
46 | _fromResponderQueue = Queue() | |
617dfe22 | 47 | _queueTimeout = 1 |
b1bec9f0 | 48 | _dnsdistStartupDelay = 2.0 |
ca404e94 | 49 | _dnsdist = None |
ec5f5c6b | 50 | _responsesCounter = {} |
18a0e7c6 | 51 | _config_template = """ |
18a0e7c6 CH |
52 | """ |
53 | _config_params = ['_testServerPort'] | |
54 | _acl = ['127.0.0.1/32'] | |
1ea747c0 RG |
55 | _consolePort = 5199 |
56 | _consoleKey = None | |
98650fde RG |
57 | _healthCheckName = 'a.root-servers.net.' |
58 | _healthCheckCounter = 0 | |
e44df0f1 | 59 | _answerUnexpected = True |
f73ce0e3 | 60 | _checkConfigExpectedOutput = None |
ca404e94 RG |
61 | |
62 | @classmethod | |
63 | def startResponders(cls): | |
64 | print("Launching responders..") | |
ec5f5c6b | 65 | |
5df86a8a | 66 | cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue]) |
ca404e94 RG |
67 | cls._UDPResponder.setDaemon(True) |
68 | cls._UDPResponder.start() | |
5df86a8a | 69 | cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue]) |
ca404e94 RG |
70 | cls._TCPResponder.setDaemon(True) |
71 | cls._TCPResponder.start() | |
72 | ||
73 | @classmethod | |
aac2124b | 74 | def startDNSDist(cls): |
ca404e94 | 75 | print("Launching dnsdist..") |
aac2124b | 76 | confFile = os.path.join('configs', 'dnsdist_%s.conf' % (cls.__name__)) |
18a0e7c6 CH |
77 | params = tuple([getattr(cls, param) for param in cls._config_params]) |
78 | print(params) | |
aac2124b | 79 | with open(confFile, 'w') as conf: |
18a0e7c6 CH |
80 | conf.write("-- Autogenerated by dnsdisttests.py\n") |
81 | conf.write(cls._config_template % params) | |
82 | ||
a91c99c1 | 83 | dnsdistcmd = [os.environ['DNSDISTBIN'], '--supervised', '-C', confFile, |
b052847c | 84 | '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ] |
18a0e7c6 CH |
85 | for acl in cls._acl: |
86 | dnsdistcmd.extend(['--acl', acl]) | |
87 | print(' '.join(dnsdistcmd)) | |
88 | ||
6b44773a CH |
89 | # validate config with --check-config, which sets client=true, possibly exposing bugs. |
90 | testcmd = dnsdistcmd + ['--check-config'] | |
ff0bc6a6 JS |
91 | try: |
92 | output = subprocess.check_output(testcmd, stderr=subprocess.STDOUT, close_fds=True) | |
93 | except subprocess.CalledProcessError as exc: | |
94 | raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc.returncode, exc.output)) | |
f73ce0e3 RG |
95 | if cls._checkConfigExpectedOutput is not None: |
96 | expectedOutput = cls._checkConfigExpectedOutput | |
97 | else: | |
98 | expectedOutput = ('Configuration \'%s\' OK!\n' % (confFile)).encode() | |
aac2124b | 99 | if output != expectedOutput: |
6b44773a CH |
100 | raise AssertionError('dnsdist --check-config failed: %s' % output) |
101 | ||
aac2124b RG |
102 | logFile = os.path.join('configs', 'dnsdist_%s.log' % (cls.__name__)) |
103 | with open(logFile, 'w') as fdLog: | |
104 | cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdLog, stderr=fdLog) | |
ca404e94 | 105 | |
0a2087eb RG |
106 | if 'DNSDIST_FAST_TESTS' in os.environ: |
107 | delay = 0.5 | |
108 | else: | |
617dfe22 RG |
109 | delay = cls._dnsdistStartupDelay |
110 | ||
0a2087eb | 111 | time.sleep(delay) |
ca404e94 RG |
112 | |
113 | if cls._dnsdist.poll() is not None: | |
0a2087eb | 114 | cls._dnsdist.kill() |
ca404e94 RG |
115 | sys.exit(cls._dnsdist.returncode) |
116 | ||
117 | @classmethod | |
118 | def setUpSockets(cls): | |
119 | print("Setting up UDP socket..") | |
120 | cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
1ade83b2 | 121 | cls._sock.settimeout(2.0) |
ca404e94 RG |
122 | cls._sock.connect(("127.0.0.1", cls._dnsDistPort)) |
123 | ||
124 | @classmethod | |
125 | def setUpClass(cls): | |
126 | ||
127 | cls.startResponders() | |
aac2124b | 128 | cls.startDNSDist() |
ca404e94 RG |
129 | cls.setUpSockets() |
130 | ||
131 | print("Launching tests..") | |
132 | ||
133 | @classmethod | |
134 | def tearDownClass(cls): | |
0a2087eb RG |
135 | if 'DNSDIST_FAST_TESTS' in os.environ: |
136 | delay = 0.1 | |
137 | else: | |
b1bec9f0 | 138 | delay = 1.0 |
ca404e94 RG |
139 | if cls._dnsdist: |
140 | cls._dnsdist.terminate() | |
0a2087eb RG |
141 | if cls._dnsdist.poll() is None: |
142 | time.sleep(delay) | |
143 | if cls._dnsdist.poll() is None: | |
144 | cls._dnsdist.kill() | |
1ade83b2 | 145 | cls._dnsdist.wait() |
ca404e94 RG |
146 | |
147 | @classmethod | |
fe1c60f2 | 148 | def _ResponderIncrementCounter(cls): |
ec5f5c6b RG |
149 | if threading.currentThread().name in cls._responsesCounter: |
150 | cls._responsesCounter[threading.currentThread().name] += 1 | |
151 | else: | |
152 | cls._responsesCounter[threading.currentThread().name] = 1 | |
153 | ||
fe1c60f2 | 154 | @classmethod |
4aa08b62 | 155 | def _getResponse(cls, request, fromQueue, toQueue, synthesize=None): |
fe1c60f2 RG |
156 | response = None |
157 | if len(request.question) != 1: | |
158 | print("Skipping query with question count %d" % (len(request.question))) | |
159 | return None | |
98650fde RG |
160 | healthCheck = str(request.question[0].name).endswith(cls._healthCheckName) |
161 | if healthCheck: | |
162 | cls._healthCheckCounter += 1 | |
4aa08b62 | 163 | response = dns.message.make_response(request) |
98650fde | 164 | else: |
fe1c60f2 | 165 | cls._ResponderIncrementCounter() |
5df86a8a | 166 | if not fromQueue.empty(): |
4aa08b62 RG |
167 | toQueue.put(request, True, cls._queueTimeout) |
168 | if synthesize is None: | |
169 | response = fromQueue.get(True, cls._queueTimeout) | |
170 | if response: | |
171 | response = copy.copy(response) | |
172 | response.id = request.id | |
fe1c60f2 | 173 | |
e44df0f1 | 174 | if not response: |
4aa08b62 | 175 | if synthesize is not None: |
e44df0f1 | 176 | response = dns.message.make_response(request) |
4aa08b62 | 177 | response.set_rcode(synthesize) |
e44df0f1 RG |
178 | elif cls._answerUnexpected: |
179 | response = dns.message.make_response(request) | |
180 | response.set_rcode(dns.rcode.SERVFAIL) | |
fe1c60f2 RG |
181 | |
182 | return response | |
183 | ||
ec5f5c6b | 184 | @classmethod |
a620f197 | 185 | def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None): |
3ef7ab0d RG |
186 | # trailingDataResponse=True means "ignore trailing data". |
187 | # Other values are either False (meaning "raise an exception") | |
188 | # or are interpreted as a response RCODE for queries with trailing data. | |
a620f197 | 189 | # callback is invoked for every -even healthcheck ones- query and should return a raw response |
4aa08b62 | 190 | ignoreTrailing = trailingDataResponse is True |
3ef7ab0d | 191 | |
ca404e94 RG |
192 | sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) |
193 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | |
ec5f5c6b | 194 | sock.bind(("127.0.0.1", port)) |
ca404e94 RG |
195 | while True: |
196 | data, addr = sock.recvfrom(4096) | |
4aa08b62 RG |
197 | forceRcode = None |
198 | try: | |
199 | request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) | |
200 | except dns.message.TrailingJunk as e: | |
3ef7ab0d | 201 | if trailingDataResponse is False or forceRcode is True: |
4aa08b62 RG |
202 | raise |
203 | print("UDP query with trailing data, synthesizing response") | |
204 | request = dns.message.from_wire(data, ignore_trailing=True) | |
205 | forceRcode = trailingDataResponse | |
206 | ||
f3913dd2 | 207 | wire = None |
a620f197 RG |
208 | if callback: |
209 | wire = callback(request) | |
210 | else: | |
211 | response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) | |
212 | if response: | |
213 | wire = response.to_wire() | |
87c605c4 | 214 | |
f3913dd2 RG |
215 | if not wire: |
216 | continue | |
217 | ||
1ade83b2 | 218 | sock.settimeout(2.0) |
a620f197 | 219 | sock.sendto(wire, addr) |
1ade83b2 | 220 | sock.settimeout(None) |
ca404e94 RG |
221 | sock.close() |
222 | ||
223 | @classmethod | |
a620f197 | 224 | def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None): |
3ef7ab0d RG |
225 | # trailingDataResponse=True means "ignore trailing data". |
226 | # Other values are either False (meaning "raise an exception") | |
227 | # or are interpreted as a response RCODE for queries with trailing data. | |
a620f197 | 228 | # callback is invoked for every -even healthcheck ones- query and should return a raw response |
4aa08b62 | 229 | ignoreTrailing = trailingDataResponse is True |
3ef7ab0d | 230 | |
ca404e94 | 231 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
501af9ae | 232 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
ca404e94 RG |
233 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) |
234 | try: | |
ec5f5c6b | 235 | sock.bind(("127.0.0.1", port)) |
ca404e94 RG |
236 | except socket.error as e: |
237 | print("Error binding in the TCP responder: %s" % str(e)) | |
238 | sys.exit(1) | |
239 | ||
240 | sock.listen(100) | |
241 | while True: | |
b1bec9f0 | 242 | (conn, _) = sock.accept() |
6ac8517d | 243 | conn.settimeout(5.0) |
ca404e94 | 244 | data = conn.recv(2) |
98650fde RG |
245 | if not data: |
246 | conn.close() | |
247 | continue | |
248 | ||
ca404e94 RG |
249 | (datalen,) = struct.unpack("!H", data) |
250 | data = conn.recv(datalen) | |
4aa08b62 RG |
251 | forceRcode = None |
252 | try: | |
253 | request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) | |
254 | except dns.message.TrailingJunk as e: | |
3ef7ab0d | 255 | if trailingDataResponse is False or forceRcode is True: |
4aa08b62 RG |
256 | raise |
257 | print("TCP query with trailing data, synthesizing response") | |
258 | request = dns.message.from_wire(data, ignore_trailing=True) | |
259 | forceRcode = trailingDataResponse | |
260 | ||
a620f197 RG |
261 | if callback: |
262 | wire = callback(request) | |
263 | else: | |
264 | response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) | |
265 | if response: | |
266 | wire = response.to_wire(max_size=65535) | |
267 | ||
268 | if not wire: | |
548c8b66 | 269 | conn.close() |
ca404e94 | 270 | continue |
ca404e94 | 271 | |
ca404e94 RG |
272 | conn.send(struct.pack("!H", len(wire))) |
273 | conn.send(wire) | |
548c8b66 RG |
274 | |
275 | while multipleResponses: | |
5df86a8a | 276 | if fromQueue.empty(): |
548c8b66 RG |
277 | break |
278 | ||
5df86a8a | 279 | response = fromQueue.get(True, cls._queueTimeout) |
548c8b66 RG |
280 | if not response: |
281 | break | |
282 | ||
283 | response = copy.copy(response) | |
284 | response.id = request.id | |
6ac8517d | 285 | wire = response.to_wire(max_size=65535) |
284d460c RG |
286 | try: |
287 | conn.send(struct.pack("!H", len(wire))) | |
288 | conn.send(wire) | |
289 | except socket.error as e: | |
290 | # some of the tests are going to close | |
291 | # the connection on us, just deal with it | |
292 | break | |
548c8b66 | 293 | |
ca404e94 | 294 | conn.close() |
548c8b66 | 295 | |
ca404e94 RG |
296 | sock.close() |
297 | ||
298 | @classmethod | |
55baa1f2 | 299 | def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False): |
ca404e94 | 300 | if useQueue: |
617dfe22 | 301 | cls._toResponderQueue.put(response, True, timeout) |
ca404e94 RG |
302 | |
303 | if timeout: | |
304 | cls._sock.settimeout(timeout) | |
305 | ||
306 | try: | |
55baa1f2 RG |
307 | if not rawQuery: |
308 | query = query.to_wire() | |
309 | cls._sock.send(query) | |
ca404e94 | 310 | data = cls._sock.recv(4096) |
b1bec9f0 | 311 | except socket.timeout: |
ca404e94 RG |
312 | data = None |
313 | finally: | |
314 | if timeout: | |
315 | cls._sock.settimeout(None) | |
316 | ||
317 | receivedQuery = None | |
318 | message = None | |
319 | if useQueue and not cls._fromResponderQueue.empty(): | |
617dfe22 | 320 | receivedQuery = cls._fromResponderQueue.get(True, timeout) |
ca404e94 RG |
321 | if data: |
322 | message = dns.message.from_wire(data) | |
323 | return (receivedQuery, message) | |
324 | ||
325 | @classmethod | |
9396d955 | 326 | def openTCPConnection(cls, timeout=None): |
ca404e94 | 327 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
501af9ae | 328 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
ca404e94 RG |
329 | if timeout: |
330 | sock.settimeout(timeout) | |
331 | ||
0a2087eb | 332 | sock.connect(("127.0.0.1", cls._dnsDistPort)) |
9396d955 | 333 | return sock |
0a2087eb | 334 | |
9396d955 | 335 | @classmethod |
a227f47d RG |
336 | def openTLSConnection(cls, port, serverName, caCert=None, timeout=None): |
337 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
501af9ae | 338 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
a227f47d RG |
339 | if timeout: |
340 | sock.settimeout(timeout) | |
341 | ||
342 | # 2.7.9+ | |
343 | if hasattr(ssl, 'create_default_context'): | |
344 | sslctx = ssl.create_default_context(cafile=caCert) | |
345 | sslsock = sslctx.wrap_socket(sock, server_hostname=serverName) | |
346 | else: | |
347 | sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED) | |
348 | ||
349 | sslsock.connect(("127.0.0.1", port)) | |
350 | return sslsock | |
351 | ||
352 | @classmethod | |
353 | def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False, response=None, timeout=2.0): | |
9396d955 RG |
354 | if not rawQuery: |
355 | wire = query.to_wire() | |
356 | else: | |
357 | wire = query | |
55baa1f2 | 358 | |
a227f47d RG |
359 | if response: |
360 | cls._toResponderQueue.put(response, True, timeout) | |
361 | ||
9396d955 RG |
362 | sock.send(struct.pack("!H", len(wire))) |
363 | sock.send(wire) | |
364 | ||
365 | @classmethod | |
a227f47d | 366 | def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0): |
9396d955 RG |
367 | message = None |
368 | data = sock.recv(2) | |
369 | if data: | |
370 | (datalen,) = struct.unpack("!H", data) | |
371 | data = sock.recv(datalen) | |
ca404e94 | 372 | if data: |
9396d955 | 373 | message = dns.message.from_wire(data) |
a227f47d RG |
374 | |
375 | if useQueue and not cls._fromResponderQueue.empty(): | |
376 | receivedQuery = cls._fromResponderQueue.get(True, timeout) | |
377 | return (receivedQuery, message) | |
378 | else: | |
379 | return message | |
9396d955 RG |
380 | |
381 | @classmethod | |
382 | def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False): | |
383 | message = None | |
384 | if useQueue: | |
385 | cls._toResponderQueue.put(response, True, timeout) | |
386 | ||
387 | sock = cls.openTCPConnection(timeout) | |
388 | ||
389 | try: | |
390 | cls.sendTCPQueryOverConnection(sock, query, rawQuery) | |
391 | message = cls.recvTCPResponseOverConnection(sock) | |
ca404e94 RG |
392 | except socket.timeout as e: |
393 | print("Timeout: %s" % (str(e))) | |
ca404e94 RG |
394 | except socket.error as e: |
395 | print("Network error: %s" % (str(e))) | |
ca404e94 RG |
396 | finally: |
397 | sock.close() | |
398 | ||
399 | receivedQuery = None | |
ca404e94 | 400 | if useQueue and not cls._fromResponderQueue.empty(): |
617dfe22 | 401 | receivedQuery = cls._fromResponderQueue.get(True, timeout) |
9396d955 | 402 | |
ca404e94 | 403 | return (receivedQuery, message) |
617dfe22 | 404 | |
548c8b66 RG |
405 | @classmethod |
406 | def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False): | |
407 | if useQueue: | |
408 | for response in responses: | |
409 | cls._toResponderQueue.put(response, True, timeout) | |
410 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
501af9ae | 411 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
548c8b66 RG |
412 | if timeout: |
413 | sock.settimeout(timeout) | |
414 | ||
415 | sock.connect(("127.0.0.1", cls._dnsDistPort)) | |
416 | messages = [] | |
417 | ||
418 | try: | |
419 | if not rawQuery: | |
420 | wire = query.to_wire() | |
421 | else: | |
422 | wire = query | |
423 | ||
424 | sock.send(struct.pack("!H", len(wire))) | |
425 | sock.send(wire) | |
426 | while True: | |
427 | data = sock.recv(2) | |
428 | if not data: | |
429 | break | |
430 | (datalen,) = struct.unpack("!H", data) | |
431 | data = sock.recv(datalen) | |
432 | messages.append(dns.message.from_wire(data)) | |
433 | ||
434 | except socket.timeout as e: | |
435 | print("Timeout: %s" % (str(e))) | |
436 | except socket.error as e: | |
437 | print("Network error: %s" % (str(e))) | |
438 | finally: | |
439 | sock.close() | |
440 | ||
441 | receivedQuery = None | |
442 | if useQueue and not cls._fromResponderQueue.empty(): | |
443 | receivedQuery = cls._fromResponderQueue.get(True, timeout) | |
444 | return (receivedQuery, messages) | |
445 | ||
617dfe22 RG |
446 | def setUp(self): |
447 | # This function is called before every tests | |
448 | ||
449 | # Clear the responses counters | |
450 | for key in self._responsesCounter: | |
451 | self._responsesCounter[key] = 0 | |
452 | ||
98650fde RG |
453 | self._healthCheckCounter = 0 |
454 | ||
617dfe22 RG |
455 | # Make sure the queues are empty, in case |
456 | # a previous test failed | |
457 | while not self._toResponderQueue.empty(): | |
458 | self._toResponderQueue.get(False) | |
459 | ||
460 | while not self._fromResponderQueue.empty(): | |
fe1c60f2 | 461 | self._fromResponderQueue.get(False) |
1ea747c0 | 462 | |
6bd430bf PD |
463 | super(DNSDistTest, self).setUp() |
464 | ||
3bef39c3 RG |
465 | @classmethod |
466 | def clearToResponderQueue(cls): | |
467 | while not cls._toResponderQueue.empty(): | |
468 | cls._toResponderQueue.get(False) | |
469 | ||
470 | @classmethod | |
471 | def clearFromResponderQueue(cls): | |
472 | while not cls._fromResponderQueue.empty(): | |
473 | cls._fromResponderQueue.get(False) | |
474 | ||
475 | @classmethod | |
476 | def clearResponderQueues(cls): | |
477 | cls.clearToResponderQueue() | |
478 | cls.clearFromResponderQueue() | |
479 | ||
1ea747c0 RG |
480 | @staticmethod |
481 | def generateConsoleKey(): | |
482 | return libnacl.utils.salsa_key() | |
483 | ||
484 | @classmethod | |
485 | def _encryptConsole(cls, command, nonce): | |
b4f23783 | 486 | command = command.encode('UTF-8') |
1ea747c0 RG |
487 | if cls._consoleKey is None: |
488 | return command | |
489 | return libnacl.crypto_secretbox(command, nonce, cls._consoleKey) | |
490 | ||
491 | @classmethod | |
492 | def _decryptConsole(cls, command, nonce): | |
493 | if cls._consoleKey is None: | |
b4f23783 CH |
494 | result = command |
495 | else: | |
496 | result = libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey) | |
497 | return result.decode('UTF-8') | |
1ea747c0 RG |
498 | |
499 | @classmethod | |
500 | def sendConsoleCommand(cls, command, timeout=1.0): | |
501 | ourNonce = libnacl.utils.rand_nonce() | |
502 | theirNonce = None | |
503 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
501af9ae | 504 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
1ea747c0 RG |
505 | if timeout: |
506 | sock.settimeout(timeout) | |
507 | ||
508 | sock.connect(("127.0.0.1", cls._consolePort)) | |
509 | sock.send(ourNonce) | |
510 | theirNonce = sock.recv(len(ourNonce)) | |
7b925432 | 511 | if len(theirNonce) != len(ourNonce): |
05a5b575 | 512 | print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce), len(ourNonce))) |
bdfa6902 RG |
513 | if len(theirNonce) == 0: |
514 | raise socket.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce))) | |
7b925432 | 515 | return None |
1ea747c0 | 516 | |
b4f23783 | 517 | halfNonceSize = int(len(ourNonce) / 2) |
333ea16e RG |
518 | readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:] |
519 | writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:] | |
333ea16e | 520 | msg = cls._encryptConsole(command, writingNonce) |
1ea747c0 RG |
521 | sock.send(struct.pack("!I", len(msg))) |
522 | sock.send(msg) | |
523 | data = sock.recv(4) | |
9c9b4998 RG |
524 | if not data: |
525 | raise socket.error("Got EOF while reading the response size") | |
526 | ||
1ea747c0 RG |
527 | (responseLen,) = struct.unpack("!I", data) |
528 | data = sock.recv(responseLen) | |
333ea16e | 529 | response = cls._decryptConsole(data, readingNonce) |
1ea747c0 | 530 | return response |
5df86a8a RG |
531 | |
532 | def compareOptions(self, a, b): | |
533 | self.assertEquals(len(a), len(b)) | |
b4f23783 | 534 | for idx in range(len(a)): |
5df86a8a RG |
535 | self.assertEquals(a[idx], b[idx]) |
536 | ||
537 | def checkMessageNoEDNS(self, expected, received): | |
538 | self.assertEquals(expected, received) | |
539 | self.assertEquals(received.edns, -1) | |
540 | self.assertEquals(len(received.options), 0) | |
541 | ||
e7c732b8 RG |
542 | def checkMessageEDNSWithoutOptions(self, expected, received): |
543 | self.assertEquals(expected, received) | |
544 | self.assertEquals(received.edns, 0) | |
d70d5ad3 | 545 | self.assertEquals(expected.payload, received.payload) |
e7c732b8 | 546 | |
5df86a8a RG |
547 | def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0): |
548 | self.assertEquals(expected, received) | |
549 | self.assertEquals(received.edns, 0) | |
d70d5ad3 | 550 | self.assertEquals(expected.payload, received.payload) |
5df86a8a RG |
551 | self.assertEquals(len(received.options), withCookies) |
552 | if withCookies: | |
553 | for option in received.options: | |
554 | self.assertEquals(option.otype, 10) | |
555 | ||
cbf4e13a | 556 | def checkMessageEDNSWithECS(self, expected, received, additionalOptions=0): |
5df86a8a RG |
557 | self.assertEquals(expected, received) |
558 | self.assertEquals(received.edns, 0) | |
d70d5ad3 | 559 | self.assertEquals(expected.payload, received.payload) |
cbf4e13a RG |
560 | self.assertEquals(len(received.options), 1 + additionalOptions) |
561 | hasECS = False | |
562 | for option in received.options: | |
563 | if option.otype == clientsubnetoption.ASSIGNED_OPTION_CODE: | |
564 | hasECS = True | |
565 | else: | |
566 | self.assertNotEquals(additionalOptions, 0) | |
567 | ||
5df86a8a | 568 | self.compareOptions(expected.options, received.options) |
cbf4e13a | 569 | self.assertTrue(hasECS) |
5df86a8a | 570 | |
cbf4e13a RG |
571 | def checkQueryEDNSWithECS(self, expected, received, additionalOptions=0): |
572 | self.checkMessageEDNSWithECS(expected, received, additionalOptions) | |
5df86a8a | 573 | |
cbf4e13a RG |
574 | def checkResponseEDNSWithECS(self, expected, received, additionalOptions=0): |
575 | self.checkMessageEDNSWithECS(expected, received, additionalOptions) | |
5df86a8a RG |
576 | |
577 | def checkQueryEDNSWithoutECS(self, expected, received): | |
578 | self.checkMessageEDNSWithoutECS(expected, received) | |
579 | ||
580 | def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0): | |
581 | self.checkMessageEDNSWithoutECS(expected, received, withCookies) | |
582 | ||
583 | def checkQueryNoEDNS(self, expected, received): | |
584 | self.checkMessageNoEDNS(expected, received) | |
585 | ||
586 | def checkResponseNoEDNS(self, expected, received): | |
587 | self.checkMessageNoEDNS(expected, received) | |
b0d08f82 | 588 |