13 import clientsubnetoption
19 # Python2/3 compatibility hacks
21 from queue
import Queue
23 from Queue
import Queue
31 class DNSDistTest(unittest
.TestCase
):
33 Set up a dnsdist instance and responder threads.
34 Queries sent to dnsdist are relayed to the responder threads,
35 who reply with the response provided by the tests themselves
36 on a queue. Responder threads also queue the queries received
37 from dnsdist on a separate queue, allowing the tests to check
38 that the queries sent from dnsdist were as expected.
41 _dnsDistListeningAddr
= "127.0.0.1"
42 _testServerPort
= 5350
43 _toResponderQueue
= Queue()
44 _fromResponderQueue
= Queue()
46 _dnsdistStartupDelay
= 2.0
48 _responsesCounter
= {}
50 _config_template
= """
52 _config_params
= ['_testServerPort']
53 _acl
= ['127.0.0.1/32']
56 _healthCheckName
= 'a.root-servers.net.'
57 _healthCheckCounter
= 0
58 _answerUnexpected
= True
61 def startResponders(cls
):
62 print("Launching responders..")
64 cls
._UDPResponder
= threading
.Thread(name
='UDP Responder', target
=cls
.UDPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
65 cls
._UDPResponder
.setDaemon(True)
66 cls
._UDPResponder
.start()
67 cls
._TCPResponder
= threading
.Thread(name
='TCP Responder', target
=cls
.TCPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
68 cls
._TCPResponder
.setDaemon(True)
69 cls
._TCPResponder
.start()
72 def startDNSDist(cls
, shutUp
=True):
73 print("Launching dnsdist..")
74 conffile
= 'dnsdist_test.conf'
75 params
= tuple([getattr(cls
, param
) for param
in cls
._config
_params
])
77 with
open(conffile
, 'w') as conf
:
78 conf
.write("-- Autogenerated by dnsdisttests.py\n")
79 conf
.write(cls
._config
_template
% params
)
81 dnsdistcmd
= [os
.environ
['DNSDISTBIN'], '-C', conffile
,
82 '-l', '%s:%d' % (cls
._dnsDistListeningAddr
, cls
._dnsDistPort
) ]
84 dnsdistcmd
.extend(['--acl', acl
])
85 print(' '.join(dnsdistcmd
))
87 # validate config with --check-config, which sets client=true, possibly exposing bugs.
88 testcmd
= dnsdistcmd
+ ['--check-config']
89 output
= subprocess
.check_output(testcmd
, close_fds
=True)
90 if output
!= b
'Configuration \'dnsdist_test.conf\' OK!\n':
91 raise AssertionError('dnsdist --check-config failed: %s' % output
)
94 with
open(os
.devnull
, 'w') as fdDevNull
:
95 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True, stdout
=fdDevNull
)
97 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True)
99 if 'DNSDIST_FAST_TESTS' in os
.environ
:
102 delay
= cls
._dnsdistStartupDelay
106 if cls
._dnsdist
.poll() is not None:
108 sys
.exit(cls
._dnsdist
.returncode
)
111 def setUpSockets(cls
):
112 print("Setting up UDP socket..")
113 cls
._sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
114 cls
._sock
.settimeout(2.0)
115 cls
._sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
120 cls
.startResponders()
121 cls
.startDNSDist(cls
._shutUp
)
124 print("Launching tests..")
127 def tearDownClass(cls
):
128 if 'DNSDIST_FAST_TESTS' in os
.environ
:
133 cls
._dnsdist
.terminate()
134 if cls
._dnsdist
.poll() is None:
136 if cls
._dnsdist
.poll() is None:
141 def _ResponderIncrementCounter(cls
):
142 if threading
.currentThread().name
in cls
._responsesCounter
:
143 cls
._responsesCounter
[threading
.currentThread().name
] += 1
145 cls
._responsesCounter
[threading
.currentThread().name
] = 1
148 def _getResponse(cls
, request
, fromQueue
, toQueue
):
150 if len(request
.question
) != 1:
151 print("Skipping query with question count %d" % (len(request
.question
)))
153 healthCheck
= str(request
.question
[0].name
).endswith(cls
._healthCheckName
)
155 cls
._healthCheckCounter
+= 1
157 cls
._ResponderIncrementCounter
()
158 if not fromQueue
.empty():
159 response
= fromQueue
.get(True, cls
._queueTimeout
)
161 response
= copy
.copy(response
)
162 response
.id = request
.id
163 toQueue
.put(request
, True, cls
._queueTimeout
)
167 response
= dns
.message
.make_response(request
)
168 elif cls
._answerUnexpected
:
169 response
= dns
.message
.make_response(request
)
170 response
.set_rcode(dns
.rcode
.SERVFAIL
)
175 def UDPResponder(cls
, port
, fromQueue
, toQueue
, ignoreTrailing
=False):
176 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
177 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
178 sock
.bind(("127.0.0.1", port
))
180 data
, addr
= sock
.recvfrom(4096)
181 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
182 response
= cls
._getResponse
(request
, fromQueue
, toQueue
)
188 sock
.sendto(response
.to_wire(), addr
)
189 sock
.settimeout(None)
193 def TCPResponder(cls
, port
, fromQueue
, toQueue
, ignoreTrailing
=False, multipleResponses
=False):
194 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
195 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
197 sock
.bind(("127.0.0.1", port
))
198 except socket
.error
as e
:
199 print("Error binding in the TCP responder: %s" % str(e
))
204 (conn
, _
) = sock
.accept()
211 (datalen
,) = struct
.unpack("!H", data
)
212 data
= conn
.recv(datalen
)
213 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
214 response
= cls
._getResponse
(request
, fromQueue
, toQueue
)
220 wire
= response
.to_wire()
221 conn
.send(struct
.pack("!H", len(wire
)))
224 while multipleResponses
:
225 if fromQueue
.empty():
228 response
= fromQueue
.get(True, cls
._queueTimeout
)
232 response
= copy
.copy(response
)
233 response
.id = request
.id
234 wire
= response
.to_wire()
236 conn
.send(struct
.pack("!H", len(wire
)))
238 except socket
.error
as e
:
239 # some of the tests are going to close
240 # the connection on us, just deal with it
248 def sendUDPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
250 cls
._toResponderQueue
.put(response
, True, timeout
)
253 cls
._sock
.settimeout(timeout
)
257 query
= query
.to_wire()
258 cls
._sock
.send(query
)
259 data
= cls
._sock
.recv(4096)
260 except socket
.timeout
:
264 cls
._sock
.settimeout(None)
268 if useQueue
and not cls
._fromResponderQueue
.empty():
269 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
271 message
= dns
.message
.from_wire(data
)
272 return (receivedQuery
, message
)
275 def openTCPConnection(cls
, timeout
=None):
276 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
278 sock
.settimeout(timeout
)
280 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
284 def openTLSConnection(cls
, port
, serverName
, caCert
=None, timeout
=None):
285 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
287 sock
.settimeout(timeout
)
290 if hasattr(ssl
, 'create_default_context'):
291 sslctx
= ssl
.create_default_context(cafile
=caCert
)
292 sslsock
= sslctx
.wrap_socket(sock
, server_hostname
=serverName
)
294 sslsock
= ssl
.wrap_socket(sock
, ca_certs
=caCert
, cert_reqs
=ssl
.CERT_REQUIRED
)
296 sslsock
.connect(("127.0.0.1", port
))
300 def sendTCPQueryOverConnection(cls
, sock
, query
, rawQuery
=False, response
=None, timeout
=2.0):
302 wire
= query
.to_wire()
307 cls
._toResponderQueue
.put(response
, True, timeout
)
309 sock
.send(struct
.pack("!H", len(wire
)))
313 def recvTCPResponseOverConnection(cls
, sock
, useQueue
=False, timeout
=2.0):
317 (datalen
,) = struct
.unpack("!H", data
)
318 data
= sock
.recv(datalen
)
320 message
= dns
.message
.from_wire(data
)
322 if useQueue
and not cls
._fromResponderQueue
.empty():
323 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
324 return (receivedQuery
, message
)
329 def sendTCPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
332 cls
._toResponderQueue
.put(response
, True, timeout
)
334 sock
= cls
.openTCPConnection(timeout
)
337 cls
.sendTCPQueryOverConnection(sock
, query
, rawQuery
)
338 message
= cls
.recvTCPResponseOverConnection(sock
)
339 except socket
.timeout
as e
:
340 print("Timeout: %s" % (str(e
)))
341 except socket
.error
as e
:
342 print("Network error: %s" % (str(e
)))
347 if useQueue
and not cls
._fromResponderQueue
.empty():
348 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
350 return (receivedQuery
, message
)
353 def sendTCPQueryWithMultipleResponses(cls
, query
, responses
, useQueue
=True, timeout
=2.0, rawQuery
=False):
355 for response
in responses
:
356 cls
._toResponderQueue
.put(response
, True, timeout
)
357 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
359 sock
.settimeout(timeout
)
361 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
366 wire
= query
.to_wire()
370 sock
.send(struct
.pack("!H", len(wire
)))
376 (datalen
,) = struct
.unpack("!H", data
)
377 data
= sock
.recv(datalen
)
378 messages
.append(dns
.message
.from_wire(data
))
380 except socket
.timeout
as e
:
381 print("Timeout: %s" % (str(e
)))
382 except socket
.error
as e
:
383 print("Network error: %s" % (str(e
)))
388 if useQueue
and not cls
._fromResponderQueue
.empty():
389 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
390 return (receivedQuery
, messages
)
393 # This function is called before every tests
395 # Clear the responses counters
396 for key
in self
._responsesCounter
:
397 self
._responsesCounter
[key
] = 0
399 self
._healthCheckCounter
= 0
401 # Make sure the queues are empty, in case
402 # a previous test failed
403 while not self
._toResponderQueue
.empty():
404 self
._toResponderQueue
.get(False)
406 while not self
._fromResponderQueue
.empty():
407 self
._fromResponderQueue
.get(False)
410 def clearToResponderQueue(cls
):
411 while not cls
._toResponderQueue
.empty():
412 cls
._toResponderQueue
.get(False)
415 def clearFromResponderQueue(cls
):
416 while not cls
._fromResponderQueue
.empty():
417 cls
._fromResponderQueue
.get(False)
420 def clearResponderQueues(cls
):
421 cls
.clearToResponderQueue()
422 cls
.clearFromResponderQueue()
425 def generateConsoleKey():
426 return libnacl
.utils
.salsa_key()
429 def _encryptConsole(cls
, command
, nonce
):
430 command
= command
.encode('UTF-8')
431 if cls
._consoleKey
is None:
433 return libnacl
.crypto_secretbox(command
, nonce
, cls
._consoleKey
)
436 def _decryptConsole(cls
, command
, nonce
):
437 if cls
._consoleKey
is None:
440 result
= libnacl
.crypto_secretbox_open(command
, nonce
, cls
._consoleKey
)
441 return result
.decode('UTF-8')
444 def sendConsoleCommand(cls
, command
, timeout
=1.0):
445 ourNonce
= libnacl
.utils
.rand_nonce()
447 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
449 sock
.settimeout(timeout
)
451 sock
.connect(("127.0.0.1", cls
._consolePort
))
453 theirNonce
= sock
.recv(len(ourNonce
))
454 if len(theirNonce
) != len(ourNonce
):
455 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce
), len(ourNonce
)))
456 if len(theirNonce
) == 0:
457 raise socket
.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce
)))
460 halfNonceSize
= int(len(ourNonce
) / 2)
461 readingNonce
= ourNonce
[0:halfNonceSize
] + theirNonce
[halfNonceSize
:]
462 writingNonce
= theirNonce
[0:halfNonceSize
] + ourNonce
[halfNonceSize
:]
463 msg
= cls
._encryptConsole
(command
, writingNonce
)
464 sock
.send(struct
.pack("!I", len(msg
)))
468 raise socket
.error("Got EOF while reading the response size")
470 (responseLen
,) = struct
.unpack("!I", data
)
471 data
= sock
.recv(responseLen
)
472 response
= cls
._decryptConsole
(data
, readingNonce
)
475 def compareOptions(self
, a
, b
):
476 self
.assertEquals(len(a
), len(b
))
477 for idx
in range(len(a
)):
478 self
.assertEquals(a
[idx
], b
[idx
])
480 def checkMessageNoEDNS(self
, expected
, received
):
481 self
.assertEquals(expected
, received
)
482 self
.assertEquals(received
.edns
, -1)
483 self
.assertEquals(len(received
.options
), 0)
485 def checkMessageEDNSWithoutOptions(self
, expected
, received
):
486 self
.assertEquals(expected
, received
)
487 self
.assertEquals(received
.edns
, 0)
489 def checkMessageEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
490 self
.assertEquals(expected
, received
)
491 self
.assertEquals(received
.edns
, 0)
492 self
.assertEquals(len(received
.options
), withCookies
)
494 for option
in received
.options
:
495 self
.assertEquals(option
.otype
, 10)
497 def checkMessageEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
498 self
.assertEquals(expected
, received
)
499 self
.assertEquals(received
.edns
, 0)
500 self
.assertEquals(len(received
.options
), 1 + additionalOptions
)
502 for option
in received
.options
:
503 if option
.otype
== clientsubnetoption
.ASSIGNED_OPTION_CODE
:
506 self
.assertNotEquals(additionalOptions
, 0)
508 self
.compareOptions(expected
.options
, received
.options
)
509 self
.assertTrue(hasECS
)
511 def checkQueryEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
512 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
514 def checkResponseEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
515 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
517 def checkQueryEDNSWithoutECS(self
, expected
, received
):
518 self
.checkMessageEDNSWithoutECS(expected
, received
)
520 def checkResponseEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
521 self
.checkMessageEDNSWithoutECS(expected
, received
, withCookies
)
523 def checkQueryNoEDNS(self
, expected
, received
):
524 self
.checkMessageNoEDNS(expected
, received
)
526 def checkResponseNoEDNS(self
, expected
, received
):
527 self
.checkMessageNoEDNS(expected
, received
)