13 import clientsubnetoption
19 # Python2/3 compatibility hacks
20 if sys
.version_info
[0] == 2:
21 from Queue
import Queue
24 from queue
import Queue
25 range = range # allow re-export of the builtin name
28 class DNSDistTest(unittest
.TestCase
):
30 Set up a dnsdist instance and responder threads.
31 Queries sent to dnsdist are relayed to the responder threads,
32 who reply with the response provided by the tests themselves
33 on a queue. Responder threads also queue the queries received
34 from dnsdist on a separate queue, allowing the tests to check
35 that the queries sent from dnsdist were as expected.
38 _dnsDistListeningAddr
= "127.0.0.1"
39 _testServerPort
= 5350
40 _toResponderQueue
= Queue()
41 _fromResponderQueue
= Queue()
43 _dnsdistStartupDelay
= 2.0
45 _responsesCounter
= {}
47 _config_template
= """
49 _config_params
= ['_testServerPort']
50 _acl
= ['127.0.0.1/32']
53 _healthCheckName
= 'a.root-servers.net.'
54 _healthCheckCounter
= 0
55 _answerUnexpected
= True
58 def startResponders(cls
):
59 print("Launching responders..")
61 cls
._UDPResponder
= threading
.Thread(name
='UDP Responder', target
=cls
.UDPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
62 cls
._UDPResponder
.setDaemon(True)
63 cls
._UDPResponder
.start()
64 cls
._TCPResponder
= threading
.Thread(name
='TCP Responder', target
=cls
.TCPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
65 cls
._TCPResponder
.setDaemon(True)
66 cls
._TCPResponder
.start()
69 def startDNSDist(cls
, shutUp
=True):
70 print("Launching dnsdist..")
71 conffile
= 'dnsdist_test.conf'
72 params
= tuple([getattr(cls
, param
) for param
in cls
._config
_params
])
74 with
open(conffile
, 'w') as conf
:
75 conf
.write("-- Autogenerated by dnsdisttests.py\n")
76 conf
.write(cls
._config
_template
% params
)
78 dnsdistcmd
= [os
.environ
['DNSDISTBIN'], '-C', conffile
,
79 '-l', '%s:%d' % (cls
._dnsDistListeningAddr
, cls
._dnsDistPort
) ]
81 dnsdistcmd
.extend(['--acl', acl
])
82 print(' '.join(dnsdistcmd
))
84 # validate config with --check-config, which sets client=true, possibly exposing bugs.
85 testcmd
= dnsdistcmd
+ ['--check-config']
86 output
= subprocess
.check_output(testcmd
, close_fds
=True)
87 if output
!= b
'Configuration \'dnsdist_test.conf\' OK!\n':
88 raise AssertionError('dnsdist --check-config failed: %s' % output
)
91 with
open(os
.devnull
, 'w') as fdDevNull
:
92 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True, stdout
=fdDevNull
)
94 cls
._dnsdist
= subprocess
.Popen(dnsdistcmd
, close_fds
=True)
96 if 'DNSDIST_FAST_TESTS' in os
.environ
:
99 delay
= cls
._dnsdistStartupDelay
103 if cls
._dnsdist
.poll() is not None:
105 sys
.exit(cls
._dnsdist
.returncode
)
108 def setUpSockets(cls
):
109 print("Setting up UDP socket..")
110 cls
._sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
111 cls
._sock
.settimeout(2.0)
112 cls
._sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
117 cls
.startResponders()
118 cls
.startDNSDist(cls
._shutUp
)
121 print("Launching tests..")
124 def tearDownClass(cls
):
125 if 'DNSDIST_FAST_TESTS' in os
.environ
:
130 cls
._dnsdist
.terminate()
131 if cls
._dnsdist
.poll() is None:
133 if cls
._dnsdist
.poll() is None:
138 def _ResponderIncrementCounter(cls
):
139 if threading
.currentThread().name
in cls
._responsesCounter
:
140 cls
._responsesCounter
[threading
.currentThread().name
] += 1
142 cls
._responsesCounter
[threading
.currentThread().name
] = 1
145 def _getResponse(cls
, request
, fromQueue
, toQueue
):
147 if len(request
.question
) != 1:
148 print("Skipping query with question count %d" % (len(request
.question
)))
150 healthCheck
= str(request
.question
[0].name
).endswith(cls
._healthCheckName
)
152 cls
._healthCheckCounter
+= 1
154 cls
._ResponderIncrementCounter
()
155 if not fromQueue
.empty():
156 response
= fromQueue
.get(True, cls
._queueTimeout
)
158 response
= copy
.copy(response
)
159 response
.id = request
.id
160 toQueue
.put(request
, True, cls
._queueTimeout
)
164 response
= dns
.message
.make_response(request
)
165 elif cls
._answerUnexpected
:
166 response
= dns
.message
.make_response(request
)
167 response
.set_rcode(dns
.rcode
.SERVFAIL
)
172 def UDPResponder(cls
, port
, fromQueue
, toQueue
, ignoreTrailing
=False):
173 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
174 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
175 sock
.bind(("127.0.0.1", port
))
177 data
, addr
= sock
.recvfrom(4096)
178 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
179 response
= cls
._getResponse
(request
, fromQueue
, toQueue
)
185 sock
.sendto(response
.to_wire(), addr
)
186 sock
.settimeout(None)
190 def TCPResponder(cls
, port
, fromQueue
, toQueue
, ignoreTrailing
=False, multipleResponses
=False):
191 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
192 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
194 sock
.bind(("127.0.0.1", port
))
195 except socket
.error
as e
:
196 print("Error binding in the TCP responder: %s" % str(e
))
201 (conn
, _
) = sock
.accept()
208 (datalen
,) = struct
.unpack("!H", data
)
209 data
= conn
.recv(datalen
)
210 request
= dns
.message
.from_wire(data
, ignore_trailing
=ignoreTrailing
)
211 response
= cls
._getResponse
(request
, fromQueue
, toQueue
)
217 wire
= response
.to_wire()
218 conn
.send(struct
.pack("!H", len(wire
)))
221 while multipleResponses
:
222 if fromQueue
.empty():
225 response
= fromQueue
.get(True, cls
._queueTimeout
)
229 response
= copy
.copy(response
)
230 response
.id = request
.id
231 wire
= response
.to_wire()
233 conn
.send(struct
.pack("!H", len(wire
)))
235 except socket
.error
as e
:
236 # some of the tests are going to close
237 # the connection on us, just deal with it
245 def sendUDPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
247 cls
._toResponderQueue
.put(response
, True, timeout
)
250 cls
._sock
.settimeout(timeout
)
254 query
= query
.to_wire()
255 cls
._sock
.send(query
)
256 data
= cls
._sock
.recv(4096)
257 except socket
.timeout
:
261 cls
._sock
.settimeout(None)
265 if useQueue
and not cls
._fromResponderQueue
.empty():
266 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
268 message
= dns
.message
.from_wire(data
)
269 return (receivedQuery
, message
)
272 def openTCPConnection(cls
, timeout
=None):
273 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
275 sock
.settimeout(timeout
)
277 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
281 def openTLSConnection(cls
, port
, serverName
, caCert
=None, timeout
=None):
282 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
284 sock
.settimeout(timeout
)
287 if hasattr(ssl
, 'create_default_context'):
288 sslctx
= ssl
.create_default_context(cafile
=caCert
)
289 sslsock
= sslctx
.wrap_socket(sock
, server_hostname
=serverName
)
291 sslsock
= ssl
.wrap_socket(sock
, ca_certs
=caCert
, cert_reqs
=ssl
.CERT_REQUIRED
)
293 sslsock
.connect(("127.0.0.1", port
))
297 def sendTCPQueryOverConnection(cls
, sock
, query
, rawQuery
=False, response
=None, timeout
=2.0):
299 wire
= query
.to_wire()
304 cls
._toResponderQueue
.put(response
, True, timeout
)
306 sock
.send(struct
.pack("!H", len(wire
)))
310 def recvTCPResponseOverConnection(cls
, sock
, useQueue
=False, timeout
=2.0):
314 (datalen
,) = struct
.unpack("!H", data
)
315 data
= sock
.recv(datalen
)
317 message
= dns
.message
.from_wire(data
)
319 if useQueue
and not cls
._fromResponderQueue
.empty():
320 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
321 return (receivedQuery
, message
)
326 def sendTCPQuery(cls
, query
, response
, useQueue
=True, timeout
=2.0, rawQuery
=False):
329 cls
._toResponderQueue
.put(response
, True, timeout
)
331 sock
= cls
.openTCPConnection(timeout
)
334 cls
.sendTCPQueryOverConnection(sock
, query
, rawQuery
)
335 message
= cls
.recvTCPResponseOverConnection(sock
)
336 except socket
.timeout
as e
:
337 print("Timeout: %s" % (str(e
)))
338 except socket
.error
as e
:
339 print("Network error: %s" % (str(e
)))
344 if useQueue
and not cls
._fromResponderQueue
.empty():
345 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
347 return (receivedQuery
, message
)
350 def sendTCPQueryWithMultipleResponses(cls
, query
, responses
, useQueue
=True, timeout
=2.0, rawQuery
=False):
352 for response
in responses
:
353 cls
._toResponderQueue
.put(response
, True, timeout
)
354 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
356 sock
.settimeout(timeout
)
358 sock
.connect(("127.0.0.1", cls
._dnsDistPort
))
363 wire
= query
.to_wire()
367 sock
.send(struct
.pack("!H", len(wire
)))
373 (datalen
,) = struct
.unpack("!H", data
)
374 data
= sock
.recv(datalen
)
375 messages
.append(dns
.message
.from_wire(data
))
377 except socket
.timeout
as e
:
378 print("Timeout: %s" % (str(e
)))
379 except socket
.error
as e
:
380 print("Network error: %s" % (str(e
)))
385 if useQueue
and not cls
._fromResponderQueue
.empty():
386 receivedQuery
= cls
._fromResponderQueue
.get(True, timeout
)
387 return (receivedQuery
, messages
)
390 # This function is called before every tests
392 # Clear the responses counters
393 for key
in self
._responsesCounter
:
394 self
._responsesCounter
[key
] = 0
396 self
._healthCheckCounter
= 0
398 # Make sure the queues are empty, in case
399 # a previous test failed
400 while not self
._toResponderQueue
.empty():
401 self
._toResponderQueue
.get(False)
403 while not self
._fromResponderQueue
.empty():
404 self
._fromResponderQueue
.get(False)
407 def clearToResponderQueue(cls
):
408 while not cls
._toResponderQueue
.empty():
409 cls
._toResponderQueue
.get(False)
412 def clearFromResponderQueue(cls
):
413 while not cls
._fromResponderQueue
.empty():
414 cls
._fromResponderQueue
.get(False)
417 def clearResponderQueues(cls
):
418 cls
.clearToResponderQueue()
419 cls
.clearFromResponderQueue()
422 def generateConsoleKey():
423 return libnacl
.utils
.salsa_key()
426 def _encryptConsole(cls
, command
, nonce
):
427 command
= command
.encode('UTF-8')
428 if cls
._consoleKey
is None:
430 return libnacl
.crypto_secretbox(command
, nonce
, cls
._consoleKey
)
433 def _decryptConsole(cls
, command
, nonce
):
434 if cls
._consoleKey
is None:
437 result
= libnacl
.crypto_secretbox_open(command
, nonce
, cls
._consoleKey
)
438 return result
.decode('UTF-8')
441 def sendConsoleCommand(cls
, command
, timeout
=1.0):
442 ourNonce
= libnacl
.utils
.rand_nonce()
444 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
446 sock
.settimeout(timeout
)
448 sock
.connect(("127.0.0.1", cls
._consolePort
))
450 theirNonce
= sock
.recv(len(ourNonce
))
451 if len(theirNonce
) != len(ourNonce
):
452 print("Received a nonce of size %d, expecting %d, console command will not be sent!" % (len(theirNonce
), len(ourNonce
)))
453 if len(theirNonce
) == 0:
454 raise socket
.error("Got EOF while reading a nonce of size %d, console command will not be sent!" % (len(ourNonce
)))
457 halfNonceSize
= int(len(ourNonce
) / 2)
458 readingNonce
= ourNonce
[0:halfNonceSize
] + theirNonce
[halfNonceSize
:]
459 writingNonce
= theirNonce
[0:halfNonceSize
] + ourNonce
[halfNonceSize
:]
460 msg
= cls
._encryptConsole
(command
, writingNonce
)
461 sock
.send(struct
.pack("!I", len(msg
)))
465 raise socket
.error("Got EOF while reading the response size")
467 (responseLen
,) = struct
.unpack("!I", data
)
468 data
= sock
.recv(responseLen
)
469 response
= cls
._decryptConsole
(data
, readingNonce
)
472 def compareOptions(self
, a
, b
):
473 self
.assertEquals(len(a
), len(b
))
474 for idx
in range(len(a
)):
475 self
.assertEquals(a
[idx
], b
[idx
])
477 def checkMessageNoEDNS(self
, expected
, received
):
478 self
.assertEquals(expected
, received
)
479 self
.assertEquals(received
.edns
, -1)
480 self
.assertEquals(len(received
.options
), 0)
482 def checkMessageEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
483 self
.assertEquals(expected
, received
)
484 self
.assertEquals(received
.edns
, 0)
485 self
.assertEquals(len(received
.options
), withCookies
)
487 for option
in received
.options
:
488 self
.assertEquals(option
.otype
, 10)
490 def checkMessageEDNSWithECS(self
, expected
, received
):
491 self
.assertEquals(expected
, received
)
492 self
.assertEquals(received
.edns
, 0)
493 self
.assertEquals(len(received
.options
), 1)
494 self
.assertEquals(received
.options
[0].otype
, clientsubnetoption
.ASSIGNED_OPTION_CODE
)
495 self
.compareOptions(expected
.options
, received
.options
)
497 def checkQueryEDNSWithECS(self
, expected
, received
):
498 self
.checkMessageEDNSWithECS(expected
, received
)
500 def checkResponseEDNSWithECS(self
, expected
, received
):
501 self
.checkMessageEDNSWithECS(expected
, received
)
503 def checkQueryEDNSWithoutECS(self
, expected
, received
):
504 self
.checkMessageEDNSWithoutECS(expected
, received
)
506 def checkResponseEDNSWithoutECS(self
, expected
, received
, withCookies
=0):
507 self
.checkMessageEDNSWithoutECS(expected
, received
, withCookies
)
509 def checkQueryNoEDNS(self
, expected
, received
):
510 self
.checkMessageNoEDNS(expected
, received
)
512 def checkResponseNoEDNS(self
, expected
, received
):
513 self
.checkMessageNoEDNS(expected
, received
)