18 class DNSDistTest(unittest
.TestCase
):
20 Set up a dnsdist instance and responder threads.
21 Queries sent to dnsdist are relayed to the responder threads,
22 who reply with the response provided by the tests themselves
23 on a queue. Responder threads also queue the queries received
24 from dnsdist on a separate queue, allowing the tests to check
25 that the queries sent from dnsdist were as expected.
28 _testServerPort
= 5350
29 _toResponderQueue
= Queue
.Queue()
30 _fromResponderQueue
= Queue
.Queue()
32 _dnsdistStartupDelay
= 2.0
34 _responsesCounter
= {}
36 _config_template
= """
38 _config_params
= ['_testServerPort']
39 _acl
= ['127.0.0.1/32']
44 def startResponders(cls
):
45 print("Launching responders..")
47 cls
._UDPResponder
= threading
.Thread(name
='UDP Responder', target
=cls
.UDPResponder
, args
=[cls
._testServerPort
])
48 cls
._UDPResponder
.setDaemon(True)
49 cls
._UDPResponder
.start()
50 cls
._TCPResponder
= threading
.Thread(name
='TCP Responder', target
=cls
.TCPResponder
, args
=[cls
._testServerPort
])
51 cls
._TCPResponder
.setDaemon(True)
52 cls
._TCPResponder
.start()
55 def startDNSDist(cls
, shutUp
=True):
56 print("Launching dnsdist..")
57 conffile
= 'dnsdist_test.conf'
58 params
= tuple([getattr(cls
, param
) for param
in cls
._config
_params
])
60 with
open(conffile
, 'w') as conf
:
61 conf
.write("-- Autogenerated by dnsdisttests.py\n")
62 conf
.write(cls
._config
_template
% params
)
64 dnsdistcmd
= [os
.environ
['DNSDISTBIN'], '-C', conffile
,
65 '-l', '127.0.0.1:%d' % cls
._dnsDistPort
]
67 dnsdistcmd
.extend(['--acl', acl
])
68 print(' '.join(dnsdistcmd
))
71 with
open(os
.devnull
, 'w') as fdDevNull
:
72 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True, stdout
=fdDevNull
)
74 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True)
76 if 'DNSDIST_FAST_TESTS' in os
.environ
:
79 delay
= cls
._dnsdistStartupDelay
83 if cls
._dnsdist
.poll() is not None:
85 sys
.exit(cls
._dnsdist
.returncode
)
88 def setUpSockets(cls
):
89 print("Setting up UDP socket..")
90 cls
._sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
91 cls
._sock
.settimeout(2.0)
92 cls
._sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
98 cls
.startDNSDist(cls
._shutUp
)
101 print("Launching tests..")
104 def tearDownClass(cls
):
105 if 'DNSDIST_FAST_TESTS' in os
.environ
:
110 cls
._dnsdist
.terminate()
111 if cls
._dnsdist
.poll() is None:
113 if cls
._dnsdist
.poll() is None:
118 def _ResponderIncrementCounter(cls
):
119 if threading
.currentThread().name
in cls
._responsesCounter
:
120 cls
._responsesCounter
[threading
.currentThread().name
] += 1
122 cls
._responsesCounter
[threading
.currentThread().name
] = 1
125 def _getResponse(cls
, request
):
127 if len(request
.question
) != 1:
128 print("Skipping query with question count %d" % (len(request
.question
)))
130 healthcheck
= not str(request
.question
[0].name
).endswith('tests.powerdns.com.')
132 cls
._ResponderIncrementCounter
()
133 if not cls
._toResponderQueue
.empty():
134 response
= cls
._toResponderQueue
.get(True, cls
._queueTimeout
)
136 response
= copy
.copy(response
)
137 response
.id = request
.id
138 cls
._fromResponderQueue
.put(request
, True, cls
._queueTimeout
)
141 # unexpected query, or health check
142 response
= dns
.message
.make_response(request
)
147 def UDPResponder(cls
, port
, ignoreTrailing
=False):
148 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
149 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
150 sock
.bind(("127.0.0.1", port
))
152 data
, addr
= sock
.recvfrom(4096)
153 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
154 response
= cls
._getResponse
(request
)
160 sock
.sendto(response
.to_wire(), addr
)
161 sock
.settimeout(None)
165 def TCPResponder(cls
, port
, ignoreTrailing
=False, multipleResponses
=False):
166 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
167 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
169 sock
.bind(("127.0.0.1", port
))
170 except socket
.error
as e
:
171 print("Error binding in the TCP responder: %s" % str(e
))
176 (conn
, _
) = sock
.accept()
179 (datalen
,) = struct
.unpack("!H", data
)
180 data
= conn
.recv(datalen
)
181 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
182 response
= cls
._getResponse
(request
)
188 wire
= response
.to_wire()
189 conn
.send(struct
.pack("!H", len(wire
)))
192 while multipleResponses
:
193 if cls
._toResponderQueue
.empty():
196 response
= cls
._toResponderQueue
.get(True, cls
._queueTimeout
)
200 response
= copy
.copy(response
)
201 response
.id = request
.id
202 wire
= response
.to_wire()
203 conn
.send(struct
.pack("!H", len(wire
)))
211 def sendUDPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
213 cls
._toResponderQueue
.put(response
, True, timeout
)
216 cls
._sock
.settimeout(timeout
)
220 query
= query
.to_wire()
221 cls
._sock
.send(query
)
222 data
= cls
._sock
.recv(4096)
223 except socket
.timeout
:
227 cls
._sock
.settimeout(None)
231 if useQueue
and not cls
._fromResponderQueue
.empty():
232 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
234 message
= dns
.message
.from_wire(data
)
235 return (receivedQuery
, message
)
238 def openTCPConnection(cls
, timeout
=None):
239 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
241 sock
.settimeout(timeout
)
243 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
247 def sendTCPQueryOverConnection(cls
, sock
, query
, rawQuery
=False):
249 wire
= query
.to_wire()
253 sock
.send(struct
.pack("!H", len(wire
)))
257 def recvTCPResponseOverConnection(cls
, sock
):
261 (datalen
,) = struct
.unpack("!H", data
)
262 data
= sock
.recv(datalen
)
264 message
= dns
.message
.from_wire(data
)
268 def sendTCPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
271 cls
._toResponderQueue
.put(response
, True, timeout
)
273 sock
= cls
.openTCPConnection(timeout
)
276 cls
.sendTCPQueryOverConnection(sock
, query
, rawQuery
)
277 message
= cls
.recvTCPResponseOverConnection(sock
)
278 except socket
.timeout
as e
:
279 print("Timeout: %s" % (str(e
)))
280 except socket
.error
as e
:
281 print("Network error: %s" % (str(e
)))
286 if useQueue
and not cls
._fromResponderQueue
.empty():
287 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
289 return (receivedQuery
, message
)
292 def sendTCPQueryWithMultipleResponses(cls
, query
, responses
, useQueue
=True, timeout
=2.0, rawQuery
=False):
294 for response
in responses
:
295 cls
._toResponderQueue
.put(response
, True, timeout
)
296 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
298 sock
.settimeout(timeout
)
300 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
305 wire
= query
.to_wire()
309 sock
.send(struct
.pack("!H", len(wire
)))
315 (datalen
,) = struct
.unpack("!H", data
)
316 data
= sock
.recv(datalen
)
317 messages
.append(dns
.message
.from_wire(data
))
319 except socket
.timeout
as e
:
320 print("Timeout: %s" % (str(e
)))
321 except socket
.error
as e
:
322 print("Network error: %s" % (str(e
)))
327 if useQueue
and not cls
._fromResponderQueue
.empty():
328 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
329 return (receivedQuery
, messages
)
332 # This function is called before every tests
334 # Clear the responses counters
335 for key
in self
._responsesCounter
:
336 self
._responsesCounter
[key
] = 0
338 # Make sure the queues are empty, in case
339 # a previous test failed
340 while not self
._toResponderQueue
.empty():
341 self
._toResponderQueue
.get(False)
343 while not self
._fromResponderQueue
.empty():
344 self
._fromResponderQueue
.get(False)
347 def clearToResponderQueue(cls
):
348 while not cls
._toResponderQueue
.empty():
349 cls
._toResponderQueue
.get(False)
352 def clearFromResponderQueue(cls
):
353 while not cls
._fromResponderQueue
.empty():
354 cls
._fromResponderQueue
.get(False)
357 def clearResponderQueues(cls
):
358 cls
.clearToResponderQueue()
359 cls
.clearFromResponderQueue()
362 def generateConsoleKey():
363 return libnacl
.utils
.salsa_key()
366 def _encryptConsole(cls
, command
, nonce
):
367 if cls
._consoleKey
is None:
369 return libnacl
.crypto_secretbox(command
, nonce
, cls
._consoleKey
)
372 def _decryptConsole(cls
, command
, nonce
):
373 if cls
._consoleKey
is None:
375 return libnacl
.crypto_secretbox_open(command
, nonce
, cls
._consoleKey
)
378 def sendConsoleCommand(cls
, command
, timeout
=1.0):
379 ourNonce
= libnacl
.utils
.rand_nonce()
381 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
383 sock
.settimeout(timeout
)
385 sock
.connect(("127.0.0.1", cls
._consolePort
))
387 theirNonce
= sock
.recv(len(ourNonce
))
389 halfNonceSize
= len(ourNonce
) / 2
390 readingNonce
= ourNonce
[0:halfNonceSize
] + theirNonce
[halfNonceSize
:]
391 writingNonce
= theirNonce
[0:halfNonceSize
] + ourNonce
[halfNonceSize
:]
393 msg
= cls
._encryptConsole
(command
, writingNonce
)
394 sock
.send(struct
.pack("!I", len(msg
)))
397 (responseLen
,) = struct
.unpack("!I", data
)
398 data
= sock
.recv(responseLen
)
399 response
= cls
._decryptConsole
(data
, readingNonce
)