13 import clientsubnetoption
19 # Python2/3 compatibility hacks
21 from queue
import Queue
23 from Queue
import Queue
31 class DNSDistTest(unittest
.TestCase
):
33 Set up a dnsdist instance and responder threads.
34 Queries sent to dnsdist are relayed to the responder threads,
35 who reply with the response provided by the tests themselves
36 on a queue. Responder threads also queue the queries received
37 from dnsdist on a separate queue, allowing the tests to check
38 that the queries sent from dnsdist were as expected.
41 _dnsDistListeningAddr
= "127.0.0.1"
42 _testServerPort
= 5350
43 _toResponderQueue
= Queue()
44 _fromResponderQueue
= Queue()
46 _dnsdistStartupDelay
= 2.0
48 _responsesCounter
= {}
50 _config_template
= """
52 _config_params
= ['_testServerPort']
53 _acl
= ['127.0.0.1/32']
56 _healthCheckName
= 'a.root-servers.net.'
57 _healthCheckCounter
= 0
58 _answerUnexpected
= True
61 def startResponders(cls
):
62 print("Launching responders..")
64 cls
._UDPResponder
= threading
.Thread(name
='UDP Responder', target
=cls
.UDPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
65 cls
._UDPResponder
.setDaemon(True)
66 cls
._UDPResponder
.start()
67 cls
._TCPResponder
= threading
.Thread(name
='TCP Responder', target
=cls
.TCPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
68 cls
._TCPResponder
.setDaemon(True)
69 cls
._TCPResponder
.start()
72 def startDNSDist(cls
, shutUp
=True):
73 print("Launching dnsdist..")
74 conffile
= 'dnsdist_test.conf'
75 params
= tuple([getattr(cls
, param
) for param
in cls
._config
_params
])
77 with
open(conffile
, 'w') as conf
:
78 conf
.write("-- Autogenerated by dnsdisttests.py\n")
79 conf
.write(cls
._config
_template
% params
)
81 dnsdistcmd
= [os
.environ
['DNSDISTBIN'], '-C', conffile
,
82 '-l', '%s:%d' % (cls
._dnsDistListeningAddr
, cls
._dnsDistPort
) ]
84 dnsdistcmd
.extend(['--acl', acl
])
85 print(' '.join(dnsdistcmd
))
87 # validate config with --check-config, which sets client=true, possibly exposing bugs.
88 testcmd
= dnsdistcmd
+ ['--check-config']
90 output
= subprocess
.check_output(testcmd
, stderr
=subprocess
.STDOUT
, close_fds
=True)
91 except subprocess
.CalledProcessError
as exc
:
92 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc
.returncode
, exc
.output
))
93 if output
!= b
'Configuration \'dnsdist_test.conf\' OK!\n':
94 raise AssertionError('dnsdist --check-config failed: %s' % output
)
97 with
open(os
.devnull
, 'w') as fdDevNull
:
98 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True, stdout
=fdDevNull
)
100 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True)
102 if 'DNSDIST_FAST_TESTS' in os
.environ
:
105 delay
= cls
._dnsdistStartupDelay
109 if cls
._dnsdist
.poll() is not None:
111 sys
.exit(cls
._dnsdist
.returncode
)
114 def setUpSockets(cls
):
115 print("Setting up UDP socket..")
116 cls
._sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
117 cls
._sock
.settimeout(2.0)
118 cls
._sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
123 cls
.startResponders()
124 cls
.startDNSDist(cls
._shutUp
)
127 print("Launching tests..")
130 def tearDownClass(cls
):
131 if 'DNSDIST_FAST_TESTS' in os
.environ
:
136 cls
._dnsdist
.terminate()
137 if cls
._dnsdist
.poll() is None:
139 if cls
._dnsdist
.poll() is None:
144 def _ResponderIncrementCounter(cls
):
145 if threading
.currentThread().name
in cls
._responsesCounter
:
146 cls
._responsesCounter
[threading
.currentThread().name
] += 1
148 cls
._responsesCounter
[threading
.currentThread().name
] = 1
151 def _getResponse(cls
, request
, fromQueue
, toQueue
):
153 if len(request
.question
) != 1:
154 print("Skipping query with question count %d" % (len(request
.question
)))
156 healthCheck
= str(request
.question
[0].name
).endswith(cls
._healthCheckName
)
158 cls
._healthCheckCounter
+= 1
160 cls
._ResponderIncrementCounter
()
161 if not fromQueue
.empty():
162 response
= fromQueue
.get(True, cls
._queueTimeout
)
164 response
= copy
.copy(response
)
165 response
.id = request
.id
166 toQueue
.put(request
, True, cls
._queueTimeout
)
170 response
= dns
.message
.make_response(request
)
171 elif cls
._answerUnexpected
:
172 response
= dns
.message
.make_response(request
)
173 response
.set_rcode(dns
.rcode
.SERVFAIL
)
178 def UDPResponder(cls
, port
, fromQueue
, toQueue
, ignoreTrailing
=False):
179 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
180 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
181 sock
.bind(("127.0.0.1", port
))
183 data
, addr
= sock
.recvfrom(4096)
184 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
185 response
= cls
._getResponse
(request
, fromQueue
, toQueue
)
191 sock
.sendto(response
.to_wire(), addr
)
192 sock
.settimeout(None)
196 def TCPResponder(cls
, port
, fromQueue
, toQueue
, ignoreTrailing
=False, multipleResponses
=False):
197 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
198 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
200 sock
.bind(("127.0.0.1", port
))
201 except socket
.error
as e
:
202 print("Error binding in the TCP responder: %s" % str(e
))
207 (conn
, _
) = sock
.accept()
214 (datalen
,) = struct
.unpack("!H", data
)
215 data
= conn
.recv(datalen
)
216 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
217 response
= cls
._getResponse
(request
, fromQueue
, toQueue
)
223 wire
= response
.to_wire()
224 conn
.send(struct
.pack("!H", len(wire
)))
227 while multipleResponses
:
228 if fromQueue
.empty():
231 response
= fromQueue
.get(True, cls
._queueTimeout
)
235 response
= copy
.copy(response
)
236 response
.id = request
.id
237 wire
= response
.to_wire()
239 conn
.send(struct
.pack("!H", len(wire
)))
241 except socket
.error
as e
:
242 # some of the tests are going to close
243 # the connection on us, just deal with it
251 def sendUDPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
253 cls
._toResponderQueue
.put(response
, True, timeout
)
256 cls
._sock
.settimeout(timeout
)
260 query
= query
.to_wire()
261 cls
._sock
.send(query
)
262 data
= cls
._sock
.recv(4096)
263 except socket
.timeout
:
267 cls
._sock
.settimeout(None)
271 if useQueue
and not cls
._fromResponderQueue
.empty():
272 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
274 message
= dns
.message
.from_wire(data
)
275 return (receivedQuery
, message
)
278 def openTCPConnection(cls
, timeout
=None):
279 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
281 sock
.settimeout(timeout
)
283 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
287 def openTLSConnection(cls
, port
, serverName
, caCert
=None, timeout
=None):
288 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
290 sock
.settimeout(timeout
)
293 if hasattr(ssl
, 'create_default_context'):
294 sslctx
= ssl
.create_default_context(cafile
=caCert
)
295 sslsock
= sslctx
.wrap_socket(sock
, server_hostname
=serverName
)
297 sslsock
= ssl
.wrap_socket(sock
, ca_certs
=caCert
, cert_reqs
=ssl
.CERT_REQUIRED
)
299 sslsock
.connect(("127.0.0.1", port
))
303 def sendTCPQueryOverConnection(cls
, sock
, query
, rawQuery
=False, response
=None, timeout
=2.0):
305 wire
= query
.to_wire()
310 cls
._toResponderQueue
.put(response
, True, timeout
)
312 sock
.send(struct
.pack("!H", len(wire
)))
316 def recvTCPResponseOverConnection(cls
, sock
, useQueue
=False, timeout
=2.0):
320 (datalen
,) = struct
.unpack("!H", data
)
321 data
= sock
.recv(datalen
)
323 message
= dns
.message
.from_wire(data
)
325 if useQueue
and not cls
._fromResponderQueue
.empty():
326 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
327 return (receivedQuery
, message
)
332 def sendTCPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
335 cls
._toResponderQueue
.put(response
, True, timeout
)
337 sock
= cls
.openTCPConnection(timeout
)
340 cls
.sendTCPQueryOverConnection(sock
, query
, rawQuery
)
341 message
= cls
.recvTCPResponseOverConnection(sock
)
342 except socket
.timeout
as e
:
343 print("Timeout: %s" % (str(e
)))
344 except socket
.error
as e
:
345 print("Network error: %s" % (str(e
)))
350 if useQueue
and not cls
._fromResponderQueue
.empty():
351 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
353 return (receivedQuery
, message
)
356 def sendTCPQueryWithMultipleResponses(cls
, query
, responses
, useQueue
=True, timeout
=2.0, rawQuery
=False):
358 for response
in responses
:
359 cls
._toResponderQueue
.put(response
, True, timeout
)
360 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
362 sock
.settimeout(timeout
)
364 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
369 wire
= query
.to_wire()
373 sock
.send(struct
.pack("!H", len(wire
)))
379 (datalen
,) = struct
.unpack("!H", data
)
380 data
= sock
.recv(datalen
)
381 messages
.append(dns
.message
.from_wire(data
))
383 except socket
.timeout
as e
:
384 print("Timeout: %s" % (str(e
)))
385 except socket
.error
as e
:
386 print("Network error: %s" % (str(e
)))
391 if useQueue
and not cls
._fromResponderQueue
.empty():
392 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
393 return (receivedQuery
, messages
)
396 # This function is called before every tests
398 # Clear the responses counters
399 for key
in self
._responsesCounter
:
400 self
._responsesCounter
[key
] = 0
402 self
._healthCheckCounter
= 0
404 # Make sure the queues are empty, in case
405 # a previous test failed
406 while not self
._toResponderQueue
.empty():
407 self
._toResponderQueue
.get(False)
409 while not self
._fromResponderQueue
.empty():
410 self
._fromResponderQueue
.get(False)
413 def clearToResponderQueue(cls
):
414 while not cls
._toResponderQueue
.empty():
415 cls
._toResponderQueue
.get(False)
418 def clearFromResponderQueue(cls
):
419 while not cls
._fromResponderQueue
.empty():
420 cls
._fromResponderQueue
.get(False)
423 def clearResponderQueues(cls
):
424 cls
.clearToResponderQueue()
425 cls
.clearFromResponderQueue()
428 def generateConsoleKey():
429 return libnacl
.utils
.salsa_key()
432 def _encryptConsole(cls
, command
, nonce
):
433 command
= command
.encode('UTF-8')
434 if cls
._consoleKey
is None:
436 return libnacl
.crypto_secretbox(command
, nonce
, cls
._consoleKey
)
439 def _decryptConsole(cls
, command
, nonce
):
440 if cls
._consoleKey
is None:
443 result
= libnacl
.crypto_secretbox_open(command
, nonce
, cls
._consoleKey
)
444 return result
.decode('UTF-8')
447 def sendConsoleCommand(cls
, command
, timeout
=1.0):
448 ourNonce
= libnacl
.utils
.rand_nonce()
450 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
452 sock
.settimeout(timeout
)
454 sock
.connect(("127.0.0.1", cls
._consolePort
))
456 theirNonce
= sock
.recv(len(ourNonce
))
457 if len(theirNonce
) != len(ourNonce
):
458 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce
), len(ourNonce
)))
459 if len(theirNonce
) == 0:
460 raise socket
.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce
)))
463 halfNonceSize
= int(len(ourNonce
) / 2)
464 readingNonce
= ourNonce
[0:halfNonceSize
] + theirNonce
[halfNonceSize
:]
465 writingNonce
= theirNonce
[0:halfNonceSize
] + ourNonce
[halfNonceSize
:]
466 msg
= cls
._encryptConsole
(command
, writingNonce
)
467 sock
.send(struct
.pack("!I", len(msg
)))
471 raise socket
.error("Got EOF while reading the response size")
473 (responseLen
,) = struct
.unpack("!I", data
)
474 data
= sock
.recv(responseLen
)
475 response
= cls
._decryptConsole
(data
, readingNonce
)
478 def compareOptions(self
, a
, b
):
479 self
.assertEquals(len(a
), len(b
))
480 for idx
in range(len(a
)):
481 self
.assertEquals(a
[idx
], b
[idx
])
483 def checkMessageNoEDNS(self
, expected
, received
):
484 self
.assertEquals(expected
, received
)
485 self
.assertEquals(received
.edns
, -1)
486 self
.assertEquals(len(received
.options
), 0)
488 def checkMessageEDNSWithoutOptions(self
, expected
, received
):
489 self
.assertEquals(expected
, received
)
490 self
.assertEquals(received
.edns
, 0)
492 def checkMessageEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
493 self
.assertEquals(expected
, received
)
494 self
.assertEquals(received
.edns
, 0)
495 self
.assertEquals(len(received
.options
), withCookies
)
497 for option
in received
.options
:
498 self
.assertEquals(option
.otype
, 10)
500 def checkMessageEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
501 self
.assertEquals(expected
, received
)
502 self
.assertEquals(received
.edns
, 0)
503 self
.assertEquals(len(received
.options
), 1 + additionalOptions
)
505 for option
in received
.options
:
506 if option
.otype
== clientsubnetoption
.ASSIGNED_OPTION_CODE
:
509 self
.assertNotEquals(additionalOptions
, 0)
511 self
.compareOptions(expected
.options
, received
.options
)
512 self
.assertTrue(hasECS
)
514 def checkQueryEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
515 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
517 def checkResponseEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
518 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
520 def checkQueryEDNSWithoutECS(self
, expected
, received
):
521 self
.checkMessageEDNSWithoutECS(expected
, received
)
523 def checkResponseEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
524 self
.checkMessageEDNSWithoutECS(expected
, received
, withCookies
)
526 def checkQueryNoEDNS(self
, expected
, received
):
527 self
.checkMessageNoEDNS(expected
, received
)
529 def checkResponseNoEDNS(self
, expected
, received
):
530 self
.checkMessageNoEDNS(expected
, received
)