13 import clientsubnetoption
19 class DNSDistTest(unittest
.TestCase
):
21 Set up a dnsdist instance and responder threads.
22 Queries sent to dnsdist are relayed to the responder threads,
23 who reply with the response provided by the tests themselves
24 on a queue. Responder threads also queue the queries received
25 from dnsdist on a separate queue, allowing the tests to check
26 that the queries sent from dnsdist were as expected.
29 _dnsDistListeningAddr
= "127.0.0.1"
30 _testServerPort
= 5350
31 _toResponderQueue
= Queue
.Queue()
32 _fromResponderQueue
= Queue
.Queue()
34 _dnsdistStartupDelay
= 2.0
36 _responsesCounter
= {}
38 _config_template
= """
40 _config_params
= ['_testServerPort']
41 _acl
= ['127.0.0.1/32']
46 def startResponders(cls
):
47 print("Launching responders..")
49 cls
._UDPResponder
= threading
.Thread(name
='UDP Responder', target
=cls
.UDPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
50 cls
._UDPResponder
.setDaemon(True)
51 cls
._UDPResponder
.start()
52 cls
._TCPResponder
= threading
.Thread(name
='TCP Responder', target
=cls
.TCPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
53 cls
._TCPResponder
.setDaemon(True)
54 cls
._TCPResponder
.start()
57 def startDNSDist(cls
, shutUp
=True):
58 print("Launching dnsdist..")
59 conffile
= 'dnsdist_test.conf'
60 params
= tuple([getattr(cls
, param
) for param
in cls
._config
_params
])
62 with
open(conffile
, 'w') as conf
:
63 conf
.write("-- Autogenerated by dnsdisttests.py\n")
64 conf
.write(cls
._config
_template
% params
)
66 dnsdistcmd
= [os
.environ
['DNSDISTBIN'], '-C', conffile
,
67 '-l', '%s:%d' % (cls
._dnsDistListeningAddr
, cls
._dnsDistPort
) ]
69 dnsdistcmd
.extend(['--acl', acl
])
70 print(' '.join(dnsdistcmd
))
73 with
open(os
.devnull
, 'w') as fdDevNull
:
74 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True, stdout
=fdDevNull
)
76 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True)
78 if 'DNSDIST_FAST_TESTS' in os
.environ
:
81 delay
= cls
._dnsdistStartupDelay
85 if cls
._dnsdist
.poll() is not None:
87 sys
.exit(cls
._dnsdist
.returncode
)
90 def setUpSockets(cls
):
91 print("Setting up UDP socket..")
92 cls
._sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
93 cls
._sock
.settimeout(2.0)
94 cls
._sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
100 cls
.startDNSDist(cls
._shutUp
)
103 print("Launching tests..")
106 def tearDownClass(cls
):
107 if 'DNSDIST_FAST_TESTS' in os
.environ
:
112 cls
._dnsdist
.terminate()
113 if cls
._dnsdist
.poll() is None:
115 if cls
._dnsdist
.poll() is None:
120 def _ResponderIncrementCounter(cls
):
121 if threading
.currentThread().name
in cls
._responsesCounter
:
122 cls
._responsesCounter
[threading
.currentThread().name
] += 1
124 cls
._responsesCounter
[threading
.currentThread().name
] = 1
127 def _getResponse(cls
, request
, fromQueue
, toQueue
):
129 if len(request
.question
) != 1:
130 print("Skipping query with question count %d" % (len(request
.question
)))
132 healthcheck
= not str(request
.question
[0].name
).endswith('tests.powerdns.com.')
134 cls
._ResponderIncrementCounter
()
135 if not fromQueue
.empty():
136 response
= fromQueue
.get(True, cls
._queueTimeout
)
138 response
= copy
.copy(response
)
139 response
.id = request
.id
140 toQueue
.put(request
, True, cls
._queueTimeout
)
143 # unexpected query, or health check
144 response
= dns
.message
.make_response(request
)
149 def UDPResponder(cls
, port
, fromQueue
, toQueue
, ignoreTrailing
=False):
150 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
151 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
152 sock
.bind(("127.0.0.1", port
))
154 data
, addr
= sock
.recvfrom(4096)
155 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
156 response
= cls
._getResponse
(request
, fromQueue
, toQueue
)
162 sock
.sendto(response
.to_wire(), addr
)
163 sock
.settimeout(None)
167 def TCPResponder(cls
, port
, fromQueue
, toQueue
, ignoreTrailing
=False, multipleResponses
=False):
168 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
169 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
171 sock
.bind(("127.0.0.1", port
))
172 except socket
.error
as e
:
173 print("Error binding in the TCP responder: %s" % str(e
))
178 (conn
, _
) = sock
.accept()
181 (datalen
,) = struct
.unpack("!H", data
)
182 data
= conn
.recv(datalen
)
183 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
184 response
= cls
._getResponse
(request
, fromQueue
, toQueue
)
190 wire
= response
.to_wire()
191 conn
.send(struct
.pack("!H", len(wire
)))
194 while multipleResponses
:
195 if fromQueue
.empty():
198 response
= fromQueue
.get(True, cls
._queueTimeout
)
202 response
= copy
.copy(response
)
203 response
.id = request
.id
204 wire
= response
.to_wire()
206 conn
.send(struct
.pack("!H", len(wire
)))
208 except socket
.error
as e
:
209 # some of the tests are going to close
210 # the connection on us, just deal with it
218 def sendUDPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
220 cls
._toResponderQueue
.put(response
, True, timeout
)
223 cls
._sock
.settimeout(timeout
)
227 query
= query
.to_wire()
228 cls
._sock
.send(query
)
229 data
= cls
._sock
.recv(4096)
230 except socket
.timeout
:
234 cls
._sock
.settimeout(None)
238 if useQueue
and not cls
._fromResponderQueue
.empty():
239 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
241 message
= dns
.message
.from_wire(data
)
242 return (receivedQuery
, message
)
245 def openTCPConnection(cls
, timeout
=None):
246 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
248 sock
.settimeout(timeout
)
250 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
254 def sendTCPQueryOverConnection(cls
, sock
, query
, rawQuery
=False):
256 wire
= query
.to_wire()
260 sock
.send(struct
.pack("!H", len(wire
)))
264 def recvTCPResponseOverConnection(cls
, sock
):
268 (datalen
,) = struct
.unpack("!H", data
)
269 data
= sock
.recv(datalen
)
271 message
= dns
.message
.from_wire(data
)
275 def sendTCPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
278 cls
._toResponderQueue
.put(response
, True, timeout
)
280 sock
= cls
.openTCPConnection(timeout
)
283 cls
.sendTCPQueryOverConnection(sock
, query
, rawQuery
)
284 message
= cls
.recvTCPResponseOverConnection(sock
)
285 except socket
.timeout
as e
:
286 print("Timeout: %s" % (str(e
)))
287 except socket
.error
as e
:
288 print("Network error: %s" % (str(e
)))
293 if useQueue
and not cls
._fromResponderQueue
.empty():
294 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
296 return (receivedQuery
, message
)
299 def sendTCPQueryWithMultipleResponses(cls
, query
, responses
, useQueue
=True, timeout
=2.0, rawQuery
=False):
301 for response
in responses
:
302 cls
._toResponderQueue
.put(response
, True, timeout
)
303 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
305 sock
.settimeout(timeout
)
307 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
312 wire
= query
.to_wire()
316 sock
.send(struct
.pack("!H", len(wire
)))
322 (datalen
,) = struct
.unpack("!H", data
)
323 data
= sock
.recv(datalen
)
324 messages
.append(dns
.message
.from_wire(data
))
326 except socket
.timeout
as e
:
327 print("Timeout: %s" % (str(e
)))
328 except socket
.error
as e
:
329 print("Network error: %s" % (str(e
)))
334 if useQueue
and not cls
._fromResponderQueue
.empty():
335 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
336 return (receivedQuery
, messages
)
339 # This function is called before every tests
341 # Clear the responses counters
342 for key
in self
._responsesCounter
:
343 self
._responsesCounter
[key
] = 0
345 # Make sure the queues are empty, in case
346 # a previous test failed
347 while not self
._toResponderQueue
.empty():
348 self
._toResponderQueue
.get(False)
350 while not self
._fromResponderQueue
.empty():
351 self
._fromResponderQueue
.get(False)
354 def clearToResponderQueue(cls
):
355 while not cls
._toResponderQueue
.empty():
356 cls
._toResponderQueue
.get(False)
359 def clearFromResponderQueue(cls
):
360 while not cls
._fromResponderQueue
.empty():
361 cls
._fromResponderQueue
.get(False)
364 def clearResponderQueues(cls
):
365 cls
.clearToResponderQueue()
366 cls
.clearFromResponderQueue()
369 def generateConsoleKey():
370 return libnacl
.utils
.salsa_key()
373 def _encryptConsole(cls
, command
, nonce
):
374 if cls
._consoleKey
is None:
376 return libnacl
.crypto_secretbox(command
, nonce
, cls
._consoleKey
)
379 def _decryptConsole(cls
, command
, nonce
):
380 if cls
._consoleKey
is None:
382 return libnacl
.crypto_secretbox_open(command
, nonce
, cls
._consoleKey
)
385 def sendConsoleCommand(cls
, command
, timeout
=1.0):
386 ourNonce
= libnacl
.utils
.rand_nonce()
388 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
390 sock
.settimeout(timeout
)
392 sock
.connect(("127.0.0.1", cls
._consolePort
))
394 theirNonce
= sock
.recv(len(ourNonce
))
395 if len(theirNonce
) != len(ourNonce
):
396 print("Received a nonce of size %, expecting %, console command will not be sent!" % (len(theirNonce
), len(ourNonce
)))
399 halfNonceSize
= len(ourNonce
) / 2
400 readingNonce
= ourNonce
[0:halfNonceSize
] + theirNonce
[halfNonceSize
:]
401 writingNonce
= theirNonce
[0:halfNonceSize
] + ourNonce
[halfNonceSize
:]
402 msg
= cls
._encryptConsole
(command
, writingNonce
)
403 sock
.send(struct
.pack("!I", len(msg
)))
406 (responseLen
,) = struct
.unpack("!I", data
)
407 data
= sock
.recv(responseLen
)
408 response
= cls
._decryptConsole
(data
, readingNonce
)
411 def compareOptions(self
, a
, b
):
412 self
.assertEquals(len(a
), len(b
))
413 for idx
in xrange(len(a
)):
414 self
.assertEquals(a
[idx
], b
[idx
])
416 def checkMessageNoEDNS(self
, expected
, received
):
417 self
.assertEquals(expected
, received
)
418 self
.assertEquals(received
.edns
, -1)
419 self
.assertEquals(len(received
.options
), 0)
421 def checkMessageEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
422 self
.assertEquals(expected
, received
)
423 self
.assertEquals(received
.edns
, 0)
424 self
.assertEquals(len(received
.options
), withCookies
)
426 for option
in received
.options
:
427 self
.assertEquals(option
.otype
, 10)
429 def checkMessageEDNSWithECS(self
, expected
, received
):
430 self
.assertEquals(expected
, received
)
431 self
.assertEquals(received
.edns
, 0)
432 self
.assertEquals(len(received
.options
), 1)
433 self
.assertEquals(received
.options
[0].otype
, clientsubnetoption
.ASSIGNED_OPTION_CODE
)
434 self
.compareOptions(expected
.options
, received
.options
)
436 def checkQueryEDNSWithECS(self
, expected
, received
):
437 self
.checkMessageEDNSWithECS(expected
, received
)
439 def checkResponseEDNSWithECS(self
, expected
, received
):
440 self
.checkMessageEDNSWithECS(expected
, received
)
442 def checkQueryEDNSWithoutECS(self
, expected
, received
):
443 self
.checkMessageEDNSWithoutECS(expected
, received
)
445 def checkResponseEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
446 self
.checkMessageEDNSWithoutECS(expected
, received
, withCookies
)
448 def checkQueryNoEDNS(self
, expected
, received
):
449 self
.checkMessageNoEDNS(expected
, received
)
451 def checkResponseNoEDNS(self
, expected
, received
):
452 self
.checkMessageNoEDNS(expected
, received
)