]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/dnsdisttests.py
Merge pull request #4815 from rgacogne/dnsdist-console-no-replay
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
1 #!/usr/bin/env python2
2
3 import copy
4 import Queue
5 import os
6 import socket
7 import struct
8 import subprocess
9 import sys
10 import threading
11 import time
12 import unittest
13 import dns
14 import dns.message
15 import libnacl
16 import libnacl.utils
17
18 class DNSDistTest(unittest.TestCase):
19 """
20 Set up a dnsdist instance and responder threads.
21 Queries sent to dnsdist are relayed to the responder threads,
22 who reply with the response provided by the tests themselves
23 on a queue. Responder threads also queue the queries received
24 from dnsdist on a separate queue, allowing the tests to check
25 that the queries sent from dnsdist were as expected.
26 """
27 _dnsDistPort = 5340
28 _testServerPort = 5350
29 _toResponderQueue = Queue.Queue()
30 _fromResponderQueue = Queue.Queue()
31 _queueTimeout = 1
32 _dnsdistStartupDelay = 2.0
33 _dnsdist = None
34 _responsesCounter = {}
35 _shutUp = True
36 _config_template = """
37 """
38 _config_params = ['_testServerPort']
39 _acl = ['127.0.0.1/32']
40 _consolePort = 5199
41 _consoleKey = None
42
43 @classmethod
44 def startResponders(cls):
45 print("Launching responders..")
46
47 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort])
48 cls._UDPResponder.setDaemon(True)
49 cls._UDPResponder.start()
50 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort])
51 cls._TCPResponder.setDaemon(True)
52 cls._TCPResponder.start()
53
54 @classmethod
55 def startDNSDist(cls, shutUp=True):
56 print("Launching dnsdist..")
57 conffile = 'dnsdist_test.conf'
58 params = tuple([getattr(cls, param) for param in cls._config_params])
59 print(params)
60 with open(conffile, 'w') as conf:
61 conf.write("-- Autogenerated by dnsdisttests.py\n")
62 conf.write(cls._config_template % params)
63
64 dnsdistcmd = [os.environ['DNSDISTBIN'], '-C', conffile,
65 '-l', '127.0.0.1:%d' % cls._dnsDistPort]
66 for acl in cls._acl:
67 dnsdistcmd.extend(['--acl', acl])
68 print(' '.join(dnsdistcmd))
69
70 if shutUp:
71 with open(os.devnull, 'w') as fdDevNull:
72 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdDevNull)
73 else:
74 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True)
75
76 if 'DNSDIST_FAST_TESTS' in os.environ:
77 delay = 0.5
78 else:
79 delay = cls._dnsdistStartupDelay
80
81 time.sleep(delay)
82
83 if cls._dnsdist.poll() is not None:
84 cls._dnsdist.kill()
85 sys.exit(cls._dnsdist.returncode)
86
87 @classmethod
88 def setUpSockets(cls):
89 print("Setting up UDP socket..")
90 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
91 cls._sock.settimeout(2.0)
92 cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
93
94 @classmethod
95 def setUpClass(cls):
96
97 cls.startResponders()
98 cls.startDNSDist(cls._shutUp)
99 cls.setUpSockets()
100
101 print("Launching tests..")
102
103 @classmethod
104 def tearDownClass(cls):
105 if 'DNSDIST_FAST_TESTS' in os.environ:
106 delay = 0.1
107 else:
108 delay = 1.0
109 if cls._dnsdist:
110 cls._dnsdist.terminate()
111 if cls._dnsdist.poll() is None:
112 time.sleep(delay)
113 if cls._dnsdist.poll() is None:
114 cls._dnsdist.kill()
115 cls._dnsdist.wait()
116
117 @classmethod
118 def _ResponderIncrementCounter(cls):
119 if threading.currentThread().name in cls._responsesCounter:
120 cls._responsesCounter[threading.currentThread().name] += 1
121 else:
122 cls._responsesCounter[threading.currentThread().name] = 1
123
124 @classmethod
125 def _getResponse(cls, request):
126 response = None
127 if len(request.question) != 1:
128 print("Skipping query with question count %d" % (len(request.question)))
129 return None
130 healthcheck = not str(request.question[0].name).endswith('tests.powerdns.com.')
131 if not healthcheck:
132 cls._ResponderIncrementCounter()
133 if not cls._toResponderQueue.empty():
134 response = cls._toResponderQueue.get(True, cls._queueTimeout)
135 if response:
136 response = copy.copy(response)
137 response.id = request.id
138 cls._fromResponderQueue.put(request, True, cls._queueTimeout)
139
140 if not response:
141 # unexpected query, or health check
142 response = dns.message.make_response(request)
143
144 return response
145
146 @classmethod
147 def UDPResponder(cls, port, ignoreTrailing=False):
148 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
149 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
150 sock.bind(("127.0.0.1", port))
151 while True:
152 data, addr = sock.recvfrom(4096)
153 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
154 response = cls._getResponse(request)
155
156 if not response:
157 continue
158
159 sock.settimeout(2.0)
160 sock.sendto(response.to_wire(), addr)
161 sock.settimeout(None)
162 sock.close()
163
164 @classmethod
165 def TCPResponder(cls, port, ignoreTrailing=False, multipleResponses=False):
166 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
167 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
168 try:
169 sock.bind(("127.0.0.1", port))
170 except socket.error as e:
171 print("Error binding in the TCP responder: %s" % str(e))
172 sys.exit(1)
173
174 sock.listen(100)
175 while True:
176 (conn, _) = sock.accept()
177 conn.settimeout(2.0)
178 data = conn.recv(2)
179 (datalen,) = struct.unpack("!H", data)
180 data = conn.recv(datalen)
181 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
182 response = cls._getResponse(request)
183
184 if not response:
185 conn.close()
186 continue
187
188 wire = response.to_wire()
189 conn.send(struct.pack("!H", len(wire)))
190 conn.send(wire)
191
192 while multipleResponses:
193 if cls._toResponderQueue.empty():
194 break
195
196 response = cls._toResponderQueue.get(True, cls._queueTimeout)
197 if not response:
198 break
199
200 response = copy.copy(response)
201 response.id = request.id
202 wire = response.to_wire()
203 conn.send(struct.pack("!H", len(wire)))
204 conn.send(wire)
205
206 conn.close()
207
208 sock.close()
209
210 @classmethod
211 def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
212 if useQueue:
213 cls._toResponderQueue.put(response, True, timeout)
214
215 if timeout:
216 cls._sock.settimeout(timeout)
217
218 try:
219 if not rawQuery:
220 query = query.to_wire()
221 cls._sock.send(query)
222 data = cls._sock.recv(4096)
223 except socket.timeout:
224 data = None
225 finally:
226 if timeout:
227 cls._sock.settimeout(None)
228
229 receivedQuery = None
230 message = None
231 if useQueue and not cls._fromResponderQueue.empty():
232 receivedQuery = cls._fromResponderQueue.get(True, timeout)
233 if data:
234 message = dns.message.from_wire(data)
235 return (receivedQuery, message)
236
237 @classmethod
238 def openTCPConnection(cls, timeout=None):
239 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
240 if timeout:
241 sock.settimeout(timeout)
242
243 sock.connect(("127.0.0.1", cls._dnsDistPort))
244 return sock
245
246 @classmethod
247 def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False):
248 if not rawQuery:
249 wire = query.to_wire()
250 else:
251 wire = query
252
253 sock.send(struct.pack("!H", len(wire)))
254 sock.send(wire)
255
256 @classmethod
257 def recvTCPResponseOverConnection(cls, sock):
258 message = None
259 data = sock.recv(2)
260 if data:
261 (datalen,) = struct.unpack("!H", data)
262 data = sock.recv(datalen)
263 if data:
264 message = dns.message.from_wire(data)
265 return message
266
267 @classmethod
268 def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
269 message = None
270 if useQueue:
271 cls._toResponderQueue.put(response, True, timeout)
272
273 sock = cls.openTCPConnection(timeout)
274
275 try:
276 cls.sendTCPQueryOverConnection(sock, query, rawQuery)
277 message = cls.recvTCPResponseOverConnection(sock)
278 except socket.timeout as e:
279 print("Timeout: %s" % (str(e)))
280 except socket.error as e:
281 print("Network error: %s" % (str(e)))
282 finally:
283 sock.close()
284
285 receivedQuery = None
286 if useQueue and not cls._fromResponderQueue.empty():
287 receivedQuery = cls._fromResponderQueue.get(True, timeout)
288
289 return (receivedQuery, message)
290
291 @classmethod
292 def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False):
293 if useQueue:
294 for response in responses:
295 cls._toResponderQueue.put(response, True, timeout)
296 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
297 if timeout:
298 sock.settimeout(timeout)
299
300 sock.connect(("127.0.0.1", cls._dnsDistPort))
301 messages = []
302
303 try:
304 if not rawQuery:
305 wire = query.to_wire()
306 else:
307 wire = query
308
309 sock.send(struct.pack("!H", len(wire)))
310 sock.send(wire)
311 while True:
312 data = sock.recv(2)
313 if not data:
314 break
315 (datalen,) = struct.unpack("!H", data)
316 data = sock.recv(datalen)
317 messages.append(dns.message.from_wire(data))
318
319 except socket.timeout as e:
320 print("Timeout: %s" % (str(e)))
321 except socket.error as e:
322 print("Network error: %s" % (str(e)))
323 finally:
324 sock.close()
325
326 receivedQuery = None
327 if useQueue and not cls._fromResponderQueue.empty():
328 receivedQuery = cls._fromResponderQueue.get(True, timeout)
329 return (receivedQuery, messages)
330
331 def setUp(self):
332 # This function is called before every tests
333
334 # Clear the responses counters
335 for key in self._responsesCounter:
336 self._responsesCounter[key] = 0
337
338 # Make sure the queues are empty, in case
339 # a previous test failed
340 while not self._toResponderQueue.empty():
341 self._toResponderQueue.get(False)
342
343 while not self._fromResponderQueue.empty():
344 self._fromResponderQueue.get(False)
345
346 @classmethod
347 def clearToResponderQueue(cls):
348 while not cls._toResponderQueue.empty():
349 cls._toResponderQueue.get(False)
350
351 @classmethod
352 def clearFromResponderQueue(cls):
353 while not cls._fromResponderQueue.empty():
354 cls._fromResponderQueue.get(False)
355
356 @classmethod
357 def clearResponderQueues(cls):
358 cls.clearToResponderQueue()
359 cls.clearFromResponderQueue()
360
361 @staticmethod
362 def generateConsoleKey():
363 return libnacl.utils.salsa_key()
364
365 @classmethod
366 def _encryptConsole(cls, command, nonce):
367 if cls._consoleKey is None:
368 return command
369 return libnacl.crypto_secretbox(command, nonce, cls._consoleKey)
370
371 @classmethod
372 def _decryptConsole(cls, command, nonce):
373 if cls._consoleKey is None:
374 return command
375 return libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey)
376
377 @classmethod
378 def sendConsoleCommand(cls, command, timeout=1.0):
379 ourNonce = libnacl.utils.rand_nonce()
380 theirNonce = None
381 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
382 if timeout:
383 sock.settimeout(timeout)
384
385 sock.connect(("127.0.0.1", cls._consolePort))
386 sock.send(ourNonce)
387 theirNonce = sock.recv(len(ourNonce))
388
389 halfNonceSize = len(ourNonce) / 2
390 readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:]
391 writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:]
392
393 msg = cls._encryptConsole(command, writingNonce)
394 sock.send(struct.pack("!I", len(msg)))
395 sock.send(msg)
396 data = sock.recv(4)
397 (responseLen,) = struct.unpack("!I", data)
398 data = sock.recv(responseLen)
399 response = cls._decryptConsole(data, readingNonce)
400 return response