]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/dnsdisttests.py
Merge pull request #7159 from rgacogne/rec41-revert-6980
[thirdparty/pdns.git] / regression-tests.dnsdist / dnsdisttests.py
1 #!/usr/bin/env python2
2
3 import copy
4 import Queue
5 import os
6 import socket
7 import struct
8 import subprocess
9 import sys
10 import threading
11 import time
12 import unittest
13 import clientsubnetoption
14 import dns
15 import dns.message
16 import libnacl
17 import libnacl.utils
18
19 class DNSDistTest(unittest.TestCase):
20 """
21 Set up a dnsdist instance and responder threads.
22 Queries sent to dnsdist are relayed to the responder threads,
23 who reply with the response provided by the tests themselves
24 on a queue. Responder threads also queue the queries received
25 from dnsdist on a separate queue, allowing the tests to check
26 that the queries sent from dnsdist were as expected.
27 """
28 _dnsDistPort = 5340
29 _dnsDistListeningAddr = "127.0.0.1"
30 _testServerPort = 5350
31 _toResponderQueue = Queue.Queue()
32 _fromResponderQueue = Queue.Queue()
33 _queueTimeout = 1
34 _dnsdistStartupDelay = 2.0
35 _dnsdist = None
36 _responsesCounter = {}
37 _shutUp = True
38 _config_template = """
39 """
40 _config_params = ['_testServerPort']
41 _acl = ['127.0.0.1/32']
42 _consolePort = 5199
43 _consoleKey = None
44
45 @classmethod
46 def startResponders(cls):
47 print("Launching responders..")
48
49 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
50 cls._UDPResponder.setDaemon(True)
51 cls._UDPResponder.start()
52 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
53 cls._TCPResponder.setDaemon(True)
54 cls._TCPResponder.start()
55
56 @classmethod
57 def startDNSDist(cls, shutUp=True):
58 print("Launching dnsdist..")
59 conffile = 'dnsdist_test.conf'
60 params = tuple([getattr(cls, param) for param in cls._config_params])
61 print(params)
62 with open(conffile, 'w') as conf:
63 conf.write("-- Autogenerated by dnsdisttests.py\n")
64 conf.write(cls._config_template % params)
65
66 dnsdistcmd = [os.environ['DNSDISTBIN'], '-C', conffile,
67 '-l', '%s:%d' % (cls._dnsDistListeningAddr, cls._dnsDistPort) ]
68 for acl in cls._acl:
69 dnsdistcmd.extend(['--acl', acl])
70 print(' '.join(dnsdistcmd))
71
72 if shutUp:
73 with open(os.devnull, 'w') as fdDevNull:
74 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True, stdout=fdDevNull)
75 else:
76 cls._dnsdist = subprocess.Popen(dnsdistcmd, close_fds=True)
77
78 if 'DNSDIST_FAST_TESTS' in os.environ:
79 delay = 0.5
80 else:
81 delay = cls._dnsdistStartupDelay
82
83 time.sleep(delay)
84
85 if cls._dnsdist.poll() is not None:
86 cls._dnsdist.kill()
87 sys.exit(cls._dnsdist.returncode)
88
89 @classmethod
90 def setUpSockets(cls):
91 print("Setting up UDP socket..")
92 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
93 cls._sock.settimeout(2.0)
94 cls._sock.connect(("127.0.0.1", cls._dnsDistPort))
95
96 @classmethod
97 def setUpClass(cls):
98
99 cls.startResponders()
100 cls.startDNSDist(cls._shutUp)
101 cls.setUpSockets()
102
103 print("Launching tests..")
104
105 @classmethod
106 def tearDownClass(cls):
107 if 'DNSDIST_FAST_TESTS' in os.environ:
108 delay = 0.1
109 else:
110 delay = 1.0
111 if cls._dnsdist:
112 cls._dnsdist.terminate()
113 if cls._dnsdist.poll() is None:
114 time.sleep(delay)
115 if cls._dnsdist.poll() is None:
116 cls._dnsdist.kill()
117 cls._dnsdist.wait()
118
119 @classmethod
120 def _ResponderIncrementCounter(cls):
121 if threading.currentThread().name in cls._responsesCounter:
122 cls._responsesCounter[threading.currentThread().name] += 1
123 else:
124 cls._responsesCounter[threading.currentThread().name] = 1
125
126 @classmethod
127 def _getResponse(cls, request, fromQueue, toQueue):
128 response = None
129 if len(request.question) != 1:
130 print("Skipping query with question count %d" % (len(request.question)))
131 return None
132 healthcheck = not str(request.question[0].name).endswith('tests.powerdns.com.')
133 if not healthcheck:
134 cls._ResponderIncrementCounter()
135 if not fromQueue.empty():
136 response = fromQueue.get(True, cls._queueTimeout)
137 if response:
138 response = copy.copy(response)
139 response.id = request.id
140 toQueue.put(request, True, cls._queueTimeout)
141
142 if not response:
143 # unexpected query, or health check
144 response = dns.message.make_response(request)
145
146 return response
147
148 @classmethod
149 def UDPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False):
150 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
151 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
152 sock.bind(("127.0.0.1", port))
153 while True:
154 data, addr = sock.recvfrom(4096)
155 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
156 response = cls._getResponse(request, fromQueue, toQueue)
157
158 if not response:
159 continue
160
161 sock.settimeout(2.0)
162 sock.sendto(response.to_wire(), addr)
163 sock.settimeout(None)
164 sock.close()
165
166 @classmethod
167 def TCPResponder(cls, port, fromQueue, toQueue, ignoreTrailing=False, multipleResponses=False):
168 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
169 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
170 try:
171 sock.bind(("127.0.0.1", port))
172 except socket.error as e:
173 print("Error binding in the TCP responder: %s" % str(e))
174 sys.exit(1)
175
176 sock.listen(100)
177 while True:
178 (conn, _) = sock.accept()
179 conn.settimeout(2.0)
180 data = conn.recv(2)
181 (datalen,) = struct.unpack("!H", data)
182 data = conn.recv(datalen)
183 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
184 response = cls._getResponse(request, fromQueue, toQueue)
185
186 if not response:
187 conn.close()
188 continue
189
190 wire = response.to_wire()
191 conn.send(struct.pack("!H", len(wire)))
192 conn.send(wire)
193
194 while multipleResponses:
195 if fromQueue.empty():
196 break
197
198 response = fromQueue.get(True, cls._queueTimeout)
199 if not response:
200 break
201
202 response = copy.copy(response)
203 response.id = request.id
204 wire = response.to_wire()
205 try:
206 conn.send(struct.pack("!H", len(wire)))
207 conn.send(wire)
208 except socket.error as e:
209 # some of the tests are going to close
210 # the connection on us, just deal with it
211 break
212
213 conn.close()
214
215 sock.close()
216
217 @classmethod
218 def sendUDPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
219 if useQueue:
220 cls._toResponderQueue.put(response, True, timeout)
221
222 if timeout:
223 cls._sock.settimeout(timeout)
224
225 try:
226 if not rawQuery:
227 query = query.to_wire()
228 cls._sock.send(query)
229 data = cls._sock.recv(4096)
230 except socket.timeout:
231 data = None
232 finally:
233 if timeout:
234 cls._sock.settimeout(None)
235
236 receivedQuery = None
237 message = None
238 if useQueue and not cls._fromResponderQueue.empty():
239 receivedQuery = cls._fromResponderQueue.get(True, timeout)
240 if data:
241 message = dns.message.from_wire(data)
242 return (receivedQuery, message)
243
244 @classmethod
245 def openTCPConnection(cls, timeout=None):
246 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
247 if timeout:
248 sock.settimeout(timeout)
249
250 sock.connect(("127.0.0.1", cls._dnsDistPort))
251 return sock
252
253 @classmethod
254 def sendTCPQueryOverConnection(cls, sock, query, rawQuery=False):
255 if not rawQuery:
256 wire = query.to_wire()
257 else:
258 wire = query
259
260 sock.send(struct.pack("!H", len(wire)))
261 sock.send(wire)
262
263 @classmethod
264 def recvTCPResponseOverConnection(cls, sock):
265 message = None
266 data = sock.recv(2)
267 if data:
268 (datalen,) = struct.unpack("!H", data)
269 data = sock.recv(datalen)
270 if data:
271 message = dns.message.from_wire(data)
272 return message
273
274 @classmethod
275 def sendTCPQuery(cls, query, response, useQueue=True, timeout=2.0, rawQuery=False):
276 message = None
277 if useQueue:
278 cls._toResponderQueue.put(response, True, timeout)
279
280 sock = cls.openTCPConnection(timeout)
281
282 try:
283 cls.sendTCPQueryOverConnection(sock, query, rawQuery)
284 message = cls.recvTCPResponseOverConnection(sock)
285 except socket.timeout as e:
286 print("Timeout: %s" % (str(e)))
287 except socket.error as e:
288 print("Network error: %s" % (str(e)))
289 finally:
290 sock.close()
291
292 receivedQuery = None
293 if useQueue and not cls._fromResponderQueue.empty():
294 receivedQuery = cls._fromResponderQueue.get(True, timeout)
295
296 return (receivedQuery, message)
297
298 @classmethod
299 def sendTCPQueryWithMultipleResponses(cls, query, responses, useQueue=True, timeout=2.0, rawQuery=False):
300 if useQueue:
301 for response in responses:
302 cls._toResponderQueue.put(response, True, timeout)
303 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
304 if timeout:
305 sock.settimeout(timeout)
306
307 sock.connect(("127.0.0.1", cls._dnsDistPort))
308 messages = []
309
310 try:
311 if not rawQuery:
312 wire = query.to_wire()
313 else:
314 wire = query
315
316 sock.send(struct.pack("!H", len(wire)))
317 sock.send(wire)
318 while True:
319 data = sock.recv(2)
320 if not data:
321 break
322 (datalen,) = struct.unpack("!H", data)
323 data = sock.recv(datalen)
324 messages.append(dns.message.from_wire(data))
325
326 except socket.timeout as e:
327 print("Timeout: %s" % (str(e)))
328 except socket.error as e:
329 print("Network error: %s" % (str(e)))
330 finally:
331 sock.close()
332
333 receivedQuery = None
334 if useQueue and not cls._fromResponderQueue.empty():
335 receivedQuery = cls._fromResponderQueue.get(True, timeout)
336 return (receivedQuery, messages)
337
338 def setUp(self):
339 # This function is called before every tests
340
341 # Clear the responses counters
342 for key in self._responsesCounter:
343 self._responsesCounter[key] = 0
344
345 # Make sure the queues are empty, in case
346 # a previous test failed
347 while not self._toResponderQueue.empty():
348 self._toResponderQueue.get(False)
349
350 while not self._fromResponderQueue.empty():
351 self._fromResponderQueue.get(False)
352
353 @classmethod
354 def clearToResponderQueue(cls):
355 while not cls._toResponderQueue.empty():
356 cls._toResponderQueue.get(False)
357
358 @classmethod
359 def clearFromResponderQueue(cls):
360 while not cls._fromResponderQueue.empty():
361 cls._fromResponderQueue.get(False)
362
363 @classmethod
364 def clearResponderQueues(cls):
365 cls.clearToResponderQueue()
366 cls.clearFromResponderQueue()
367
368 @staticmethod
369 def generateConsoleKey():
370 return libnacl.utils.salsa_key()
371
372 @classmethod
373 def _encryptConsole(cls, command, nonce):
374 if cls._consoleKey is None:
375 return command
376 return libnacl.crypto_secretbox(command, nonce, cls._consoleKey)
377
378 @classmethod
379 def _decryptConsole(cls, command, nonce):
380 if cls._consoleKey is None:
381 return command
382 return libnacl.crypto_secretbox_open(command, nonce, cls._consoleKey)
383
384 @classmethod
385 def sendConsoleCommand(cls, command, timeout=1.0):
386 ourNonce = libnacl.utils.rand_nonce()
387 theirNonce = None
388 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
389 if timeout:
390 sock.settimeout(timeout)
391
392 sock.connect(("127.0.0.1", cls._consolePort))
393 sock.send(ourNonce)
394 theirNonce = sock.recv(len(ourNonce))
395 if len(theirNonce) != len(ourNonce):
396 print("Received a nonce of size %, expecting %, console command will not be sent!" % (len(theirNonce), len(ourNonce)))
397 return None
398
399 halfNonceSize = len(ourNonce) / 2
400 readingNonce = ourNonce[0:halfNonceSize] + theirNonce[halfNonceSize:]
401 writingNonce = theirNonce[0:halfNonceSize] + ourNonce[halfNonceSize:]
402 msg = cls._encryptConsole(command, writingNonce)
403 sock.send(struct.pack("!I", len(msg)))
404 sock.send(msg)
405 data = sock.recv(4)
406 (responseLen,) = struct.unpack("!I", data)
407 data = sock.recv(responseLen)
408 response = cls._decryptConsole(data, readingNonce)
409 return response
410
411 def compareOptions(self, a, b):
412 self.assertEquals(len(a), len(b))
413 for idx in xrange(len(a)):
414 self.assertEquals(a[idx], b[idx])
415
416 def checkMessageNoEDNS(self, expected, received):
417 self.assertEquals(expected, received)
418 self.assertEquals(received.edns, -1)
419 self.assertEquals(len(received.options), 0)
420
421 def checkMessageEDNSWithoutECS(self, expected, received, withCookies=0):
422 self.assertEquals(expected, received)
423 self.assertEquals(received.edns, 0)
424 self.assertEquals(len(received.options), withCookies)
425 if withCookies:
426 for option in received.options:
427 self.assertEquals(option.otype, 10)
428
429 def checkMessageEDNSWithECS(self, expected, received):
430 self.assertEquals(expected, received)
431 self.assertEquals(received.edns, 0)
432 self.assertEquals(len(received.options), 1)
433 self.assertEquals(received.options[0].otype, clientsubnetoption.ASSIGNED_OPTION_CODE)
434 self.compareOptions(expected.options, received.options)
435
436 def checkQueryEDNSWithECS(self, expected, received):
437 self.checkMessageEDNSWithECS(expected, received)
438
439 def checkResponseEDNSWithECS(self, expected, received):
440 self.checkMessageEDNSWithECS(expected, received)
441
442 def checkQueryEDNSWithoutECS(self, expected, received):
443 self.checkMessageEDNSWithoutECS(expected, received)
444
445 def checkResponseEDNSWithoutECS(self, expected, received, withCookies=0):
446 self.checkMessageEDNSWithoutECS(expected, received, withCookies)
447
448 def checkQueryNoEDNS(self, expected, received):
449 self.checkMessageNoEDNS(expected, received)
450
451 def checkResponseNoEDNS(self, expected, received):
452 self.checkMessageNoEDNS(expected, received)