13 import clientsubnetoption
19 # Python2/3 compatibility hacks
21 from queue
import Queue
23 from Queue
import Queue
31 class DNSDistTest(unittest
.TestCase
):
33 Set up a dnsdist instance and responder threads.
34 Queries sent to dnsdist are relayed to the responder threads,
35 who reply with the response provided by the tests themselves
36 on a queue. Responder threads also queue the queries received
37 from dnsdist on a separate queue, allowing the tests to check
38 that the queries sent from dnsdist were as expected.
41 _dnsDistListeningAddr
= "127.0.0.1"
42 _testServerPort
= 5350
43 _toResponderQueue
= Queue()
44 _fromResponderQueue
= Queue()
46 _dnsdistStartupDelay
= 2.0
48 _responsesCounter
= {}
49 _config_template
= """
51 _config_params
= ['_testServerPort']
52 _acl
= ['127.0.0.1/32']
55 _healthCheckName
= 'a.root-servers.net.'
56 _healthCheckCounter
= 0
57 _answerUnexpected
= True
60 def startResponders(cls
):
61 print("Launching responders..")
63 cls
._UDPResponder
= threading
.Thread(name
='UDP Responder', target
=cls
.UDPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
64 cls
._UDPResponder
.setDaemon(True)
65 cls
._UDPResponder
.start()
66 cls
._TCPResponder
= threading
.Thread(name
='TCP Responder', target
=cls
.TCPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
67 cls
._TCPResponder
.setDaemon(True)
68 cls
._TCPResponder
.start()
71 def startDNSDist(cls
):
72 print("Launching dnsdist..")
73 confFile
= os
.path
.join('configs', 'dnsdist_%s.conf' % (cls
.__name
__))
74 params
= tuple([getattr(cls
, param
) for param
in cls
._config
_params
])
76 with
open(confFile
, 'w') as conf
:
77 conf
.write("-- Autogenerated by dnsdisttests.py\n")
78 conf
.write(cls
._config
_template
% params
)
80 dnsdistcmd
= [os
.environ
['DNSDISTBIN'], '--supervised', '-C', confFile
,
81 '-l', '%s:%d' % (cls
._dnsDistListeningAddr
, cls
._dnsDistPort
) ]
83 dnsdistcmd
.extend(['--acl', acl
])
84 print(' '.join(dnsdistcmd
))
86 # validate config with --check-config, which sets client=true, possibly exposing bugs.
87 testcmd
= dnsdistcmd
+ ['--check-config']
89 output
= subprocess
.check_output(testcmd
, stderr
=subprocess
.STDOUT
, close_fds
=True)
90 except subprocess
.CalledProcessError
as exc
:
91 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc
.returncode
, exc
.output
))
92 expectedOutput
= ('Configuration \'%s\' OK!\n' % (confFile
)).encode()
93 if output
!= expectedOutput
:
94 raise AssertionError('dnsdist --check-config failed: %s' % output
)
96 logFile
= os
.path
.join('configs', 'dnsdist_%s.log' % (cls
.__name
__))
97 with
open(logFile
, 'w') as fdLog
:
98 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True, stdout
=fdLog
, stderr
=fdLog
)
100 if 'DNSDIST_FAST_TESTS' in os
.environ
:
103 delay
= cls
._dnsdistStartupDelay
107 if cls
._dnsdist
.poll() is not None:
109 sys
.exit(cls
._dnsdist
.returncode
)
112 def setUpSockets(cls
):
113 print("Setting up UDP socket..")
114 cls
._sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
115 cls
._sock
.settimeout(2.0)
116 cls
._sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
121 cls
.startResponders()
125 print("Launching tests..")
128 def tearDownClass(cls
):
129 if 'DNSDIST_FAST_TESTS' in os
.environ
:
134 cls
._dnsdist
.terminate()
135 if cls
._dnsdist
.poll() is None:
137 if cls
._dnsdist
.poll() is None:
142 def _ResponderIncrementCounter(cls
):
143 if threading
.currentThread().name
in cls
._responsesCounter
:
144 cls
._responsesCounter
[threading
.currentThread().name
] += 1
146 cls
._responsesCounter
[threading
.currentThread().name
] = 1
149 def _getResponse(cls
, request
, fromQueue
, toQueue
, synthesize
=None):
151 if len(request
.question
) != 1:
152 print("Skipping query with question count %d" % (len(request
.question
)))
154 healthCheck
= str(request
.question
[0].name
).endswith(cls
._healthCheckName
)
156 cls
._healthCheckCounter
+= 1
157 response
= dns
.message
.make_response(request
)
159 cls
._ResponderIncrementCounter
()
160 if not fromQueue
.empty():
161 toQueue
.put(request
, True, cls
._queueTimeout
)
162 if synthesize
is None:
163 response
= fromQueue
.get(True, cls
._queueTimeout
)
165 response
= copy
.copy(response
)
166 response
.id = request
.id
169 if synthesize
is not None:
170 response
= dns
.message
.make_response(request
)
171 response
.set_rcode(synthesize
)
172 elif cls
._answerUnexpected
:
173 response
= dns
.message
.make_response(request
)
174 response
.set_rcode(dns
.rcode
.SERVFAIL
)
179 def UDPResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, callback
=None):
180 # trailingDataResponse=True means "ignore trailing data".
181 # Other values are either False (meaning "raise an exception")
182 # or are interpreted as a response RCODE for queries with trailing data.
183 # callback is invoked for every -even healthcheck ones- query and should return a raw response
184 ignoreTrailing
= trailingDataResponse
is True
186 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
187 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
188 sock
.bind(("127.0.0.1", port
))
190 data
, addr
= sock
.recvfrom(4096)
193 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
194 except dns
.message
.TrailingJunk
as e
:
195 if trailingDataResponse
is False or forceRcode
is True:
197 print("UDP query with trailing data, synthesizing response")
198 request
= dns
.message
.from_wire(data
, ignore_trailing
=True)
199 forceRcode
= trailingDataResponse
203 wire
= callback(request
)
205 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
207 wire
= response
.to_wire()
213 sock
.sendto(wire
, addr
)
214 sock
.settimeout(None)
218 def TCPResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, multipleResponses
=False, callback
=None):
219 # trailingDataResponse=True means "ignore trailing data".
220 # Other values are either False (meaning "raise an exception")
221 # or are interpreted as a response RCODE for queries with trailing data.
222 # callback is invoked for every -even healthcheck ones- query and should return a raw response
223 ignoreTrailing
= trailingDataResponse
is True
225 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
226 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
227 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
229 sock
.bind(("127.0.0.1", port
))
230 except socket
.error
as e
:
231 print("Error binding in the TCP responder: %s" % str(e
))
236 (conn
, _
) = sock
.accept()
243 (datalen
,) = struct
.unpack("!H", data
)
244 data
= conn
.recv(datalen
)
247 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
248 except dns
.message
.TrailingJunk
as e
:
249 if trailingDataResponse
is False or forceRcode
is True:
251 print("TCP query with trailing data, synthesizing response")
252 request
= dns
.message
.from_wire(data
, ignore_trailing
=True)
253 forceRcode
= trailingDataResponse
256 wire
= callback(request
)
258 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
260 wire
= response
.to_wire(max_size
=65535)
266 conn
.send(struct
.pack("!H", len(wire
)))
269 while multipleResponses
:
270 if fromQueue
.empty():
273 response
= fromQueue
.get(True, cls
._queueTimeout
)
277 response
= copy
.copy(response
)
278 response
.id = request
.id
279 wire
= response
.to_wire(max_size
=65535)
281 conn
.send(struct
.pack("!H", len(wire
)))
283 except socket
.error
as e
:
284 # some of the tests are going to close
285 # the connection on us, just deal with it
293 def sendUDPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
295 cls
._toResponderQueue
.put(response
, True, timeout
)
298 cls
._sock
.settimeout(timeout
)
302 query
= query
.to_wire()
303 cls
._sock
.send(query
)
304 data
= cls
._sock
.recv(4096)
305 except socket
.timeout
:
309 cls
._sock
.settimeout(None)
313 if useQueue
and not cls
._fromResponderQueue
.empty():
314 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
316 message
= dns
.message
.from_wire(data
)
317 return (receivedQuery
, message
)
320 def openTCPConnection(cls
, timeout
=None):
321 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
322 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
324 sock
.settimeout(timeout
)
326 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
330 def openTLSConnection(cls
, port
, serverName
, caCert
=None, timeout
=None):
331 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
332 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
334 sock
.settimeout(timeout
)
337 if hasattr(ssl
, 'create_default_context'):
338 sslctx
= ssl
.create_default_context(cafile
=caCert
)
339 sslsock
= sslctx
.wrap_socket(sock
, server_hostname
=serverName
)
341 sslsock
= ssl
.wrap_socket(sock
, ca_certs
=caCert
, cert_reqs
=ssl
.CERT_REQUIRED
)
343 sslsock
.connect(("127.0.0.1", port
))
347 def sendTCPQueryOverConnection(cls
, sock
, query
, rawQuery
=False, response
=None, timeout
=2.0):
349 wire
= query
.to_wire()
354 cls
._toResponderQueue
.put(response
, True, timeout
)
356 sock
.send(struct
.pack("!H", len(wire
)))
360 def recvTCPResponseOverConnection(cls
, sock
, useQueue
=False, timeout
=2.0):
364 (datalen
,) = struct
.unpack("!H", data
)
365 data
= sock
.recv(datalen
)
367 message
= dns
.message
.from_wire(data
)
369 if useQueue
and not cls
._fromResponderQueue
.empty():
370 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
371 return (receivedQuery
, message
)
376 def sendTCPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
379 cls
._toResponderQueue
.put(response
, True, timeout
)
381 sock
= cls
.openTCPConnection(timeout
)
384 cls
.sendTCPQueryOverConnection(sock
, query
, rawQuery
)
385 message
= cls
.recvTCPResponseOverConnection(sock
)
386 except socket
.timeout
as e
:
387 print("Timeout: %s" % (str(e
)))
388 except socket
.error
as e
:
389 print("Network error: %s" % (str(e
)))
394 if useQueue
and not cls
._fromResponderQueue
.empty():
395 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
397 return (receivedQuery
, message
)
400 def sendTCPQueryWithMultipleResponses(cls
, query
, responses
, useQueue
=True, timeout
=2.0, rawQuery
=False):
402 for response
in responses
:
403 cls
._toResponderQueue
.put(response
, True, timeout
)
404 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
405 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
407 sock
.settimeout(timeout
)
409 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
414 wire
= query
.to_wire()
418 sock
.send(struct
.pack("!H", len(wire
)))
424 (datalen
,) = struct
.unpack("!H", data
)
425 data
= sock
.recv(datalen
)
426 messages
.append(dns
.message
.from_wire(data
))
428 except socket
.timeout
as e
:
429 print("Timeout: %s" % (str(e
)))
430 except socket
.error
as e
:
431 print("Network error: %s" % (str(e
)))
436 if useQueue
and not cls
._fromResponderQueue
.empty():
437 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
438 return (receivedQuery
, messages
)
441 # This function is called before every tests
443 # Clear the responses counters
444 for key
in self
._responsesCounter
:
445 self
._responsesCounter
[key
] = 0
447 self
._healthCheckCounter
= 0
449 # Make sure the queues are empty, in case
450 # a previous test failed
451 while not self
._toResponderQueue
.empty():
452 self
._toResponderQueue
.get(False)
454 while not self
._fromResponderQueue
.empty():
455 self
._fromResponderQueue
.get(False)
458 def clearToResponderQueue(cls
):
459 while not cls
._toResponderQueue
.empty():
460 cls
._toResponderQueue
.get(False)
463 def clearFromResponderQueue(cls
):
464 while not cls
._fromResponderQueue
.empty():
465 cls
._fromResponderQueue
.get(False)
468 def clearResponderQueues(cls
):
469 cls
.clearToResponderQueue()
470 cls
.clearFromResponderQueue()
473 def generateConsoleKey():
474 return libnacl
.utils
.salsa_key()
477 def _encryptConsole(cls
, command
, nonce
):
478 command
= command
.encode('UTF-8')
479 if cls
._consoleKey
is None:
481 return libnacl
.crypto_secretbox(command
, nonce
, cls
._consoleKey
)
484 def _decryptConsole(cls
, command
, nonce
):
485 if cls
._consoleKey
is None:
488 result
= libnacl
.crypto_secretbox_open(command
, nonce
, cls
._consoleKey
)
489 return result
.decode('UTF-8')
492 def sendConsoleCommand(cls
, command
, timeout
=1.0):
493 ourNonce
= libnacl
.utils
.rand_nonce()
495 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
496 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
498 sock
.settimeout(timeout
)
500 sock
.connect(("127.0.0.1", cls
._consolePort
))
502 theirNonce
= sock
.recv(len(ourNonce
))
503 if len(theirNonce
) != len(ourNonce
):
504 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce
), len(ourNonce
)))
505 if len(theirNonce
) == 0:
506 raise socket
.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce
)))
509 halfNonceSize
= int(len(ourNonce
) / 2)
510 readingNonce
= ourNonce
[0:halfNonceSize
] + theirNonce
[halfNonceSize
:]
511 writingNonce
= theirNonce
[0:halfNonceSize
] + ourNonce
[halfNonceSize
:]
512 msg
= cls
._encryptConsole
(command
, writingNonce
)
513 sock
.send(struct
.pack("!I", len(msg
)))
517 raise socket
.error("Got EOF while reading the response size")
519 (responseLen
,) = struct
.unpack("!I", data
)
520 data
= sock
.recv(responseLen
)
521 response
= cls
._decryptConsole
(data
, readingNonce
)
524 def compareOptions(self
, a
, b
):
525 self
.assertEquals(len(a
), len(b
))
526 for idx
in range(len(a
)):
527 self
.assertEquals(a
[idx
], b
[idx
])
529 def checkMessageNoEDNS(self
, expected
, received
):
530 self
.assertEquals(expected
, received
)
531 self
.assertEquals(received
.edns
, -1)
532 self
.assertEquals(len(received
.options
), 0)
534 def checkMessageEDNSWithoutOptions(self
, expected
, received
):
535 self
.assertEquals(expected
, received
)
536 self
.assertEquals(received
.edns
, 0)
537 self
.assertEquals(expected
.payload
, received
.payload
)
539 def checkMessageEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
540 self
.assertEquals(expected
, received
)
541 self
.assertEquals(received
.edns
, 0)
542 self
.assertEquals(expected
.payload
, received
.payload
)
543 self
.assertEquals(len(received
.options
), withCookies
)
545 for option
in received
.options
:
546 self
.assertEquals(option
.otype
, 10)
548 def checkMessageEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
549 self
.assertEquals(expected
, received
)
550 self
.assertEquals(received
.edns
, 0)
551 self
.assertEquals(expected
.payload
, received
.payload
)
552 self
.assertEquals(len(received
.options
), 1 + additionalOptions
)
554 for option
in received
.options
:
555 if option
.otype
== clientsubnetoption
.ASSIGNED_OPTION_CODE
:
558 self
.assertNotEquals(additionalOptions
, 0)
560 self
.compareOptions(expected
.options
, received
.options
)
561 self
.assertTrue(hasECS
)
563 def checkQueryEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
564 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
566 def checkResponseEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
567 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
569 def checkQueryEDNSWithoutECS(self
, expected
, received
):
570 self
.checkMessageEDNSWithoutECS(expected
, received
)
572 def checkResponseEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
573 self
.checkMessageEDNSWithoutECS(expected
, received
, withCookies
)
575 def checkQueryNoEDNS(self
, expected
, received
):
576 self
.checkMessageNoEDNS(expected
, received
)
578 def checkResponseNoEDNS(self
, expected
, received
):
579 self
.checkMessageNoEDNS(expected
, received
)