13 import clientsubnetoption
19 # Python2/3 compatibility hacks
21 from queue
import Queue
23 from Queue
import Queue
31 class DNSDistTest(unittest
.TestCase
):
33 Set up a dnsdist instance and responder threads.
34 Queries sent to dnsdist are relayed to the responder threads,
35 who reply with the response provided by the tests themselves
36 on a queue. Responder threads also queue the queries received
37 from dnsdist on a separate queue, allowing the tests to check
38 that the queries sent from dnsdist were as expected.
41 _dnsDistListeningAddr
= "127.0.0.1"
42 _testServerPort
= 5350
43 _toResponderQueue
= Queue()
44 _fromResponderQueue
= Queue()
46 _dnsdistStartupDelay
= 2.0
48 _responsesCounter
= {}
49 _config_template
= """
51 _config_params
= ['_testServerPort']
52 _acl
= ['127.0.0.1/32']
55 _healthCheckName
= 'a.root-servers.net.'
56 _healthCheckCounter
= 0
57 _answerUnexpected
= True
60 def startResponders(cls
):
61 print("Launching responders..")
63 cls
._UDPResponder
= threading
.Thread(name
='UDP Responder', target
=cls
.UDPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
64 cls
._UDPResponder
.setDaemon(True)
65 cls
._UDPResponder
.start()
66 cls
._TCPResponder
= threading
.Thread(name
='TCP Responder', target
=cls
.TCPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
67 cls
._TCPResponder
.setDaemon(True)
68 cls
._TCPResponder
.start()
71 def startDNSDist(cls
):
72 print("Launching dnsdist..")
73 confFile
= os
.path
.join('configs', 'dnsdist_%s.conf' % (cls
.__name
__))
74 params
= tuple([getattr(cls
, param
) for param
in cls
._config
_params
])
76 with
open(confFile
, 'w') as conf
:
77 conf
.write("-- Autogenerated by dnsdisttests.py\n")
78 conf
.write(cls
._config
_template
% params
)
80 dnsdistcmd
= [os
.environ
['DNSDISTBIN'], '-C', confFile
,
81 '-l', '%s:%d' % (cls
._dnsDistListeningAddr
, cls
._dnsDistPort
) ]
83 dnsdistcmd
.extend(['--acl', acl
])
84 print(' '.join(dnsdistcmd
))
86 # validate config with --check-config, which sets client=true, possibly exposing bugs.
87 testcmd
= dnsdistcmd
+ ['--check-config']
89 output
= subprocess
.check_output(testcmd
, stderr
=subprocess
.STDOUT
, close_fds
=True)
90 except subprocess
.CalledProcessError
as exc
:
91 raise AssertionError('dnsdist --check-config failed (%d): %s' % (exc
.returncode
, exc
.output
))
92 expectedOutput
= ('Configuration \'%s\' OK!\n' % (confFile
)).encode()
93 if output
!= expectedOutput
:
94 raise AssertionError('dnsdist --check-config failed: %s' % output
)
96 logFile
= os
.path
.join('configs', 'dnsdist_%s.log' % (cls
.__name
__))
97 with
open(logFile
, 'w') as fdLog
:
98 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True, stdout
=fdLog
, stderr
=fdLog
)
100 if 'DNSDIST_FAST_TESTS' in os
.environ
:
103 delay
= cls
._dnsdistStartupDelay
107 if cls
._dnsdist
.poll() is not None:
109 sys
.exit(cls
._dnsdist
.returncode
)
112 def setUpSockets(cls
):
113 print("Setting up UDP socket..")
114 cls
._sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
115 cls
._sock
.settimeout(2.0)
116 cls
._sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
121 cls
.startResponders()
125 print("Launching tests..")
128 def tearDownClass(cls
):
129 if 'DNSDIST_FAST_TESTS' in os
.environ
:
134 cls
._dnsdist
.terminate()
135 if cls
._dnsdist
.poll() is None:
137 if cls
._dnsdist
.poll() is None:
142 def _ResponderIncrementCounter(cls
):
143 if threading
.currentThread().name
in cls
._responsesCounter
:
144 cls
._responsesCounter
[threading
.currentThread().name
] += 1
146 cls
._responsesCounter
[threading
.currentThread().name
] = 1
149 def _getResponse(cls
, request
, fromQueue
, toQueue
, synthesize
=None):
151 if len(request
.question
) != 1:
152 print("Skipping query with question count %d" % (len(request
.question
)))
154 healthCheck
= str(request
.question
[0].name
).endswith(cls
._healthCheckName
)
156 cls
._healthCheckCounter
+= 1
157 response
= dns
.message
.make_response(request
)
159 cls
._ResponderIncrementCounter
()
160 if not fromQueue
.empty():
161 toQueue
.put(request
, True, cls
._queueTimeout
)
162 if synthesize
is None:
163 response
= fromQueue
.get(True, cls
._queueTimeout
)
165 response
= copy
.copy(response
)
166 response
.id = request
.id
169 if synthesize
is not None:
170 response
= dns
.message
.make_response(request
)
171 response
.set_rcode(synthesize
)
172 elif cls
._answerUnexpected
:
173 response
= dns
.message
.make_response(request
)
174 response
.set_rcode(dns
.rcode
.SERVFAIL
)
179 def UDPResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False):
180 # trailingDataResponse=True means "ignore trailing data".
181 # Other values are either False (meaning "raise an exception")
182 # or are interpreted as a response RCODE for queries with trailing data.
183 ignoreTrailing
= trailingDataResponse
is True
185 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
186 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
187 sock
.bind(("127.0.0.1", port
))
189 data
, addr
= sock
.recvfrom(4096)
192 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
193 except dns
.message
.TrailingJunk
as e
:
194 if trailingDataResponse
is False or forceRcode
is True:
196 print("UDP query with trailing data, synthesizing response")
197 request
= dns
.message
.from_wire(data
, ignore_trailing
=True)
198 forceRcode
= trailingDataResponse
200 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
205 sock
.sendto(response
.to_wire(), addr
)
206 sock
.settimeout(None)
210 def TCPResponder(cls
, port
, fromQueue
, toQueue
, trailingDataResponse
=False, multipleResponses
=False):
211 # trailingDataResponse=True means "ignore trailing data".
212 # Other values are either False (meaning "raise an exception")
213 # or are interpreted as a response RCODE for queries with trailing data.
214 ignoreTrailing
= trailingDataResponse
is True
216 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
217 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
218 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
220 sock
.bind(("127.0.0.1", port
))
221 except socket
.error
as e
:
222 print("Error binding in the TCP responder: %s" % str(e
))
227 (conn
, _
) = sock
.accept()
234 (datalen
,) = struct
.unpack("!H", data
)
235 data
= conn
.recv(datalen
)
238 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
239 except dns
.message
.TrailingJunk
as e
:
240 if trailingDataResponse
is False or forceRcode
is True:
242 print("TCP query with trailing data, synthesizing response")
243 request
= dns
.message
.from_wire(data
, ignore_trailing
=True)
244 forceRcode
= trailingDataResponse
246 response
= cls
._getResponse
(request
, fromQueue
, toQueue
, synthesize
=forceRcode
)
251 wire
= response
.to_wire()
252 conn
.send(struct
.pack("!H", len(wire
)))
255 while multipleResponses
:
256 if fromQueue
.empty():
259 response
= fromQueue
.get(True, cls
._queueTimeout
)
263 response
= copy
.copy(response
)
264 response
.id = request
.id
265 wire
= response
.to_wire()
267 conn
.send(struct
.pack("!H", len(wire
)))
269 except socket
.error
as e
:
270 # some of the tests are going to close
271 # the connection on us, just deal with it
279 def sendUDPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
281 cls
._toResponderQueue
.put(response
, True, timeout
)
284 cls
._sock
.settimeout(timeout
)
288 query
= query
.to_wire()
289 cls
._sock
.send(query
)
290 data
= cls
._sock
.recv(4096)
291 except socket
.timeout
:
295 cls
._sock
.settimeout(None)
299 if useQueue
and not cls
._fromResponderQueue
.empty():
300 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
302 message
= dns
.message
.from_wire(data
)
303 return (receivedQuery
, message
)
306 def openTCPConnection(cls
, timeout
=None):
307 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
308 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
310 sock
.settimeout(timeout
)
312 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
316 def openTLSConnection(cls
, port
, serverName
, caCert
=None, timeout
=None):
317 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
318 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
320 sock
.settimeout(timeout
)
323 if hasattr(ssl
, 'create_default_context'):
324 sslctx
= ssl
.create_default_context(cafile
=caCert
)
325 sslsock
= sslctx
.wrap_socket(sock
, server_hostname
=serverName
)
327 sslsock
= ssl
.wrap_socket(sock
, ca_certs
=caCert
, cert_reqs
=ssl
.CERT_REQUIRED
)
329 sslsock
.connect(("127.0.0.1", port
))
333 def sendTCPQueryOverConnection(cls
, sock
, query
, rawQuery
=False, response
=None, timeout
=2.0):
335 wire
= query
.to_wire()
340 cls
._toResponderQueue
.put(response
, True, timeout
)
342 sock
.send(struct
.pack("!H", len(wire
)))
346 def recvTCPResponseOverConnection(cls
, sock
, useQueue
=False, timeout
=2.0):
350 (datalen
,) = struct
.unpack("!H", data
)
351 data
= sock
.recv(datalen
)
353 message
= dns
.message
.from_wire(data
)
355 if useQueue
and not cls
._fromResponderQueue
.empty():
356 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
357 return (receivedQuery
, message
)
362 def sendTCPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
365 cls
._toResponderQueue
.put(response
, True, timeout
)
367 sock
= cls
.openTCPConnection(timeout
)
370 cls
.sendTCPQueryOverConnection(sock
, query
, rawQuery
)
371 message
= cls
.recvTCPResponseOverConnection(sock
)
372 except socket
.timeout
as e
:
373 print("Timeout: %s" % (str(e
)))
374 except socket
.error
as e
:
375 print("Network error: %s" % (str(e
)))
380 if useQueue
and not cls
._fromResponderQueue
.empty():
381 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
383 return (receivedQuery
, message
)
386 def sendTCPQueryWithMultipleResponses(cls
, query
, responses
, useQueue
=True, timeout
=2.0, rawQuery
=False):
388 for response
in responses
:
389 cls
._toResponderQueue
.put(response
, True, timeout
)
390 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
391 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
393 sock
.settimeout(timeout
)
395 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
400 wire
= query
.to_wire()
404 sock
.send(struct
.pack("!H", len(wire
)))
410 (datalen
,) = struct
.unpack("!H", data
)
411 data
= sock
.recv(datalen
)
412 messages
.append(dns
.message
.from_wire(data
))
414 except socket
.timeout
as e
:
415 print("Timeout: %s" % (str(e
)))
416 except socket
.error
as e
:
417 print("Network error: %s" % (str(e
)))
422 if useQueue
and not cls
._fromResponderQueue
.empty():
423 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
424 return (receivedQuery
, messages
)
427 # This function is called before every tests
429 # Clear the responses counters
430 for key
in self
._responsesCounter
:
431 self
._responsesCounter
[key
] = 0
433 self
._healthCheckCounter
= 0
435 # Make sure the queues are empty, in case
436 # a previous test failed
437 while not self
._toResponderQueue
.empty():
438 self
._toResponderQueue
.get(False)
440 while not self
._fromResponderQueue
.empty():
441 self
._fromResponderQueue
.get(False)
444 def clearToResponderQueue(cls
):
445 while not cls
._toResponderQueue
.empty():
446 cls
._toResponderQueue
.get(False)
449 def clearFromResponderQueue(cls
):
450 while not cls
._fromResponderQueue
.empty():
451 cls
._fromResponderQueue
.get(False)
454 def clearResponderQueues(cls
):
455 cls
.clearToResponderQueue()
456 cls
.clearFromResponderQueue()
459 def generateConsoleKey():
460 return libnacl
.utils
.salsa_key()
463 def _encryptConsole(cls
, command
, nonce
):
464 command
= command
.encode('UTF-8')
465 if cls
._consoleKey
is None:
467 return libnacl
.crypto_secretbox(command
, nonce
, cls
._consoleKey
)
470 def _decryptConsole(cls
, command
, nonce
):
471 if cls
._consoleKey
is None:
474 result
= libnacl
.crypto_secretbox_open(command
, nonce
, cls
._consoleKey
)
475 return result
.decode('UTF-8')
478 def sendConsoleCommand(cls
, command
, timeout
=1.0):
479 ourNonce
= libnacl
.utils
.rand_nonce()
481 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
482 sock
.setsockopt(socket
.IPPROTO_TCP
, socket
.TCP_NODELAY
, 1)
484 sock
.settimeout(timeout
)
486 sock
.connect(("127.0.0.1", cls
._consolePort
))
488 theirNonce
= sock
.recv(len(ourNonce
))
489 if len(theirNonce
) != len(ourNonce
):
490 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce
), len(ourNonce
)))
491 if len(theirNonce
) == 0:
492 raise socket
.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce
)))
495 halfNonceSize
= int(len(ourNonce
) / 2)
496 readingNonce
= ourNonce
[0:halfNonceSize
] + theirNonce
[halfNonceSize
:]
497 writingNonce
= theirNonce
[0:halfNonceSize
] + ourNonce
[halfNonceSize
:]
498 msg
= cls
._encryptConsole
(command
, writingNonce
)
499 sock
.send(struct
.pack("!I", len(msg
)))
503 raise socket
.error("Got EOF while reading the response size")
505 (responseLen
,) = struct
.unpack("!I", data
)
506 data
= sock
.recv(responseLen
)
507 response
= cls
._decryptConsole
(data
, readingNonce
)
510 def compareOptions(self
, a
, b
):
511 self
.assertEquals(len(a
), len(b
))
512 for idx
in range(len(a
)):
513 self
.assertEquals(a
[idx
], b
[idx
])
515 def checkMessageNoEDNS(self
, expected
, received
):
516 self
.assertEquals(expected
, received
)
517 self
.assertEquals(received
.edns
, -1)
518 self
.assertEquals(len(received
.options
), 0)
520 def checkMessageEDNSWithoutOptions(self
, expected
, received
):
521 self
.assertEquals(expected
, received
)
522 self
.assertEquals(received
.edns
, 0)
524 def checkMessageEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
525 self
.assertEquals(expected
, received
)
526 self
.assertEquals(received
.edns
, 0)
527 self
.assertEquals(len(received
.options
), withCookies
)
529 for option
in received
.options
:
530 self
.assertEquals(option
.otype
, 10)
532 def checkMessageEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
533 self
.assertEquals(expected
, received
)
534 self
.assertEquals(received
.edns
, 0)
535 self
.assertEquals(len(received
.options
), 1 + additionalOptions
)
537 for option
in received
.options
:
538 if option
.otype
== clientsubnetoption
.ASSIGNED_OPTION_CODE
:
541 self
.assertNotEquals(additionalOptions
, 0)
543 self
.compareOptions(expected
.options
, received
.options
)
544 self
.assertTrue(hasECS
)
546 def checkQueryEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
547 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
549 def checkResponseEDNSWithECS(self
, expected
, received
, additionalOptions
=0):
550 self
.checkMessageEDNSWithECS(expected
, received
, additionalOptions
)
552 def checkQueryEDNSWithoutECS(self
, expected
, received
):
553 self
.checkMessageEDNSWithoutECS(expected
, received
)
555 def checkResponseEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
556 self
.checkMessageEDNSWithoutECS(expected
, received
, withCookies
)
558 def checkQueryNoEDNS(self
, expected
, received
):
559 self
.checkMessageNoEDNS(expected
, received
)
561 def checkResponseNoEDNS(self
, expected
, received
):
562 self
.checkMessageNoEDNS(expected
, received
)